diff --git a/jack/debug/__init__.py b/jack/debug/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/jack/debug/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/jack/debug/base.py b/jack/debug/base.py new file mode 100644 index 00000000..6220c003 --- /dev/null +++ b/jack/debug/base.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +import tensorflow as tf + +import logging + +logger = logging.getLogger(__name__) + + +def test_update(feed_dict, train_op): + session = tf.Session() + session.run(tf.global_variables_initializer()) + + before = session.run(tf.trainable_variables()) + session.run(train_op, feed_dict=feed_dict) + after = session.run(tf.trainable_variables()) + + res = False + for b, a in zip(before, after): + # Check if anything changed + res |= (b != a).any() + + return res