diff --git a/site/setup.cfg b/site/setup.cfg index 7bd0290c..17795307 100644 --- a/site/setup.cfg +++ b/site/setup.cfg @@ -48,12 +48,14 @@ invenio_base.apps = zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats + zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML invenio_base.api_apps = zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy profiler = zenodo_rdm.profiler:Profiler zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE + zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML invenio_base.api_blueprints = zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp diff --git a/site/zenodo_rdm/ml/__init__.py b/site/zenodo_rdm/ml/__init__.py new file mode 100644 index 00000000..4899db37 --- /dev/null +++ b/site/zenodo_rdm/ml/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Zenodo-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Machine learning module.""" diff --git a/site/zenodo_rdm/ml/base.py b/site/zenodo_rdm/ml/base.py new file mode 100644 index 00000000..a85eb6a1 --- /dev/null +++ b/site/zenodo_rdm/ml/base.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Zenodo-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Base class for ML models.""" + + +class MLModel: + """Base class for ML models.""" + + def __init__(self, version=None, **kwargs): + """Constructor.""" + self.version = version + + def process(self, data, preprocess=None, postprocess=None, raise_exc=True): + """Pipeline function to call pre/post process with predict.""" + try: + preprocessor = preprocess or self.preprocess + postprocessor = postprocess or self.postprocess + + preprocessed = preprocessor(data) + prediction = self.predict(preprocessed) + return postprocessor(prediction) + except Exception as e: + if raise_exc: + raise e + return None + + def predict(self, data): + """Predict method to be implemented by subclass.""" + raise NotImplementedError() + + def preprocess(self, data): + """Preprocess data.""" + return data + + def postprocess(self, data): + """Postprocess data.""" + return data diff --git a/site/zenodo_rdm/ml/config.py b/site/zenodo_rdm/ml/config.py new file mode 100644 index 00000000..e1cacf94 --- /dev/null +++ b/site/zenodo_rdm/ml/config.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. + +"""Machine learning config.""" + +from .models import SpamDetectorScikit + +ML_MODELS = { + "spam_scikit": SpamDetectorScikit, +} +"""Machine learning models.""" + +# NOTE Model URL and model host need to be formattable strings for the model name. +ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME" +ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE" +ML_KUBEFLOW_TOKEN = "CHANGE SECRET" +"""Kubeflow connection config.""" diff --git a/site/zenodo_rdm/ml/ext.py b/site/zenodo_rdm/ml/ext.py new file mode 100644 index 00000000..18c72520 --- /dev/null +++ b/site/zenodo_rdm/ml/ext.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. + +"""ZenodoRDM machine learning module.""" + +from flask import current_app + +from . import config + + +class ZenodoML: + """Zenodo machine learning extension.""" + + def __init__(self, app=None): + """Extension initialization.""" + if app: + self.init_app(app) + + @staticmethod + def init_config(app): + """Initialize configuration.""" + for k in dir(config): + if k.startswith("ML_"): + app.config.setdefault(k, getattr(config, k)) + + def init_app(self, app): + """Flask application initialization.""" + self.init_config(app) + app.extensions["zenodo-ml"] = self + + def _parse_model_name_version(self, model): + """Parse model name and version.""" + vals = model.rsplit(":") + version = vals[1] if len(vals) > 1 else None + return vals[0], version + + def models(self, model, **kwargs): + """Return model based on model name.""" + models = current_app.config.get("ML_MODELS", {}) + model_name, version = self._parse_model_name_version(model) + + if model_name not in models: + raise ValueError("Model not found/registered.") + + return models[model_name](version=version, **kwargs) diff --git a/site/zenodo_rdm/ml/models.py b/site/zenodo_rdm/ml/models.py new file mode 100644 index 00000000..bde9bdc5 --- /dev/null +++ b/site/zenodo_rdm/ml/models.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Model definitions.""" + + +import json +import string + +import requests +from bs4 import BeautifulSoup +from flask import current_app + +from .base import MLModel + + +class SpamDetectorScikit(MLModel): + """Spam detection model based on Sklearn.""" + + MODEL_NAME = "sklearn-spam" + MAX_WORDS = 4000 + + def __init__(self, version, **kwargs): + """Constructor. Makes version required.""" + super().__init__(version, **kwargs) + + def preprocess(self, data): + """Preprocess data. + + Parse HTML, remove punctuation and truncate to max chars. + """ + text = BeautifulSoup(data, "html.parser").get_text() + trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation)) + parts = text.translate(trans_table).lower().strip().split(" ") + if len(parts) >= self.MAX_WORDS: + parts = parts[: self.MAX_WORDS] + return " ".join(parts) + + def postprocess(self, data): + """Postprocess data. + + Gives spam and ham probability. + """ + result = { + "spam": data["outputs"][0]["data"][0], + "ham": data["outputs"][0]["data"][1], + } + return result + + def _send_request_kubeflow(self, data): + """Send predict request to Kubeflow.""" + payload = { + "inputs": [ + { + "name": "input-0", + "shape": [1], + "datatype": "BYTES", + "data": [f"{data}"], + } + ] + } + model_ref = self.MODEL_NAME + "-" + self.version + url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref) + host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref) + access_token = current_app.config.get("ML_KUBEFLOW_TOKEN") + r = requests.post( + url, + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "Host": host, + }, + json=payload, + ) + if r.status_code != 200: + raise requests.RequestException("Prediction was not successful.", request=r) + return json.loads(r.text) + + def predict(self, data): + """Get prediction from model.""" + prediction = self._send_request_kubeflow(data) + return prediction diff --git a/site/zenodo_rdm/ml/proxies.py b/site/zenodo_rdm/ml/proxies.py new file mode 100644 index 00000000..41596c21 --- /dev/null +++ b/site/zenodo_rdm/ml/proxies.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Proxy objects for easier access to application objects.""" + +from flask import current_app +from werkzeug.local import LocalProxy + +current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"])