diff --git a/pmip/data.py b/pmip/data.py index 0aa58d8..5419a16 100644 --- a/pmip/data.py +++ b/pmip/data.py @@ -55,6 +55,21 @@ def load_from_s3(filename=None, subdirectory=None, bucket=None): return obj_bytes +def load_from_fs_and_unpickle(filename=None, subdirectory=None): + if filename is None: + raise ValueError("You must specify a filename") + if subdirectory is None: + raise ValueError("You must specify a subdirectory") + + if not os.path.exists(subdirectory): + raise ValueError('Directory {} does not exist'.format(subdirectory)) + + with open(os.path.join(subdirectory, filename), 'rb') as pickle_file: + obj = pickle.load(pickle_file) + + return obj + + def load_from_s3_and_unpickle(filename=None, subdirectory=None, bucket=None): if filename is None: raise ValueError("You must specify a filename") diff --git a/pmip/routes.py b/pmip/routes.py index 4d70b3c..0a679fc 100644 --- a/pmip/routes.py +++ b/pmip/routes.py @@ -3,7 +3,7 @@ from flask import Flask from flask_restplus import Resource, Api, fields, abort -from pmip.data import load_from_s3_and_unpickle, get_latest_s3_dateint +from pmip.data import load_from_s3_and_unpickle, get_latest_s3_dateint, load_from_fs_and_unpickle DATA_DIR = "data" MODEL_FILENAME = "model.pkl" @@ -15,15 +15,22 @@ def possible_types(value): return value -latest_model_id = get_latest_s3_dateint( - datadir='models', - bucket=os.getenv('BUCKET') -) -model = load_from_s3_and_unpickle( - filename='model.pkl', - subdirectory=f'models/{latest_model_id}', - bucket=os.getenv('BUCKET') -) +if os.getenv('ENVIRONMENT', '') == 'dev': + latest_model_id = 'local' + model = load_from_fs_and_unpickle( + filename='model.pkl', + subdirectory='data', + ) +elif os.getenv('ENVIRONMENT', '') in ['staging', 'prod']: + latest_model_id = get_latest_s3_dateint( + datadir='models', + bucket=os.getenv('BUCKET') + ) + model = load_from_s3_and_unpickle( + filename='model.pkl', + subdirectory=f'models/{latest_model_id}', + bucket=os.getenv('BUCKET') + ) app = Flask(__name__) api = Api(app) @@ -86,7 +93,19 @@ def post(self): f"Don't recognize type {type}" ) - return {'result': result}, 201 + return {'result': result}, 200 + + +@api.route('/model-info') +class ModelInfo(Resource): + + # @api.marshal_with(request, code=201) + def get(self): + result = { + "model_id": latest_model_id + } + + return {'result': result}, 200 if __name__ == '__main__': diff --git a/scripts/train.sh b/scripts/train.sh index c5c6b12..3581403 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -33,8 +33,7 @@ if [ "$ENVIRONMENT" == "staging" ]; then aws s3 cp $DIR/ $S3_DIR/ \ --recursive --exclude "*" --include "*.ipynb" --include "*.html" --include "*.pkl" - # Redeploy Lambda - wget -P /tmp https://download.docker.com/linux/debian/dists/stretch/pool/stable/amd64/docker-ce_18.09.0~3-0~debian-stretch_amd64.deb - dpkg -i /tmp/docker-ce_18.09.0~3-0~debian-stretch_amd64.deb - serverless deploy --region $([ -z "$AWS_DEFAULT_REGION" ] && aws configure get region || echo "$AWS_DEFAULT_REGION") + # Restart API + aws elasticbeanstalk restart-app-server --environment-name Pmip-env + #--region $([ -z "$AWS_DEFAULT_REGION" ] && aws configure get region || echo "$AWS_DEFAULT_REGION") fi