Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run the trained model using flask api? #129

Open
imsurinder90 opened this issue May 31, 2018 · 0 comments
Open

How to run the trained model using flask api? #129

imsurinder90 opened this issue May 31, 2018 · 0 comments

Comments

@imsurinder90
Copy link

I created a route localhost:5050/predict to run the model with the given statement.

{
    "statement": "You helped the customer in troubleshooting the Cable issue and you asked "
}

but It gives error below error:

  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\flask_restful\__init__.py", line 595, in dispatch_request
    resp = meth(*args, **kwargs)
  File "D:\surinder\ds\test\text_classification_projects\char-rnn-tensorflow\wordpredict.py", line 45, in post
    saver = tf.train.Saver()
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1311, in __init__
    self.build()
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1320, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "C:\Users\surinder.kumar01\AppData\Local\conda\conda\envs\tia\lib\site-packages\tensorflow\python\training\saver.py", line 1345, in _build
    raise ValueError("No variables to save")
ValueError: No variables to save

Here is the code:

from __future__ import print_function
import os
from six.moves import cPickle
import tensorflow as tf
from model import Model

from flask import Flask, request
from flask_restful import Resource, Api

app = Flask(__name__)
api = Api(app)

params = {
    'save_dir': 'save',
    'prime': '',
    'n': 500,
    'sample': 2
}


def get_model():
    with open(os.path.join(params['save_dir'], 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(params['save_dir'], 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    return chars, vocab, Model(saved_args, training=False)


class predict(Resource):

    chars, vocab, model = get_model()

    def sample(self, statement, args, chars, vocab, model, saver, ckpt):
        with tf.Session() as sess:
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                result = model.sample(sess, chars, vocab, args['n'], statement, args['sample']).encode('utf-8')
                return result

    def post(self):
        statement = request.get_json(silent=True)['statement']
        result = None

        # tf.global_variables_initializer().run()
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(params['save_dir'])
        # with tf.Session() as sess:
        result = self.sample(
            statement, params, predict.chars,
            predict.vocab, predict.model, saver, ckpt
        ).decode('utf-8').split(".")[0]
        return {
            'statement': statement,
            'full_statement': result
        }

api.add_resource(predict, '/')

if __name__ == "__main__":
    app.run(debug=True)

Please help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant