Actually expose smart_cond and smart_constant_value in tf.contrib.framework
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 26 Feb 2018 19:43:14 +0000 (11:43 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
Also moves these methods into their own file in python/framework. This avoids further bloating control_flow_ops.py and makes the BUILD deps easier for a future change I'm working on.

PiperOrigin-RevId: 187055501

tensorflow/contrib/framework/BUILD
tensorflow/contrib/framework/__init__.py
tensorflow/python/BUILD
tensorflow/python/framework/smart_cond.py [new file with mode: 0644]
tensorflow/python/framework/smart_cond_test.py [new file with mode: 0644]
tensorflow/python/layers/utils.py
tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/control_flow_ops_test.py

index 1accb31..50868c6 100644 (file)
@@ -63,6 +63,7 @@ tf_custom_op_py_library(
         "//tensorflow/python:platform",
         "//tensorflow/python:pywrap_tensorflow",
         "//tensorflow/python:script_ops",
+        "//tensorflow/python:smart_cond",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:state_ops",
         "//tensorflow/python:state_ops_gen",
index deeb5be..8063250 100644 (file)
@@ -87,6 +87,9 @@ See the @{$python/contrib.framework} guide.
 
 @@get_placeholders
 
+@@smart_cond
+@@smart_constant_value
+
 @@CriticalSection
 
 @@BoundedTensorSpec
@@ -104,10 +107,10 @@ from tensorflow.contrib.framework.python.ops import *
 
 from tensorflow.python.framework.ops import prepend_name_scope
 from tensorflow.python.framework.ops import strip_name_scope
+from tensorflow.python.framework.smart_cond import smart_cond
+from tensorflow.python.framework.smart_cond import smart_constant_value
 from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
 from tensorflow.python.framework.tensor_spec import TensorSpec
-from tensorflow.python.ops.control_flow_ops import smart_cond
-from tensorflow.python.ops.control_flow_ops import smart_constant_value
 from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = ['nest']
index 4c8c735..b0cb48c 100644 (file)
@@ -766,6 +766,31 @@ py_library(
 )
 
 py_library(
+    name = "smart_cond",
+    srcs = ["framework/smart_cond.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":control_flow_ops",
+        ":tensor_util",
+    ],
+)
+
+py_test(
+    name = "smart_cond_test",
+    size = "small",
+    srcs = ["framework/smart_cond_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":constant_op",
+        ":framework_ops",
+        ":math_ops",
+        ":session",
+        ":smart_cond",
+    ],
+)
+
+py_library(
     name = "sparse_tensor",
     srcs = ["framework/sparse_tensor.py"],
     srcs_version = "PY2AND3",
@@ -4091,6 +4116,7 @@ py_library(
         ":control_flow_ops",
         ":framework_for_generated_wrappers",
         ":platform",
+        ":smart_cond",
         ":tensor_util",
         ":util",
         ":variable_scope",
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
new file mode 100644 (file)
index 0000000..f97bb01
--- /dev/null
@@ -0,0 +1,79 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""smart_cond and related utilties."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+
+
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+  If `pred` is a bool or has a constant value, we return either `true_fn()`
+  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+  Arguments:
+    pred: A scalar determining whether to return the result of `true_fn` or
+      `false_fn`.
+    true_fn: The callable to be performed if pred is true.
+    false_fn: The callable to be performed if pred is false.
+    name: Optional name prefix when using `tf.cond`.
+
+  Returns:
+    Tensors returned by the call to either `true_fn` or `false_fn`.
+
+  Raises:
+    TypeError: If `true_fn` or `false_fn` is not callable.
+  """
+  if not callable(true_fn):
+    raise TypeError("`true_fn` must be callable.")
+  if not callable(false_fn):
+    raise TypeError("`false_fn` must be callable.")
+
+  pred_value = smart_constant_value(pred)
+  if pred_value is not None:
+    if pred_value:
+      return true_fn()
+    else:
+      return false_fn()
+  else:
+    return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
+                                 name=name)
+
+
+def smart_constant_value(pred):
+  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+  Arguments:
+    pred: A scalar, either a Python bool or tensor.
+
+  Returns:
+    True or False if `pred` has a constant boolean value, None otherwise.
+
+  Raises:
+    TypeError: If `pred` is not a Tensor or bool.
+  """
+  if isinstance(pred, bool):
+    pred_value = pred
+  elif isinstance(pred, ops.Tensor):
+    pred_value = tensor_util.constant_value(pred)
+  else:
+    raise TypeError("`pred` must be a Tensor or a Python bool.")
+  return pred_value
diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py
new file mode 100644 (file)
index 0000000..b682506
--- /dev/null
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+@test_util.with_c_api
+class SmartCondTest(test_util.TensorFlowTestCase):
+
+  def testSmartCondTrue(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = constant_op.constant(2)
+        y = constant_op.constant(5)
+        z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
+                                  lambda: math_ops.multiply(y, 5))
+        self.assertEqual(z.eval(), 32)
+
+  def testSmartCondFalse(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = constant_op.constant(4)
+        y = constant_op.constant(3)
+        z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
+                                  lambda: math_ops.multiply(y, 3))
+        self.assertEqual(z.eval(), 9)
+
+  def testSmartCondMissingArg1(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = constant_op.constant(1)
+        with self.assertRaises(TypeError):
+          smart_cond.smart_cond(True, false_fn=lambda: x)
+
+  def testSmartCondMissingArg2(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = constant_op.constant(1)
+        with self.assertRaises(TypeError):
+          smart_cond.smart_cond(True, lambda: x)
+
+
+if __name__ == "__main__":
+  googletest.main()
index 484c6fc..3b156c3 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.eager import context
 from tensorflow.python.ops import variables
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond as smart_module
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.util import nest
 
@@ -201,7 +202,7 @@ def smart_cond(pred, true_fn=None, false_fn=None, name=None):
   if isinstance(pred, variables.Variable):
     return control_flow_ops.cond(
         pred, true_fn=true_fn, false_fn=false_fn, name=name)
-  return control_flow_ops.smart_cond(
+  return smart_module.smart_cond(
       pred, true_fn=true_fn, false_fn=false_fn, name=name)
 
 
@@ -228,7 +229,7 @@ def constant_value(pred):
 
   if isinstance(pred, variables.Variable):
     return None
-  return control_flow_ops.smart_constant_value(pred)
+  return smart_module.smart_constant_value(pred)
 
 
 def object_list_uid(object_list):
index 8218e60..152578c 100644 (file)
@@ -23,7 +23,6 @@ See the @{$python/control_flow_ops} guide.
 @@no_op
 @@count_up_to
 @@cond
-@@smart_cond
 @@case
 @@while_loop
 @@logical_and
@@ -2128,61 +2127,6 @@ def cond(pred,
 # pylint: enable=redefined-outer-name
 
 
-def smart_cond(pred, true_fn=None, false_fn=None, name=None):
-  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
-
-  If `pred` is a bool or has a constant value, we return either `true_fn()`
-  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
-
-  Arguments:
-    pred: A scalar determining whether to return the result of `true_fn` or
-      `false_fn`.
-    true_fn: The callable to be performed if pred is true.
-    false_fn: The callable to be performed if pred is false.
-    name: Optional name prefix when using `tf.cond`.
-
-  Returns:
-    Tensors returned by the call to either `true_fn` or `false_fn`.
-
-  Raises:
-    TypeError: If `true_fn` or `false_fn` is not callable.
-  """
-  if not callable(true_fn):
-    raise TypeError("`true_fn` must be callable.")
-  if not callable(false_fn):
-    raise TypeError("`false_fn` must be callable.")
-
-  pred_value = smart_constant_value(pred)
-  if pred_value is not None:
-    if pred_value:
-      return true_fn()
-    else:
-      return false_fn()
-  else:
-    return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
-
-
-def smart_constant_value(pred):
-  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
-
-  Arguments:
-    pred: A scalar, either a Python bool or tensor.
-
-  Returns:
-    True or False if `pred` has a constant boolean value, None otherwise.
-
-  Raises:
-    TypeError: If `pred` is not a Tensor or bool.
-  """
-  if isinstance(pred, bool):
-    pred_value = pred
-  elif isinstance(pred, ops.Tensor):
-    pred_value = tensor_util.constant_value(pred)
-  else:
-    raise TypeError("`pred` must be a Tensor or a Python bool.")
-  return pred_value
-
-
 def _resource_safe_shape(t):
   """Returns the shape of t or the variable it points to."""
   if t.dtype == dtypes.resource:
index adc8c51..f22f305 100644 (file)
@@ -350,42 +350,6 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
 
 
 @test_util.with_c_api
-class SmartCondTest(test_util.TensorFlowTestCase):
-
-  def testSmartCondTrue(self):
-    with ops.Graph().as_default():
-      with session.Session():
-        x = constant_op.constant(2)
-        y = constant_op.constant(5)
-        z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
-                                        lambda: math_ops.multiply(y, 5))
-        self.assertEqual(z.eval(), 32)
-
-  def testSmartCondFalse(self):
-    with ops.Graph().as_default():
-      with session.Session():
-        x = constant_op.constant(4)
-        y = constant_op.constant(3)
-        z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
-                                        lambda: math_ops.multiply(y, 3))
-        self.assertEqual(z.eval(), 9)
-
-  def testSmartCondMissingArg1(self):
-    with ops.Graph().as_default():
-      with session.Session():
-        x = constant_op.constant(1)
-        with self.assertRaises(TypeError):
-          control_flow_ops.smart_cond(True, false_fn=lambda: x)
-
-  def testSmartCondMissingArg2(self):
-    with ops.Graph().as_default():
-      with session.Session():
-        x = constant_op.constant(1)
-        with self.assertRaises(TypeError):
-          control_flow_ops.smart_cond(True, lambda: x)
-
-
-@test_util.with_c_api
 class CondTest(test_util.TensorFlowTestCase):
 
   def testCondTrue(self):