-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
55 lines (47 loc) · 1.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import time
import math
from Encoder import EncoderRNN
from Decoder import DecoderRNN
from AttnDecoder import AttnDecoderRNN
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from helper_functions import prepareData,readLangs,filterPairs,trainIters,evaluate
use_cuda = torch.cuda.is_available()
'''
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 10
eng_prefixes = (
"i am ", "i m ",
"he is", "he s ",
"she is", "she s",
"you are", "you re ",
"we are", "we re ",
"they are", "they re "
)
'''
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))
teacher_forcing_ratio = 0.5
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words,
1, dropout_p=0.1)
if use_cuda:
encoder1 = encoder1.cuda()
attn_decoder1 = attn_decoder1.cuda()
trainIters(input_lang,output_lang,encoder1, attn_decoder1,pairs, n_iters=75000, print_every=5000)
def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(encoder1, attn_decoder1, input_sentence)
print('input =', input_sentence)
print('output =', ' '.join(output_words))
#evaluateAndShowAttention ("Your input sentence to be translated")