Skip to content

Commit

Permalink
Merge pull request #39 from vlasenkoalexey/master
Browse files Browse the repository at this point in the history
Making GraphDef Editor compatible with TensorFlow 2.x
  • Loading branch information
BryanCutler authored Dec 15, 2020
2 parents f9b5b76 + 9d10761 commit 8935fe2
Show file tree
Hide file tree
Showing 17 changed files with 54 additions and 19 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,22 @@ $ pip install ./graph_def_editor
`test.out` at the root of the project.


## TensorFlow versions compartibility

GraphDef Editor is fully supported for TensorFlow versions 1.14.x and 1.15.x.
For TensorFlow 2.x some transforms might not work.

To execute tests for specific TensorFlow version run following comand from the repository root:
```sh
docker run -v ${PWD}:/v -w /v tensorflow/tensorflow:<version>[-py3] bash -c "pip3 install -U pytest && pytest"
```

Pre 2.2.0 TensorFlow versions have -py3 suffix indicating that Python3 should be used.

To execute specific test:
```sh
docker run -v ${PWD}:/v -w /v tensorflow/tensorflow:<version>[-py3] python -m tests.transform_test
```



2 changes: 1 addition & 1 deletion graph_def_editor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distutils import dir_util
import os
from six import string_types
import tensorflow as tf
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import Tuple, Dict, FrozenSet, Iterable, Union, Set, Any
Expand Down
2 changes: 1 addition & 1 deletion graph_def_editor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import Tuple, List, Iterable, Any, AbstractSet
Expand Down
2 changes: 1 addition & 1 deletion graph_def_editor/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import Tuple, Dict, Iterable, Union, Callable, Any
Expand Down
1 change: 0 additions & 1 deletion graph_def_editor/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import six
from six import iteritems
from six import StringIO
import tensorflow as tf

from graph_def_editor import select, util
from graph_def_editor import graph as gde_graph
Expand Down
2 changes: 1 addition & 1 deletion graph_def_editor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

import tensorflow as tf
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import AbstractSet
Expand Down
2 changes: 1 addition & 1 deletion graph_def_editor/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from six import iterkeys
from six import string_types
from six import StringIO
import tensorflow as tf
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import Iterable
Expand Down
2 changes: 1 addition & 1 deletion graph_def_editor/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import numpy as np
from six import iteritems, string_types
import tensorflow as tf
import tensorflow.compat.v1 as tf

from graph_def_editor import graph, node, tensor

Expand Down
3 changes: 2 additions & 1 deletion tests/edit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
import unittest

import graph_def_editor as gde
Expand Down
4 changes: 2 additions & 2 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"""

import unittest
import tensorflow as tf
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
import shutil
import tempfile

Expand Down
4 changes: 3 additions & 1 deletion tests/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down
4 changes: 3 additions & 1 deletion tests/reroute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down
8 changes: 7 additions & 1 deletion tests/rewrite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import shutil
import tempfile
import unittest
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import numpy as np

import graph_def_editor as gde
Expand Down Expand Up @@ -759,3 +761,7 @@ def test_fold_batch_norms_up_fused_relu6(self):
# Make sure the rewrite happened
for n in g.nodes:
self.assertNotEqual(n.op_type, "FusedBatchNorm")


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion tests/select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import re

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down
4 changes: 3 additions & 1 deletion tests/subgraph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down
4 changes: 3 additions & 1 deletion tests/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import functools
import numpy as np

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down
8 changes: 5 additions & 3 deletions tests/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

import unittest

import graph_def_editor as gde
Expand Down Expand Up @@ -59,7 +61,7 @@ def test_unique_graph(self):
a0, b0, a1, b1 = (g0["a"], g0["b"], g1["a"], g1["b"])

print("g0['a'] returns {} (type {})".format(g0['a'], type(g0['a'])))

# Same graph, should be fine.
self.assertIsNone(gde.util.check_graphs(a0, b0))
# Two different graphs, should assert.
Expand Down Expand Up @@ -198,4 +200,4 @@ def test_identity(self):


if __name__ == "__main__":
unittest.TestCase.main()
unittest.main()

0 comments on commit 8935fe2

Please sign in to comment.