Enable eager execution while TensorFlow 2 model is loaded (#1945)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Wed, 26 Aug 2020 14:16:58 +0000 (17:16 +0300)
committerGitHub <noreply@github.com>
Wed, 26 Aug 2020 14:16:58 +0000 (17:16 +0300)
model-optimizer/mo/front/tf/loader.py

index 8684374..2a227b7 100644 (file)
@@ -226,6 +226,8 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
         if model_dir:
             # saved model directory
             try:
+                # enable eager execution temporarily while TensorFlow 2 model is being loaded
+                tf_v1.enable_eager_execution()
                 # code to extract GraphDef for TF 2.0 SavedModel format
                 # tf.saved_model.load function throws TypeError for TF 1.x SavedModel format in case TF 1.x installed
                 imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
@@ -233,8 +235,12 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
                 concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
                 frozen_func = convert_variables_to_constants_v2(concrete_func, lower_control_flow=False) # pylint: disable=E1123
                 graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
+                # disable eager execution since next steps are executed with a graph in non-eager mode
+                tf_v1.disable_eager_execution()
                 return graph_def, variables_values
             except (TypeError, KeyError):
+                # disable eager execution since TensorFlow 1 model is handled
+                tf_v1.disable_eager_execution()
                 # code to extract GraphDef for TF 1.0 SavedModel format
                 tags = saved_model_tags if saved_model_tags is not None else [tf_v1.saved_model.tag_constants.SERVING]
                 with tf_v1.Session() as sess: