Support sklearn.neighbors.KNeighborsClassifier
This commit is contained in:
parent
1ddae351a3
commit
1978b97b49
@ -4,6 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def ensure_directory_exists(directory_path):
|
def ensure_directory_exists(directory_path):
|
||||||
"""
|
"""
|
||||||
@ -35,11 +36,17 @@ def check_file_exists(file_path):
|
|||||||
|
|
||||||
supportedModels = {
|
supportedModels = {
|
||||||
"sm.OLS": {
|
"sm.OLS": {
|
||||||
"supported": True
|
"supported": True,
|
||||||
|
"normalPredictorFunction": "predict"
|
||||||
},
|
},
|
||||||
"sm.Logit": {
|
"sm.Logit": {
|
||||||
"supported": True
|
"supported": True,
|
||||||
}
|
"normalPredictorFunction": "predict"
|
||||||
|
},
|
||||||
|
"sklearn.neighbors.KNeighborsClassifier": {
|
||||||
|
"supported": True,
|
||||||
|
"normalPredictorFunction": "predict_proba"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
supportedDataType = {
|
supportedDataType = {
|
||||||
@ -66,13 +73,26 @@ def mlModelSavePredict(self, df, typeOfPredict = 'normal'):
|
|||||||
output = []
|
output = []
|
||||||
outputsName = self.mlModelSaverConfig.get("outputs", [{"name": "result"}])
|
outputsName = self.mlModelSaverConfig.get("outputs", [{"name": "result"}])
|
||||||
outputsName = [item["name"] for item in outputsName]
|
outputsName = [item["name"] for item in outputsName]
|
||||||
|
modelType = self.mlModelSaverConfig['modelType']
|
||||||
|
modelTypeConfig = supportedModels[modelType]
|
||||||
if typeOfPredict == 'normal':
|
if typeOfPredict == 'normal':
|
||||||
results = self.predict(dfAfterTransformation)
|
results = getattr(self, modelTypeConfig['normalPredictorFunction'])(dfAfterTransformation)
|
||||||
|
print(results)
|
||||||
for value in results:
|
for value in results:
|
||||||
output.append({
|
if isinstance(value, np.ndarray):
|
||||||
outputsName[0]: value,
|
res = {}
|
||||||
})
|
for index, val in enumerate(value):
|
||||||
|
res[outputsName[index]] = val
|
||||||
|
output.append(res)
|
||||||
|
else:
|
||||||
|
output.append({
|
||||||
|
outputsName[0]: value,
|
||||||
|
})
|
||||||
|
print(type(value))
|
||||||
|
print(value)
|
||||||
|
print(outputsName)
|
||||||
return output
|
return output
|
||||||
|
return output
|
||||||
|
|
||||||
class MlModelSaver:
|
class MlModelSaver:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user