diff --git a/mlModelSaver/__init__.py b/mlModelSaver/__init__.py index bb06dcc..54ad448 100644 --- a/mlModelSaver/__init__.py +++ b/mlModelSaver/__init__.py @@ -51,6 +51,10 @@ supportedModels = { "supported": True, "normalPredictorFunction": "predict_proba" }, + "sklearn.tree.DecisionTreeRegressor": { + "supported": True, + "normalPredictorFunction": "predict" + }, } supportedDataType = { diff --git a/pytests/test_mlModelSaver.py b/pytests/test_mlModelSaver.py index 2248d2c..3cc8c20 100644 --- a/pytests/test_mlModelSaver.py +++ b/pytests/test_mlModelSaver.py @@ -27,7 +27,8 @@ def test_ensureCLassInstance(): 'sm.OLS', 'sm.Logit', 'sklearn.neighbors.KNeighborsClassifier', - 'sklearn.tree.DecisionTreeClassifier' + 'sklearn.tree.DecisionTreeClassifier', + 'sklearn.tree.DecisionTreeRegressor', ]