Add TF_TryEvaluateConstant to the C API and have smart_cond call it.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 6 Mar 2018 21:38:56 +0000 (13:38 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 21:42:58 +0000 (13:42 -0800)
This effectively plumbs EvaluateConstantTensor to smart_cond. This makes smart_cond even smarter by trying to evaluate the predicate
if it can't statically infer it.

PiperOrigin-RevId: 188073244

tensorflow/c/c_api.cc
tensorflow/c/c_api.h
tensorflow/python/client/tf_session.i
tensorflow/python/client/tf_session_helper.cc
tensorflow/python/client/tf_session_helper.h
tensorflow/python/framework/smart_cond.py
tensorflow/python/framework/smart_cond_test.py

index 85f1d16..3d0e886 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #endif
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eval_const_tensor.h"
 #include "tensorflow/core/common_runtime/shape_refiner.h"
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/log_memory.h"
@@ -73,6 +74,7 @@ using tensorflow::NodeBuilder;
 using tensorflow::NodeDef;
 using tensorflow::OpDef;
 using tensorflow::OpRegistry;
+using tensorflow::OutputTensor;
 using tensorflow::PartialTensorShape;
 using tensorflow::RunMetadata;
 using tensorflow::RunOptions;
@@ -2682,6 +2684,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
                 output_values, target_names, nullptr, status);
 }
 
+unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
+                                     TF_Tensor** result, TF_Status* status) {
+  *result = nullptr;
+  mutex_lock l(graph->mu);
+  OutputTensor tensor(&output.oper->node, output.index);
+  bool evaluated;
+  Tensor result_tensor;
+  status->status = EvaluateConstantTensor(
+      tensor, graph->refiner, *graph->graph.op_registry(),
+      graph->graph.versions().producer(), &evaluated, &result_tensor);
+  if (evaluated) {
+    DCHECK(status->status.ok());
+    *result = TF_TensorFromTensor(result_tensor, status);
+    if (!status->status.ok()) evaluated = false;
+  }
+  return evaluated;
+}
+
 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
   tensorflow::OpList op_list;
   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
index ad592ef..b32f574 100644 (file)
@@ -1275,13 +1275,22 @@ TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto(
 // Deleting a function does not remove it from any graphs it was copied to.
 TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func);
 
+// Attempts to evaluate `output`. This will only be possible if `output` doesn't
+// depend on any graph inputs (this function is safe to call if this isn't the
+// case though).
+//
+// If the evaluation is successful, this function returns true and `output`s
+// value is returned in `result`. Otherwise returns false. An error status is
+// returned if something is wrong with the graph or input. Note that this may
+// return false even if no error status is set.
+TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph,
+                                                           TF_Output output,
+                                                           TF_Tensor** result,
+                                                           TF_Status* status);
+
 // TODO(josh11b): Register OpDef, available to all operations added
 // to this graph.
 
-// The following two may both benefit from a subgraph-definition API
-// that re-uses most of the graph-definition API.
-// TODO(andydavis): Add functions to a graph.
-
 // --------------------------------------------------------------------------
 // API for driving Graph execution.
 
index f305cd2..53557ac 100644 (file)
@@ -720,6 +720,8 @@ def TF_Reset(target, containers=None, config=None):
 }
 
 %unignore SetRequireShapeInferenceFns;
+%unignore TF_TryEvaluateConstant_wrapper;
+%noexception TF_TryEvaluateConstant_wrapper;
 
 %include "tensorflow/python/client/tf_session_helper.h"
 
index 361dbc2..a8ab917 100644 (file)
@@ -493,4 +493,19 @@ std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
   return input_strs;
 }
 
+PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
+                                         TF_Status* status) {
+  TF_Tensor* result_tensor;
+  bool evaluated =
+      TF_TryEvaluateConstant(graph, output, &result_tensor, status);
+  if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
+
+  Safe_TF_TensorPtr safe_result_tensor(result_tensor);
+  PyObject* out;
+  Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
+  Set_TF_Status_from_Status(status, s);
+  if (!s.ok()) Py_RETURN_NONE;
+  return out;
+}
+
 }  // namespace tensorflow
index 29d5b28..83318dc 100644 (file)
@@ -213,6 +213,11 @@ std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
     TF_ImportGraphDefResults* results);
 
+// If evaluation was possible, returns the numpy ndarray of the evaluated
+// result. Otherwise returns None.
+PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
+                                         TF_Status* status);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
index f97bb01..4f2f1db 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import control_flow_ops
@@ -74,6 +76,16 @@ def smart_constant_value(pred):
     pred_value = pred
   elif isinstance(pred, ops.Tensor):
     pred_value = tensor_util.constant_value(pred)
+    # TODO(skyewm): consider folding this into tensor_util.constant_value when
+    # _USE_C_API is removed (there may be performance and correctness bugs, so I
+    # wanted to limit the change hidden behind _USE_C_API).
+    # pylint: disable=protected-access
+    if pred_value is None and ops._USE_C_API:
+      with errors.raise_exception_on_not_ok_status() as status:
+        pred_value = c_api.TF_TryEvaluateConstant_wrapper(
+            pred.graph._c_graph, pred._as_tf_output(), status)
+    # pylint: enable=protected-access
+
   else:
     raise TypeError("`pred` must be a Tensor or a Python bool.")
   return pred_value
index b682506..3070355 100644 (file)
@@ -19,9 +19,11 @@ from __future__ import print_function
 
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import googletest
 
@@ -29,7 +31,7 @@ from tensorflow.python.platform import googletest
 @test_util.with_c_api
 class SmartCondTest(test_util.TensorFlowTestCase):
 
-  def testSmartCondTrue(self):
+  def testTrue(self):
     with ops.Graph().as_default():
       with session.Session():
         x = constant_op.constant(2)
@@ -38,7 +40,7 @@ class SmartCondTest(test_util.TensorFlowTestCase):
                                   lambda: math_ops.multiply(y, 5))
         self.assertEqual(z.eval(), 32)
 
-  def testSmartCondFalse(self):
+  def testFalse(self):
     with ops.Graph().as_default():
       with session.Session():
         x = constant_op.constant(4)
@@ -47,14 +49,48 @@ class SmartCondTest(test_util.TensorFlowTestCase):
                                   lambda: math_ops.multiply(y, 3))
         self.assertEqual(z.eval(), 9)
 
-  def testSmartCondMissingArg1(self):
+  def testUnknown(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = array_ops.placeholder(dtype=dtypes.int32)
+        y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
+                                  lambda: constant_op.constant(2))
+        self.assertEqual(y.eval(feed_dict={x: 1}), 1)
+        self.assertEqual(y.eval(feed_dict={x: -1}), 2)
+
+  def testEval(self):
+    # Constant expression evaluation only works with the C API enabled.
+    if not ops._USE_C_API: return
+
+    with ops.Graph().as_default():
+      with session.Session():
+        x = constant_op.constant(1)
+        y = constant_op.constant(2)
+        # x * y > 0 can be evaluated at graph construction time, so the false
+        # branch shouldn't be evaluated at all.
+        def raise_exception():
+          raise RuntimeError("did not expect to be called")
+        z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
+                                  raise_exception)
+        self.assertEqual(z.eval(feed_dict={x: 1}), 1)
+
+  def testPlaceholderWithDefault(self):
+    with ops.Graph().as_default():
+      with session.Session():
+        x = array_ops.placeholder_with_default(1, shape=())
+        y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
+                                  lambda: constant_op.constant(2))
+        self.assertEqual(y.eval(), 1)
+        self.assertEqual(y.eval(feed_dict={x: -1}), 2)
+
+  def testMissingArg1(self):
     with ops.Graph().as_default():
       with session.Session():
         x = constant_op.constant(1)
         with self.assertRaises(TypeError):
           smart_cond.smart_cond(True, false_fn=lambda: x)
 
-  def testSmartCondMissingArg2(self):
+  def testMissingArg2(self):
     with ops.Graph().as_default():
       with session.Session():
         x = constant_op.constant(1)