#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eval_const_tensor.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
using tensorflow::NodeDef;
using tensorflow::OpDef;
using tensorflow::OpRegistry;
+using tensorflow::OutputTensor;
using tensorflow::PartialTensorShape;
using tensorflow::RunMetadata;
using tensorflow::RunOptions;
output_values, target_names, nullptr, status);
}
+unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
+ TF_Tensor** result, TF_Status* status) {
+ *result = nullptr;
+ mutex_lock l(graph->mu);
+ OutputTensor tensor(&output.oper->node, output.index);
+ bool evaluated;
+ Tensor result_tensor;
+ status->status = EvaluateConstantTensor(
+ tensor, graph->refiner, *graph->graph.op_registry(),
+ graph->graph.versions().producer(), &evaluated, &result_tensor);
+ if (evaluated) {
+ DCHECK(status->status.ok());
+ *result = TF_TensorFromTensor(result_tensor, status);
+ if (!status->status.ok()) evaluated = false;
+ }
+ return evaluated;
+}
+
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
tensorflow::OpList op_list;
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
// Deleting a function does not remove it from any graphs it was copied to.
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func);
+// Attempts to evaluate `output`. This will only be possible if `output` doesn't
+// depend on any graph inputs (this function is safe to call if this isn't the
+// case though).
+//
+// If the evaluation is successful, this function returns true and `output`s
+// value is returned in `result`. Otherwise returns false. An error status is
+// returned if something is wrong with the graph or input. Note that this may
+// return false even if no error status is set.
+TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph,
+ TF_Output output,
+ TF_Tensor** result,
+ TF_Status* status);
+
// TODO(josh11b): Register OpDef, available to all operations added
// to this graph.
-// The following two may both benefit from a subgraph-definition API
-// that re-uses most of the graph-definition API.
-// TODO(andydavis): Add functions to a graph.
-
// --------------------------------------------------------------------------
// API for driving Graph execution.
}
%unignore SetRequireShapeInferenceFns;
+%unignore TF_TryEvaluateConstant_wrapper;
+%noexception TF_TryEvaluateConstant_wrapper;
%include "tensorflow/python/client/tf_session_helper.h"
return input_strs;
}
+PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
+ TF_Status* status) {
+ TF_Tensor* result_tensor;
+ bool evaluated =
+ TF_TryEvaluateConstant(graph, output, &result_tensor, status);
+ if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE;
+
+ Safe_TF_TensorPtr safe_result_tensor(result_tensor);
+ PyObject* out;
+ Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out);
+ Set_TF_Status_from_Status(status, s);
+ if (!s.ok()) Py_RETURN_NONE;
+ return out;
+}
+
} // namespace tensorflow
std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
TF_ImportGraphDefResults* results);
+// If evaluation was possible, returns the numpy ndarray of the evaluated
+// result. Otherwise returns None.
+PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output,
+ TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
pred_value = pred
elif isinstance(pred, ops.Tensor):
pred_value = tensor_util.constant_value(pred)
+ # TODO(skyewm): consider folding this into tensor_util.constant_value when
+ # _USE_C_API is removed (there may be performance and correctness bugs, so I
+ # wanted to limit the change hidden behind _USE_C_API).
+ # pylint: disable=protected-access
+ if pred_value is None and ops._USE_C_API:
+ with errors.raise_exception_on_not_ok_status() as status:
+ pred_value = c_api.TF_TryEvaluateConstant_wrapper(
+ pred.graph._c_graph, pred._as_tf_output(), status)
+ # pylint: enable=protected-access
+
else:
raise TypeError("`pred` must be a Tensor or a Python bool.")
return pred_value
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@test_util.with_c_api
class SmartCondTest(test_util.TensorFlowTestCase):
- def testSmartCondTrue(self):
+ def testTrue(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(2)
lambda: math_ops.multiply(y, 5))
self.assertEqual(z.eval(), 32)
- def testSmartCondFalse(self):
+ def testFalse(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(4)
lambda: math_ops.multiply(y, 3))
self.assertEqual(z.eval(), 9)
- def testSmartCondMissingArg1(self):
+ def testUnknown(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = array_ops.placeholder(dtype=dtypes.int32)
+ y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
+ lambda: constant_op.constant(2))
+ self.assertEqual(y.eval(feed_dict={x: 1}), 1)
+ self.assertEqual(y.eval(feed_dict={x: -1}), 2)
+
+ def testEval(self):
+ # Constant expression evaluation only works with the C API enabled.
+ if not ops._USE_C_API: return
+
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ y = constant_op.constant(2)
+ # x * y > 0 can be evaluated at graph construction time, so the false
+ # branch shouldn't be evaluated at all.
+ def raise_exception():
+ raise RuntimeError("did not expect to be called")
+ z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
+ raise_exception)
+ self.assertEqual(z.eval(feed_dict={x: 1}), 1)
+
+ def testPlaceholderWithDefault(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = array_ops.placeholder_with_default(1, shape=())
+ y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
+ lambda: constant_op.constant(2))
+ self.assertEqual(y.eval(), 1)
+ self.assertEqual(y.eval(feed_dict={x: -1}), 2)
+
+ def testMissingArg1(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
smart_cond.smart_cond(True, false_fn=lambda: x)
- def testSmartCondMissingArg2(self):
+ def testMissingArg2(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)