Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of tensorflow.keras instead of keras #94

Merged
merged 4 commits into from
Nov 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/
.idea/
.DS_Store

*.h5
*.tsv
*.tar.gz
*out*
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
keras==2.3.1
numpy==1.17.3
numpy==1.16.2
# tensorflow-gpu==1.14.0
h5py
gast==0.2.2
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

setup(
name='keras-tcn',
version='2.8.3',
version='2.9.0',
description='Keras TCN',
author='Philippe Remy',
license='MIT',
long_description_content_type='text/markdown',
long_description=open('README.md').read(),
packages=['tcn'],
# manually install tensorflow or tensorflow-gpu
install_requires=['numpy',
'keras']
install_requires=[
'numpy==1.16.2',
'keras',
'gast==0.2.2'
]
)
8 changes: 4 additions & 4 deletions tasks/imdb_tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
"""
import numpy as np
from keras import Model, Input
from keras.datasets import imdb
from keras.layers import Dense, Dropout, Embedding
from keras.preprocessing import sequence
from tensorflow.keras import Model, Input
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Dense, Dropout, Embedding
from tensorflow.keras.preprocessing import sequence

from tcn import TCN

Expand Down
58 changes: 30 additions & 28 deletions tasks/save_reload_model.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
import os

import numpy as np
from keras import Model, Input
from keras.layers import Dense, Dropout, Embedding

from tcn import TCN
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.layers import Dense, Dropout, Embedding
from tcn.tcn import TCN

# simple TCN model.
# define input shape
max_len = 100
max_features = 50
i = Input(shape=(max_len,))
x = Embedding(max_features, 16)(i)
x = TCN(nb_filters=12,
dropout_rate=0.5, # with dropout here.
kernel_size=6,
dilations=[1, 2, 4])(x)
x = Dropout(0.5)(x) # and dropout here.
x = Dense(1, activation='sigmoid')(x)

model = Model(inputs=[i], outputs=[x])
# make model
model = Sequential(layers=[Embedding(max_features, 16, input_shape=(max_len,)),
TCN(nb_filters=12,
dropout_rate=0.5,
kernel_size=6,
dilations=[1, 2, 4]),
Dropout(0.5),
Dense(units=1, activation='sigmoid')])

if os.path.exists('tcn.npz'):
# Load checkpoint if file exists.
w = np.load('tcn.npz', allow_pickle=True)['w']
print('Model reloaded.')
model.set_weights(w.tolist())
else:
# Save the checkpoint.
w = np.array(model.get_weights())
np.savez_compressed(file='tcn.npz', w=w, allow_pickle=True)
print('First time.')
# get model as json string and save to file
model_as_json = model.to_json()
with open(r'model.json', "w") as json_file:
json_file.write(model_as_json)
# save weights to file (for this format, need h5py installed)
model.save_weights('weights.h5')

# Make inference.
# The value for [First time] and [Model reloaded] should be the same. Run the script twice!
inputs = np.ones(shape=(1, 100))
out1 = model.predict(inputs)[0, 0]
print('*' * 80)
print(out1)
print('Inference after creation:', out1)

# load model from file
loaded_json = open(r'model.json', 'r').read()
reloaded_model = model_from_json(loaded_json, custom_objects={'TCN': TCN})

# restore weights
reloaded_model.load_weights(r'weights.h5')

# Make inference.
out2 = reloaded_model.predict(inputs)[0, 0]
print('*' * 80)
print('Inference after loading:', out2)
2 changes: 1 addition & 1 deletion tasks/save_reload_sequential_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from keras.models import Sequential, model_from_json
from keras.layers import Dense, Dropout, Embedding
from tcn.tcn import TCN
from tcn import TCN

# define input shape
max_len = 100
Expand Down
13 changes: 7 additions & 6 deletions tasks/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
"""
import numpy as np
from keras import Sequential
from keras.callbacks import Callback
from keras.datasets import imdb
from keras.layers import Dense, Dropout, Embedding
from keras.preprocessing import sequence
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Dense, Dropout, Embedding
from tensorflow.keras.preprocessing import sequence

from tcn import TCN

Expand Down Expand Up @@ -50,7 +50,8 @@ class TestCallback(Callback):

def on_epoch_end(self, epoch, logs=None):
print(logs)
assert logs['val_accuracy'] > 0.78
acc_key = 'val_accuracy' if 'val_accuracy' in logs else 'val_acc'
assert logs[acc_key] > 0.78


print('Train...')
Expand Down
4 changes: 2 additions & 2 deletions tasks/time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras import Input, Model
from keras.layers import Dense
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense

from tcn import TCN

Expand Down
33 changes: 16 additions & 17 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import List

import keras.backend as K
import keras.layers
from keras import optimizers
from keras.layers import Layer
from keras.layers import Activation, Lambda, add
from keras.layers import Conv1D, SpatialDropout1D
from keras.layers import Dense, BatchNormalization
from keras.models import Input, Model
from tensorflow.keras import backend as K, Model, Input, optimizers
from tensorflow.keras import layers
from tensorflow.keras.layers import Activation, SpatialDropout1D, Lambda
from tensorflow.keras.layers import Layer
from tensorflow.python.keras.layers import Conv1D
from tensorflow.python.layers.core import Dense
from tensorflow.python.layers.normalization import BatchNormalization


class ResidualBlock(Layer):
Expand Down Expand Up @@ -58,7 +57,7 @@ def _add_and_activate_layer(self, layer):

Args:
layer: Appends layer to internal layer list and builds it based on the current output
shape of ResidualBlock. Updates current output shape.
shape of ResidualBlocK. Updates current output shape.

"""
self.residual_layers.append(layer)
Expand Down Expand Up @@ -90,7 +89,7 @@ def build(self, input_shape):

if not self.last_block:
# 1x1 conv to match the shapes (channel dimension).
name = 'conv1D_{}'.format(k+1)
name = 'conv1D_{}'.format(k + 1)
with K.name_scope(name):
# make and build this layer separately because it directly uses input_shape
self.shape_match_conv = Conv1D(filters=self.nb_filters,
Expand All @@ -100,7 +99,7 @@ def build(self, input_shape):
kernel_initializer=self.kernel_initializer)

else:
self.shape_match_conv = Lambda(lambda x: x, name='identity')
self.shape_match_conv = Lambda(lambda x: x, name='identity')

self.shape_match_conv.build(input_shape)
self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape)
Expand Down Expand Up @@ -128,7 +127,7 @@ def call(self, inputs, training=None):
x = layer(x)

x2 = self.shape_match_conv(inputs)
res_x = add([x2, x])
res_x = layers.add([x2, x])
return [self.final_activation(res_x), x]

def compute_output_shape(self, input_shape):
Expand Down Expand Up @@ -159,7 +158,7 @@ class TCN(Layer):
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
nb_stacks : The number of stacks of residual blocks to use.
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
activation: The activation used in the residual blocks o = Activation(x + F(x)).
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
Expand Down Expand Up @@ -237,7 +236,7 @@ def build(self, input_shape):
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
kernel_initializer=self.kernel_initializer,
last_block = len(self.residual_blocks)+1 == total_num_blocks,
last_block=len(self.residual_blocks) + 1 == total_num_blocks,
name='residual_block_{}'.format(len(self.residual_blocks))))
# build newest residual block
self.residual_blocks[-1].build(self.build_output_shape)
Expand Down Expand Up @@ -270,7 +269,7 @@ def call(self, inputs, training=None):
skip_connections.append(skip_out)

if self.use_skip_connections:
x = add(skip_connections)
x = layers.add(skip_connections)
if not self.return_sequences:
x = self.lambda_layer(x)
return x
Expand Down Expand Up @@ -303,7 +302,7 @@ def compiled_tcn(num_feat, # type: int
dilations, # type: List[int]
nb_stacks, # type: int
max_len, # type: int
output_len=1, #type: int
output_len=1, # type: int
padding='causal', # type: str
use_skip_connections=True, # type: bool
return_sequences=True,
Expand All @@ -328,7 +327,7 @@ def compiled_tcn(num_feat, # type: int
nb_stacks : The number of stacks of residual blocks to use.
max_len: The maximum sequence length, use None if the sequence length is dynamic.
padding: The padding to use in the convolutional layers.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual block.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
regression: Whether the output should be continuous or discrete.
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
Expand Down