from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.framework import ops
if not callable(fn2):
raise TypeError('`fn2` must be callable.')
+ if context.in_eager_mode():
+ if pred:
+ return fn1()
+ else:
+ return fn2()
+
pred_value = constant_value(pred)
if pred_value is not None:
if pred_value:
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
+ if context.in_eager_mode():
+ result = func(*[x.numpy() for x in inp])
+ result = nest.flatten(result)
+
+ return [x if x is None else ops.convert_to_tensor(x) for x in result]
+
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)