diff --git a/foolbox/models/tensorflow.py b/foolbox/models/tensorflow.py index 730b413f..ab0c0db6 100644 --- a/foolbox/models/tensorflow.py +++ b/foolbox/models/tensorflow.py @@ -50,6 +50,8 @@ def __init__( self._created_session = True else: self._created_session = False + assert session.graph == images.graph, \ + 'The default session uses the wrong graph' with session.graph.as_default(): self._session = session