Support sklearn.neighbors.KNeighborsClassifier
This commit is contained in:
parent
1ddae351a3
commit
1978b97b49
@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
def ensure_directory_exists(directory_path):
|
||||
"""
|
||||
@ -35,11 +36,17 @@ def check_file_exists(file_path):
|
||||
|
||||
supportedModels = {
|
||||
"sm.OLS": {
|
||||
"supported": True
|
||||
"supported": True,
|
||||
"normalPredictorFunction": "predict"
|
||||
},
|
||||
"sm.Logit": {
|
||||
"supported": True
|
||||
}
|
||||
"supported": True,
|
||||
"normalPredictorFunction": "predict"
|
||||
},
|
||||
"sklearn.neighbors.KNeighborsClassifier": {
|
||||
"supported": True,
|
||||
"normalPredictorFunction": "predict_proba"
|
||||
},
|
||||
}
|
||||
|
||||
supportedDataType = {
|
||||
@ -66,13 +73,26 @@ def mlModelSavePredict(self, df, typeOfPredict = 'normal'):
|
||||
output = []
|
||||
outputsName = self.mlModelSaverConfig.get("outputs", [{"name": "result"}])
|
||||
outputsName = [item["name"] for item in outputsName]
|
||||
modelType = self.mlModelSaverConfig['modelType']
|
||||
modelTypeConfig = supportedModels[modelType]
|
||||
if typeOfPredict == 'normal':
|
||||
results = self.predict(dfAfterTransformation)
|
||||
results = getattr(self, modelTypeConfig['normalPredictorFunction'])(dfAfterTransformation)
|
||||
print(results)
|
||||
for value in results:
|
||||
output.append({
|
||||
outputsName[0]: value,
|
||||
})
|
||||
if isinstance(value, np.ndarray):
|
||||
res = {}
|
||||
for index, val in enumerate(value):
|
||||
res[outputsName[index]] = val
|
||||
output.append(res)
|
||||
else:
|
||||
output.append({
|
||||
outputsName[0]: value,
|
||||
})
|
||||
print(type(value))
|
||||
print(value)
|
||||
print(outputsName)
|
||||
return output
|
||||
return output
|
||||
|
||||
class MlModelSaver:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user