forked from hunkim/DeepLearningZeroToAll
-
Notifications
You must be signed in to change notification settings - Fork 1
/
klab-12-1-rnn_hello_char.py
65 lines (51 loc) · 1.74 KB
/
klab-12-1-rnn_hello_char.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, TimeDistributed, Activation, LSTM
from keras.utils import np_utils
import os
# brew install graphviz
# pip3 install graphviz
# pip3 install pydot
from keras.utils.visualize_util import plot
# sample text
sample = "hihello"
char_set = list(set(sample)) # id -> char ['i', 'l', 'e', 'o', 'h']
char_dic = {w: i for i, w in enumerate(char_set)}
x_str = sample[:-1]
y_str = sample[1:]
data_dim = len(char_set)
timesteps = len(y_str)
nb_classes = len(char_set)
print(x_str, y_str)
x = [char_dic[c] for c in x_str] # char to index
y = [char_dic[c] for c in y_str] # char to index
# One-hot encoding
x = np_utils.to_categorical(x, nb_classes=nb_classes)
# reshape X to be [samples, time steps, features]
x = np.reshape(x, (-1, len(x), data_dim))
print(x.shape)
# One-hot encoding
y = np_utils.to_categorical(y, nb_classes=nb_classes)
# time steps
y = np.reshape(y, (-1, len(y), data_dim))
print(y.shape)
model = Sequential()
model.add(LSTM(nb_classes, input_shape=(
timesteps, data_dim), return_sequences=True))
model.add(TimeDistributed(Dense(nb_classes)))
model.add(Activation('softmax'))
model.summary()
# Store model graph in png
#plot(model, to_file=os.path.basename(__file__) + '.png', show_shapes=True)
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop', metrics=['accuracy'])
model.fit(x, y, nb_epoch=1)
predictions = model.predict(x, verbose=0)
for i, prediction in enumerate(predictions):
print(prediction)
x_index = np.argmax(x[i], axis=1)
x_str = [char_set[j] for j in x_index]
print(x_index, ''.join(x_str))
index = np.argmax(prediction, axis=1)
result = [char_set[j] for j in index]
print(index, ''.join(result))