Make tf.py_func and tf.smart_cond play better with eager mode.
authorAkshay Modi <nareshmodi@google.com>
Sat, 17 Feb 2018 00:40:02 +0000 (16:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 17 Feb 2018 00:44:38 +0000 (16:44 -0800)
PiperOrigin-RevId: 186063941

tensorflow/python/layers/utils.py
tensorflow/python/ops/script_ops.py

index 1bbf4e6dffd3415ba246e26cd92923df8116edab..1195284024fad488bf1792a6e017a5bb6271d214 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.ops import variables
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.framework import ops
@@ -201,6 +202,12 @@ def smart_cond(pred, fn1, fn2, name=None):
   if not callable(fn2):
     raise TypeError('`fn2` must be callable.')
 
+  if context.in_eager_mode():
+    if pred:
+      return fn1()
+    else:
+      return fn2()
+
   pred_value = constant_value(pred)
   if pred_value is not None:
     if pred_value:
index 0ba29cbf329e2f36c5788bde8acc4ef7f1fe6f74..c7e8c28efdbfc116a70e72b0a04a7bd946f0e37d 100644 (file)
@@ -33,6 +33,7 @@ from tensorflow.python.eager import context
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_script_ops
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -318,6 +319,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
   Returns:
     A list of `Tensor` or a single `Tensor` which `func` computes.
   """
+  if context.in_eager_mode():
+    result = func(*[x.numpy() for x in inp])
+    result = nest.flatten(result)
+
+    return [x if x is None else ops.convert_to_tensor(x) for x in result]
+
   return _internal_py_func(
       func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)