forked from hunkim/DeepLearningZeroToAll
-
Notifications
You must be signed in to change notification settings - Fork 1
/
klab-12-2-rnn_long_char.py
73 lines (56 loc) · 2.2 KB
/
klab-12-2-rnn_long_char.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, TimeDistributed, Activation, LSTM
from keras.utils import np_utils
import os
# brew install graphviz
# pip3 install graphviz
# pip3 install pydot
from keras.utils.visualize_util import plot
# sample sentence
sentence = "if you want to build a ship, don't drum up people together to collect wood and don't assign them tasks and work, but rather teach them to long for the endless immensity of the sea."
char_set = list(set(sentence)) # id -> char ['i', 'l', 'e', 'o', 'h', ...]
char_dic = {w: i for i, w in enumerate(char_set)}
data_dim = len(char_set)
seq_length = timesteps = 10
nb_classes = len(char_set)
dataX = []
dataY = []
for i in range(0, len(sentence) - seq_length):
x_str = sentence[i:i + seq_length]
y_str = sentence[i + 1: i + seq_length + 1]
print(x_str, '->', y_str)
x = [char_dic[c] for c in x_str] # char to index
y = [char_dic[c] for c in y_str] # char to index
dataX.append(x)
dataY.append(y)
# One-hot encoding
dataX = np_utils.to_categorical(dataX, nb_classes=nb_classes)
# reshape X to be [samples, time steps, features]
dataX = np.reshape(dataX, (-1, seq_length, data_dim))
print(dataX.shape)
# One-hot encoding
dataY = np_utils.to_categorical(dataY, nb_classes=nb_classes)
# time steps
dataY = np.reshape(dataY, (-1, seq_length, data_dim))
print(dataY.shape)
model = Sequential()
model.add(LSTM(nb_classes, input_shape=(
timesteps, data_dim), return_sequences=True))
model.add(LSTM(nb_classes, return_sequences=True))
model.add(TimeDistributed(Dense(nb_classes)))
model.add(Activation('softmax'))
model.summary()
# Store model graph in png
#plot(model, to_file=os.path.basename(__file__) + '.png', show_shapes=True)
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop', metrics=['accuracy'])
model.fit(dataX, dataY, nb_epoch=1000)
predictions = model.predict(dataX, verbose=0)
for i, prediction in enumerate(predictions):
# print(prediction)
x_index = np.argmax(dataX[i], axis=1)
x_str = [char_set[j] for j in x_index]
index = np.argmax(prediction, axis=1)
result = [char_set[j] for j in index]
print(''.join(x_str), ' -> ', ''.join(result))