Support sklearn.neighbors.KNeighborsClassifier

This commit is contained in:
Jason Jafari 2024-06-16 13:12:16 -04:00
parent 1ddae351a3
commit 1978b97b49

View File

@ -4,6 +4,7 @@ import json
import os
from functools import partial
import numpy as np
def ensure_directory_exists(directory_path):
"""
@ -35,11 +36,17 @@ def check_file_exists(file_path):
supportedModels = {
"sm.OLS": {
"supported": True
"supported": True,
"normalPredictorFunction": "predict"
},
"sm.Logit": {
"supported": True
}
"supported": True,
"normalPredictorFunction": "predict"
},
"sklearn.neighbors.KNeighborsClassifier": {
"supported": True,
"normalPredictorFunction": "predict_proba"
},
}
supportedDataType = {
@ -66,13 +73,26 @@ def mlModelSavePredict(self, df, typeOfPredict = 'normal'):
output = []
outputsName = self.mlModelSaverConfig.get("outputs", [{"name": "result"}])
outputsName = [item["name"] for item in outputsName]
modelType = self.mlModelSaverConfig['modelType']
modelTypeConfig = supportedModels[modelType]
if typeOfPredict == 'normal':
results = self.predict(dfAfterTransformation)
results = getattr(self, modelTypeConfig['normalPredictorFunction'])(dfAfterTransformation)
print(results)
for value in results:
output.append({
outputsName[0]: value,
})
if isinstance(value, np.ndarray):
res = {}
for index, val in enumerate(value):
res[outputsName[index]] = val
output.append(res)
else:
output.append({
outputsName[0]: value,
})
print(type(value))
print(value)
print(outputsName)
return output
return output
class MlModelSaver: