"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:state_ops",
"//tensorflow/python:state_ops_gen",
@@get_placeholders
+@@smart_cond
+@@smart_constant_value
+
@@CriticalSection
@@BoundedTensorSpec
from tensorflow.python.framework.ops import prepend_name_scope
from tensorflow.python.framework.ops import strip_name_scope
+from tensorflow.python.framework.smart_cond import smart_cond
+from tensorflow.python.framework.smart_cond import smart_constant_value
from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
-from tensorflow.python.ops.control_flow_ops import smart_cond
-from tensorflow.python.ops.control_flow_ops import smart_constant_value
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest']
],
)
+py_library(
+ name = "smart_cond",
+ srcs = ["framework/smart_cond.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":tensor_util",
+ ],
+)
+
+py_test(
+ name = "smart_cond_test",
+ size = "small",
+ srcs = ["framework/smart_cond_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":framework_ops",
+ ":math_ops",
+ ":session",
+ ":smart_cond",
+ ],
+)
+
py_library(
name = "sparse_tensor",
srcs = ["framework/sparse_tensor.py"],
":control_flow_ops",
":framework_for_generated_wrappers",
":platform",
+ ":smart_cond",
":tensor_util",
":util",
":variable_scope",
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""smart_cond and related utilties."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+
+
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `true_fn` or `false_fn`.
+
+ Raises:
+ TypeError: If `true_fn` or `false_fn` is not callable.
+ """
+ if not callable(true_fn):
+ raise TypeError("`true_fn` must be callable.")
+ if not callable(false_fn):
+ raise TypeError("`false_fn` must be callable.")
+
+ pred_value = smart_constant_value(pred)
+ if pred_value is not None:
+ if pred_value:
+ return true_fn()
+ else:
+ return false_fn()
+ else:
+ return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
+ name=name)
+
+
+def smart_constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or tensor.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError: If `pred` is not a Tensor or bool.
+ """
+ if isinstance(pred, bool):
+ pred_value = pred
+ elif isinstance(pred, ops.Tensor):
+ pred_value = tensor_util.constant_value(pred)
+ else:
+ raise TypeError("`pred` must be a Tensor or a Python bool.")
+ return pred_value
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+@test_util.with_c_api
+class SmartCondTest(test_util.TensorFlowTestCase):
+
+ def testSmartCondTrue(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 5))
+ self.assertEqual(z.eval(), 32)
+
+ def testSmartCondFalse(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(4)
+ y = constant_op.constant(3)
+ z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 3))
+ self.assertEqual(z.eval(), 9)
+
+ def testSmartCondMissingArg1(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ smart_cond.smart_cond(True, false_fn=lambda: x)
+
+ def testSmartCondMissingArg2(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ smart_cond.smart_cond(True, lambda: x)
+
+
+if __name__ == "__main__":
+ googletest.main()
from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.framework import tensor_util
from tensorflow.python.util import nest
if isinstance(pred, variables.Variable):
return control_flow_ops.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
- return control_flow_ops.smart_cond(
+ return smart_module.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
if isinstance(pred, variables.Variable):
return None
- return control_flow_ops.smart_constant_value(pred)
+ return smart_module.smart_constant_value(pred)
def object_list_uid(object_list):
@@no_op
@@count_up_to
@@cond
-@@smart_cond
@@case
@@while_loop
@@logical_and
# pylint: enable=redefined-outer-name
-def smart_cond(pred, true_fn=None, false_fn=None, name=None):
- """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
-
- If `pred` is a bool or has a constant value, we return either `true_fn()`
- or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
-
- Arguments:
- pred: A scalar determining whether to return the result of `true_fn` or
- `false_fn`.
- true_fn: The callable to be performed if pred is true.
- false_fn: The callable to be performed if pred is false.
- name: Optional name prefix when using `tf.cond`.
-
- Returns:
- Tensors returned by the call to either `true_fn` or `false_fn`.
-
- Raises:
- TypeError: If `true_fn` or `false_fn` is not callable.
- """
- if not callable(true_fn):
- raise TypeError("`true_fn` must be callable.")
- if not callable(false_fn):
- raise TypeError("`false_fn` must be callable.")
-
- pred_value = smart_constant_value(pred)
- if pred_value is not None:
- if pred_value:
- return true_fn()
- else:
- return false_fn()
- else:
- return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
-
-
-def smart_constant_value(pred):
- """Return the bool value for `pred`, or None if `pred` had a dynamic value.
-
- Arguments:
- pred: A scalar, either a Python bool or tensor.
-
- Returns:
- True or False if `pred` has a constant boolean value, None otherwise.
-
- Raises:
- TypeError: If `pred` is not a Tensor or bool.
- """
- if isinstance(pred, bool):
- pred_value = pred
- elif isinstance(pred, ops.Tensor):
- pred_value = tensor_util.constant_value(pred)
- else:
- raise TypeError("`pred` must be a Tensor or a Python bool.")
- return pred_value
-
-
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
self.assertEquals(grad_x_false.eval(), 0.)
-@test_util.with_c_api
-class SmartCondTest(test_util.TensorFlowTestCase):
-
- def testSmartCondTrue(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(2)
- y = constant_op.constant(5)
- z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
- lambda: math_ops.multiply(y, 5))
- self.assertEqual(z.eval(), 32)
-
- def testSmartCondFalse(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(4)
- y = constant_op.constant(3)
- z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
- lambda: math_ops.multiply(y, 3))
- self.assertEqual(z.eval(), 9)
-
- def testSmartCondMissingArg1(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.smart_cond(True, false_fn=lambda: x)
-
- def testSmartCondMissingArg2(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.smart_cond(True, lambda: x)
-
-
@test_util.with_c_api
class CondTest(test_util.TensorFlowTestCase):