Source code for schrodinger.application.matsci.mlearn.sklearn_json

"""
# Third-party code. No Schrodinger Copyright.
"""

import json

from sklearn import discriminant_analysis
from sklearn import dummy  # noqa: F401
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import _gb_losses  # noqa: F401
from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import Perceptron
from sklearn.linear_model import Ridge
from sklearn.naive_bayes import BernoulliNB
from sklearn.naive_bayes import ComplementNB
from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import MultinomialNB
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor

from . import classification as clf
from . import decomposition as dcp
from . import regression as reg

__version__ = '0.1.0'


[docs]def serialize_model(model): if isinstance(model, LogisticRegression): return clf.serialize_logistic_regression(model) elif isinstance(model, BernoulliNB): return clf.serialize_bernoulli_nb(model) elif isinstance(model, GaussianNB): return clf.serialize_gaussian_nb(model) elif isinstance(model, MultinomialNB): return clf.serialize_multinomial_nb(model) elif isinstance(model, ComplementNB): return clf.serialize_complement_nb(model) elif isinstance(model, discriminant_analysis.LinearDiscriminantAnalysis): return clf.serialize_lda(model) elif isinstance(model, discriminant_analysis.QuadraticDiscriminantAnalysis): return clf.serialize_qda(model) elif isinstance(model, svm.SVC): return clf.serialize_svm(model) elif isinstance(model, Perceptron): return clf.serialize_perceptron(model) elif isinstance(model, DecisionTreeClassifier): return clf.serialize_decision_tree(model) elif isinstance(model, GradientBoostingClassifier): return clf.serialize_gradient_boosting(model) elif isinstance(model, RandomForestClassifier): return clf.serialize_random_forest(model) elif isinstance(model, MLPClassifier): return clf.serialize_mlp(model) elif isinstance(model, LinearRegression): return reg.serialize_linear_regressor(model) elif isinstance(model, Lasso): return reg.serialize_lasso_regressor(model) elif isinstance(model, Ridge): return reg.serialize_ridge_regressor(model) elif isinstance(model, SVR): return reg.serialize_svr(model) elif isinstance(model, DecisionTreeRegressor): return reg.serialize_decision_tree_regressor(model) elif isinstance(model, GradientBoostingRegressor): return reg.serialize_gradient_boosting_regressor(model) elif isinstance(model, RandomForestRegressor): return reg.serialize_random_forest_regressor(model) elif isinstance(model, MLPRegressor): return reg.serialize_mlp_regressor(model) elif isinstance(model, PCA): return dcp.serialize_pca(model) else: raise ModellNotSupported( 'This model type is not currently supported. Email support@mlrequest.com to request a feature or report a bug.' )
[docs]def deserialize_model(model_dict): if model_dict['meta'] == 'lr': return clf.deserialize_logistic_regression(model_dict) elif model_dict['meta'] == 'bernoulli-nb': return clf.deserialize_bernoulli_nb(model_dict) elif model_dict['meta'] == 'gaussian-nb': return clf.deserialize_gaussian_nb(model_dict) elif model_dict['meta'] == 'multinomial-nb': return clf.deserialize_multinomial_nb(model_dict) elif model_dict['meta'] == 'complement-nb': return clf.deserialize_complement_nb(model_dict) elif model_dict['meta'] == 'lda': return clf.deserialize_lda(model_dict) elif model_dict['meta'] == 'qda': return clf.deserialize_qda(model_dict) elif model_dict['meta'] == 'svm': return clf.deserialize_svm(model_dict) elif model_dict['meta'] == 'perceptron': return clf.deserialize_perceptron(model_dict) elif model_dict['meta'] == 'decision-tree': return clf.deserialize_decision_tree(model_dict) elif model_dict['meta'] == 'gb': return clf.deserialize_gradient_boosting(model_dict) elif model_dict['meta'] == 'rf': return clf.deserialize_random_forest(model_dict) elif model_dict['meta'] == 'mlp': return clf.deserialize_mlp(model_dict) elif model_dict['meta'] == 'linear-regression': return reg.deserialize_linear_regressor(model_dict) elif model_dict['meta'] == 'lasso-regression': return reg.deserialize_lasso_regressor(model_dict) elif model_dict['meta'] == 'ridge-regression': return reg.deserialize_ridge_regressor(model_dict) elif model_dict['meta'] == 'svr': return reg.deserialize_svr(model_dict) elif model_dict['meta'] == 'decision-tree-regression': return reg.deserialize_decision_tree_regressor(model_dict) elif model_dict['meta'] == 'gb-regression': return reg.deserialize_gradient_boosting_regressor(model_dict) elif model_dict['meta'] == 'rf-regression': return reg.deserialize_random_forest_regressor(model_dict) elif model_dict['meta'] == 'mlp-regression': return reg.deserialize_mlp_regressor(model_dict) elif model_dict['meta'] == 'pca': return dcp.deserialize_pca(model_dict) else: raise ModellNotSupported( 'Model type not supported or corrupt JSON file. Email support@mlrequest.com to request a feature or report a bug.' )
[docs]def to_dict(model): return serialize_model(model)
[docs]def from_dict(model_dict): return deserialize_model(model_dict)
[docs]def to_json(model, model_name): with open(model_name, 'w') as model_json: json.dump(serialize_model(model), model_json)
[docs]def from_json(model_name): with open(model_name, 'r') as model_json: model_dict = json.load(model_json) return deserialize_model(model_dict)
[docs]class ModellNotSupported(Exception): pass