Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the translation results along with training and scripts in notebooks are different #12

Open
annisamansa opened this issue Dec 25, 2020 · 0 comments

Comments

@annisamansa
Copy link

annisamansa commented Dec 25, 2020

Dear authors,
I wrote a script according to Load_model_and_translate_baseline.ipynb to translate some test file. But I found that the translation results are so different from results that decoded during training. What's more, the translation results are quite far apart from source text.
And I'm sure the source text is the same with dev set.
How can I make them the same.............
Below is my script...

import sys
import pickle
import numpy as np
import tensorflow as tf

REPO_PATH = 'xxxxxxxxxxxxx/good-translation-wrong-in-context/'
sys.path.insert(0, REPO_PATH)
print(sys.path)
import lib
import lib.task.seq2seq.models.transformer as tr
VOC_PATH = REPO_PATH + '/scripts/build/'

def cli_main():
inp_voc = pickle.load(open(VOC_PATH + 'src.voc', 'rb'))
out_voc = pickle.load(open(VOC_PATH + 'dst.voc', 'rb'))
testid = inp_voc.ids(['BOS', 'EOS', 'UNK'])
print(testid)

tf.reset_default_graph()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.99, allow_growth=True)
sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))

hp = {
    # the same with training script
    }
model = tr.Model('mod', inp_voc, out_voc, inference_mode='fast', **hp)

path_to_ckpt = REPO_PATH + '/scripts/build/checkpoint/model-2048.npz'
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
lib.train.saveload.load(path_to_ckpt, var_list)

path_to_testset = REPO_PATH + '/scripts/e2c/test_src'
test_src = open(path_to_testset).readlines()
print(test_src)

model.translate_lines(test_src)

if name == 'main':
cli_main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant