diff --git a/mlModelSaver/__init__.py b/mlModelSaver/__init__.py index a313ae1..d2444cb 100644 --- a/mlModelSaver/__init__.py +++ b/mlModelSaver/__init__.py @@ -4,6 +4,7 @@ import json import os from functools import partial +import numpy as np def ensure_directory_exists(directory_path): """ @@ -35,11 +36,17 @@ def check_file_exists(file_path): supportedModels = { "sm.OLS": { - "supported": True + "supported": True, + "normalPredictorFunction": "predict" }, "sm.Logit": { - "supported": True - } + "supported": True, + "normalPredictorFunction": "predict" + }, + "sklearn.neighbors.KNeighborsClassifier": { + "supported": True, + "normalPredictorFunction": "predict_proba" + }, } supportedDataType = { @@ -66,13 +73,26 @@ def mlModelSavePredict(self, df, typeOfPredict = 'normal'): output = [] outputsName = self.mlModelSaverConfig.get("outputs", [{"name": "result"}]) outputsName = [item["name"] for item in outputsName] + modelType = self.mlModelSaverConfig['modelType'] + modelTypeConfig = supportedModels[modelType] if typeOfPredict == 'normal': - results = self.predict(dfAfterTransformation) + results = getattr(self, modelTypeConfig['normalPredictorFunction'])(dfAfterTransformation) + print(results) for value in results: - output.append({ - outputsName[0]: value, - }) + if isinstance(value, np.ndarray): + res = {} + for index, val in enumerate(value): + res[outputsName[index]] = val + output.append(res) + else: + output.append({ + outputsName[0]: value, + }) + print(type(value)) + print(value) + print(outputsName) return output + return output class MlModelSaver: