Source code for schrodinger.application.matsci.mlearn.sklearn_json
"""
# Third-party code. No Schrodinger Copyright.
"""
import json
from sklearn import discriminant_analysis
from sklearn import dummy # noqa: F401
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import _gb_losses # noqa: F401
from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import Perceptron
from sklearn.linear_model import Ridge
from sklearn.naive_bayes import BernoulliNB
from sklearn.naive_bayes import ComplementNB
from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import MultinomialNB
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from . import classification as clf
from . import decomposition as dcp
from . import regression as reg
__version__ = '0.1.0'
[docs]def serialize_model(model):
if isinstance(model, LogisticRegression):
return clf.serialize_logistic_regression(model)
elif isinstance(model, BernoulliNB):
return clf.serialize_bernoulli_nb(model)
elif isinstance(model, GaussianNB):
return clf.serialize_gaussian_nb(model)
elif isinstance(model, MultinomialNB):
return clf.serialize_multinomial_nb(model)
elif isinstance(model, ComplementNB):
return clf.serialize_complement_nb(model)
elif isinstance(model, discriminant_analysis.LinearDiscriminantAnalysis):
return clf.serialize_lda(model)
elif isinstance(model, discriminant_analysis.QuadraticDiscriminantAnalysis):
return clf.serialize_qda(model)
elif isinstance(model, svm.SVC):
return clf.serialize_svm(model)
elif isinstance(model, Perceptron):
return clf.serialize_perceptron(model)
elif isinstance(model, DecisionTreeClassifier):
return clf.serialize_decision_tree(model)
elif isinstance(model, GradientBoostingClassifier):
return clf.serialize_gradient_boosting(model)
elif isinstance(model, RandomForestClassifier):
return clf.serialize_random_forest(model)
elif isinstance(model, MLPClassifier):
return clf.serialize_mlp(model)
elif isinstance(model, LinearRegression):
return reg.serialize_linear_regressor(model)
elif isinstance(model, Lasso):
return reg.serialize_lasso_regressor(model)
elif isinstance(model, Ridge):
return reg.serialize_ridge_regressor(model)
elif isinstance(model, SVR):
return reg.serialize_svr(model)
elif isinstance(model, DecisionTreeRegressor):
return reg.serialize_decision_tree_regressor(model)
elif isinstance(model, GradientBoostingRegressor):
return reg.serialize_gradient_boosting_regressor(model)
elif isinstance(model, RandomForestRegressor):
return reg.serialize_random_forest_regressor(model)
elif isinstance(model, MLPRegressor):
return reg.serialize_mlp_regressor(model)
elif isinstance(model, PCA):
return dcp.serialize_pca(model)
else:
raise ModellNotSupported(
'This model type is not currently supported. Email support@mlrequest.com to request a feature or report a bug.'
)
[docs]def deserialize_model(model_dict):
if model_dict['meta'] == 'lr':
return clf.deserialize_logistic_regression(model_dict)
elif model_dict['meta'] == 'bernoulli-nb':
return clf.deserialize_bernoulli_nb(model_dict)
elif model_dict['meta'] == 'gaussian-nb':
return clf.deserialize_gaussian_nb(model_dict)
elif model_dict['meta'] == 'multinomial-nb':
return clf.deserialize_multinomial_nb(model_dict)
elif model_dict['meta'] == 'complement-nb':
return clf.deserialize_complement_nb(model_dict)
elif model_dict['meta'] == 'lda':
return clf.deserialize_lda(model_dict)
elif model_dict['meta'] == 'qda':
return clf.deserialize_qda(model_dict)
elif model_dict['meta'] == 'svm':
return clf.deserialize_svm(model_dict)
elif model_dict['meta'] == 'perceptron':
return clf.deserialize_perceptron(model_dict)
elif model_dict['meta'] == 'decision-tree':
return clf.deserialize_decision_tree(model_dict)
elif model_dict['meta'] == 'gb':
return clf.deserialize_gradient_boosting(model_dict)
elif model_dict['meta'] == 'rf':
return clf.deserialize_random_forest(model_dict)
elif model_dict['meta'] == 'mlp':
return clf.deserialize_mlp(model_dict)
elif model_dict['meta'] == 'linear-regression':
return reg.deserialize_linear_regressor(model_dict)
elif model_dict['meta'] == 'lasso-regression':
return reg.deserialize_lasso_regressor(model_dict)
elif model_dict['meta'] == 'ridge-regression':
return reg.deserialize_ridge_regressor(model_dict)
elif model_dict['meta'] == 'svr':
return reg.deserialize_svr(model_dict)
elif model_dict['meta'] == 'decision-tree-regression':
return reg.deserialize_decision_tree_regressor(model_dict)
elif model_dict['meta'] == 'gb-regression':
return reg.deserialize_gradient_boosting_regressor(model_dict)
elif model_dict['meta'] == 'rf-regression':
return reg.deserialize_random_forest_regressor(model_dict)
elif model_dict['meta'] == 'mlp-regression':
return reg.deserialize_mlp_regressor(model_dict)
elif model_dict['meta'] == 'pca':
return dcp.deserialize_pca(model_dict)
else:
raise ModellNotSupported(
'Model type not supported or corrupt JSON file. Email support@mlrequest.com to request a feature or report a bug.'
)
[docs]def to_dict(model):
return serialize_model(model)
[docs]def from_dict(model_dict):
return deserialize_model(model_dict)
[docs]def to_json(model, model_name):
with open(model_name, 'w') as model_json:
json.dump(serialize_model(model), model_json)
[docs]def from_json(model_name):
with open(model_name, 'r') as model_json:
model_dict = json.load(model_json)
return deserialize_model(model_dict)
[docs]class ModellNotSupported(Exception):
pass