model-registry/mlModelSaver/mlModelSaver.py

67 lines
2.3 KiB
Python

import pickle
import json
supportedModels = {
"sm.OLS": {
"supported": True
}
}
supportedDataType = {
"int": {
"supported": True
},
"float": {
"supported": True
},
"binary":{
"supported": True
}
}
class MlModelSaver:
def __init__(self, config):
self.baseRelativePath = config.get('baseRelativePath', '.')
self.modelsFolder = config.get('modelsFolder', '~~modelsFolder')
def showSupportedModels(self):
supported_keys = [key for key, value in supportedModels.items() if value.get('supported')]
return supported_keys
def exportModel(self, model, config):
model.mlModelSaverConfig = config
isModelSupporter = supportedModels.get(
config.get("modelType", ''),
{}
).get("supported", False)
if not isModelSupporter:
raise ValueError(f'only {self.showSupportedModels()} are supported and {config.get("modelType", '')} is not supported')
print(model.mlModelSaverConfig)
# modelName = config['modelName']
# modelsConfig[modelName] = {}
# modelsConfig[modelName]['name'] = modelName
# model = config['model']
# inputs = config['inputs']
# output = config['output']
# transformers = config.get('transformers', [])
# description = config['description']
# modelsConfig[modelName]['description'] = description
# modelsConfig[modelName]['inputs'] = inputs
# if len(transformers) > 0:
# modelsConfig[modelName]['transformers'] = transformers
# modelsConfig[modelName]['output'] = output
# modelType = config.get('modelType', '')
# modelsConfig[modelName]['modelType'] = modelType
# if hasattr(model, 'customMetrics'):
# customMetrics = model.customMetrics
# modelsConfig[modelName]['customMetrics'] = customMetrics
# else:
# pass
# filename = f'{baseRelativePath}/models/{modelName}'
# pickle.dump(model, open(filename, 'wb'))
# with open(f'{baseRelativePath}/models/configs.json', "w") as outputFile:
# json.dump(modelsConfig, outputFile, indent = 4)
# loaded_model = pickle.load(open(filename, 'rb'))
# return loaded_model