testonly = 1,
srcs = ["c_test_util.cc"],
hdrs = ["c_test_util.h"],
+ visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow:__subpackages__",
+ ],
deps = [
":c_api",
"//tensorflow/core:lib",
return Const(tensor.get(), graph, s, name);
}
-void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
- const char* name, TF_Operation** op, bool check) {
+void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op,
+ bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
- AddHelper(l, r, graph, s, name, &op, true);
+ AddOpHelper(l, r, graph, s, name, &op, true);
return op;
}
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
- AddHelper(l, r, graph, s, name, &op, false);
+ AddOpHelper(l, r, graph, s, name, &op, false);
return op;
}
return TF_FinishOperation(desc, s);
}
+void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op, bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
+ TF_AddInput(desc, {l, 0});
+ TF_AddInput(desc, {r, 0});
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
+TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_Operation* op;
+ BinaryOpHelper("Min", l, r, graph, s, name, &op, true);
+ return op;
+}
+
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add");
+TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name = "min");
+
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
const char* name = "neg");
],
deps = [
":c_api",
+ "//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/mutex.h"
using tensorflow::string;
namespace {
-bool IsCPU(tensorflow::Device* d) {
+bool IsCPU(const tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
-bool IsXLA(tensorflow::Device* d) {
+bool IsXLA(const tensorflow::Device* d) {
if (d == nullptr) return false;
const auto& device_type = d->attributes().device_type();
return device_type.find("XLA") != std::string::npos;
}
-string DeviceName(tensorflow::Device* d) {
+string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+std::atomic_int_fast64_t func_id_generator(0);
+#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
extern "C" {
return device->name().c_str();
}
+void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
+ op->use_xla = enable;
+#ifndef TENSORFLOW_EAGER_USE_XLA
+ LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
+ "built with XLA support.";
+#endif // TENSORFLOW_EAGER_USE_XLA
+}
+
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
// Questionable heuristic ...
//
}
return tensorflow::Status::OK();
}
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+// Synthesizes and returns a wrapper function over `op`, which must be a
+// primitive op (e.g. matmul).
+//
+// The wrapper function conforms to the function signature expected by
+// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
+// resources>. For example, if the op has input params <Const1, Arg2, Const3,
+// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
+// Resource4> as the input params to the synthesized function.
+//
+// It populates `const_input_types`, `arg_input_types` and
+// `op_input_to_func_input` based on the reordering results, that the caller can
+// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
+// `status` accordingly.
+const tensorflow::FunctionDef* OpToFunction(
+ TFE_Op* op, std::vector<TF_DataType>* const_input_types,
+ std::vector<TF_DataType>* arg_input_types,
+ tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
+ TF_Status* status) {
+ DCHECK(!op->is_function());
+
+ tensorflow::FunctionDef fdef;
+
+ // Get the OpDef of the op we are trying to encapsulate.
+ TFE_Context* ctx = op->ctx;
+ const tensorflow::OpRegistrationData* op_data;
+ {
+ tensorflow::tf_shared_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ }
+ const tensorflow::OpDef& op_def = op_data->op_def;
+
+ tensorflow::OpDef* signature = fdef.mutable_signature();
+
+ // Handle constant inputs.
+ const std::unordered_set<string> const_inputs(
+ *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
+
+ // First add place holders for the input args, so that we can refer to them by
+ // position in the next loop. Also tally up the resource inputs.
+ int num_resource_inputs = 0;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
+ ++num_resource_inputs;
+ }
+ signature->add_input_arg();
+ }
+
+ // Now we map the input params from `op_def` to `signature`, where the param
+ // ordering for `signature` is: <constants, args, resources>.
+ int const_index = 0;
+ int arg_index = const_inputs.size();
+ int resource_index = op_def.input_arg_size() - num_resource_inputs;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
+ tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
+ if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
+ VLOG(1) << "For const input, mapping op input " << i << " to func input "
+ << const_index;
+ (*op_input_to_func_input)[i] = const_index;
+ func_input_arg = signature->mutable_input_arg(const_index++);
+ const_input_types->push_back(
+ static_cast<TF_DataType>(op->inputs[i].dtype()));
+ } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
+ VLOG(1) << "For resource input, mapping op input " << i
+ << " to func input " << resource_index;
+ (*op_input_to_func_input)[i] = resource_index;
+ func_input_arg = signature->mutable_input_arg(resource_index++);
+ } else {
+ VLOG(1) << "For arg input, mapping op input " << i << " to func input "
+ << arg_index;
+ (*op_input_to_func_input)[i] = arg_index;
+ func_input_arg = signature->mutable_input_arg(arg_index++);
+ arg_input_types->push_back(
+ static_cast<TF_DataType>(op->inputs[i].dtype()));
+ }
+
+ func_input_arg->set_name(op_input_arg.name());
+ func_input_arg->set_type(op->inputs[i].dtype());
+ }
+ VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
+
+ // Resources args are at the end of the function input params, and we should
+ // have iterated over all of them.
+ DCHECK_EQ(signature->input_arg_size(), resource_index);
+
+ // Make the synthesized function's name unique.
+ signature->set_name(tensorflow::strings::StrCat(
+ op_def.name(), func_id_generator.fetch_add(1)));
+
+ // Add the node def and set its input names to match op_def's names.
+ const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
+ *fdef.add_node_def() = ndef;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
+ }
+ VLOG(1) << "Added NodeDef: " << fdef.DebugString();
+
+ // Fix the output names and set output types.
+ for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
+ const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
+ const string& out_tensor_name = tensorflow::strings::StrCat(
+ ndef.name(), ":", op_def_arg.name(), ":", 0);
+ arg->set_name(op_def_arg.name());
+ (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
+ const string& type_attr = op_def_arg.type_attr();
+ if (!type_attr.empty()) {
+ auto i = ndef.attr().find(type_attr);
+ if (i == ndef.attr().end()) {
+ status->status = tensorflow::errors::InvalidArgument(
+ tensorflow::strings::StrCat("Could not find attr ", type_attr,
+ " in NodeDef ", ndef.DebugString()));
+ return nullptr;
+ }
+ arg->set_type(i->second.type());
+ }
+ }
+ VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
+
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(fdef);
+ if (!status->status.ok()) return nullptr;
+ const auto ret = ctx->func_lib_def.Find(signature->name());
+ DCHECK(ret != nullptr);
+ return ret;
+}
+
+// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
+// via XLA.
+std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
+ VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
+ auto launch_op =
+ std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ if (op->device) {
+ TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+
+ const tensorflow::FunctionDef* fdef;
+ {
+ tensorflow::tf_shared_lock l(op->ctx->functions_mu);
+ fdef = op->ctx->func_lib_def.Find(op->name);
+ }
+ std::vector<TF_DataType> const_input_types;
+ std::vector<TF_DataType> arg_input_types;
+ tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
+ if (fdef == nullptr) {
+ // See if this is a primitive op, and if so create a function for it, so
+ // that _XlaLaunchOp can access it.
+ fdef = OpToFunction(op, &const_input_types, &arg_input_types,
+ &op_input_to_func_input, status);
+ if (!status->status.ok()) return nullptr;
+ } else {
+ // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
+ // functions, so we need to find another way to handle constant inputs.
+ for (int i = const_input_types.size();
+ i < fdef->signature().input_arg_size(); ++i) {
+ VLOG(1) << "Adding Targs from input arg " << i;
+ const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
+ arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
+ }
+ }
+ DCHECK(fdef != nullptr);
+
+ // Copy inputs and their devices.
+ // Since input param reordering may have occurred between `op` and `launch_op`
+ // via `op_input_to_func_input`, adjust the actual inputs accordingly.
+ launch_op->inputs = op->inputs;
+ launch_op->input_devices = op->input_devices;
+ if (!op_input_to_func_input.empty()) {
+ DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
+ if (!op->input_devices.empty()) {
+ DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
+ }
+ for (int i = 0; i < op_input_to_func_input.size(); ++i) {
+ VLOG(1) << "mapping op input " << i << " to func input "
+ << op_input_to_func_input[i];
+
+ launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
+ if (!op->input_devices.empty()) {
+ launch_op->input_devices[op_input_to_func_input[i]] =
+ op->input_devices[i];
+ }
+ }
+ }
+ launch_op->attrs.NumInputs(op->inputs.size());
+
+ TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
+ const_input_types.size());
+
+ // Set Targs and Nresources attrs.
+ TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
+ arg_input_types.size());
+ const int num_resource_inputs = fdef->signature().input_arg_size() -
+ const_input_types.size() -
+ arg_input_types.size();
+ TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
+
+ // Set Tresults attr.
+ std::vector<TF_DataType> tresults;
+ for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
+ tresults.push_back(static_cast<TF_DataType>(arg.type()));
+ }
+ TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
+ tresults.size());
+
+ // Set function attr.
+ tensorflow::AttrValue attr_value;
+ tensorflow::NameAttrList* func = attr_value.mutable_func();
+ func->set_name(fdef->signature().name());
+ launch_op->attrs.Set("function", attr_value);
+
+ return launch_op;
+}
+#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
tensorflow::Device* device =
(op->device == nullptr) ? ctx->devices()[0] : op->device;
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+ std::unique_ptr<TFE_Op> xla_launch_op;
+ if (op->use_xla && op->name != "_XlaLaunch") {
+ xla_launch_op = BuildXlaLaunch(op, status);
+ if (!status->status.ok()) {
+ return;
+ }
+ op = xla_launch_op.get();
+ }
+#endif // TENSORFLOW_EAGER_USE_XLA
+
std::vector<tensorflow::Tensor> outputs(1);
const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
TF_Status* status);
+// When 'enable' is set to 1, and if TensorFlow library is built with XLA
+// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
+//
+// If the library is not built with XLA support, this call would be a no-op.
+TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
+ unsigned char enable);
+
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
std::vector<tensorflow::Tensor> inputs;
std::vector<tensorflow::Device*> input_devices;
tensorflow::Device* device;
+ bool use_xla = false;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
return op;
}
+TFE_TensorHandle* TestAxisTensorHandle() {
+ int64_t dims[] = {1};
+ int data[] = {1};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+ TFE_TensorHandle* axis) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Min", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, input, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, axis, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpSetAttrBool(op, "keep_dims", 1);
+ TFE_OpSetAttrType(op, "Tidx", TF_INT32);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
+
+ return op;
+}
+
// If there is a GPU device, returns true and sets 'gpu_device_name'
// accordingly.
bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
TF_DeleteStatus(status);
}
-TEST(CAPI, Execute) {
+TEST(CAPI, Execute_MatMul_CPU) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteStatus(status);
}
+TEST(CAPI, Execute_Min_CPU) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* input = TestMatrixTensorHandle();
+ TFE_TensorHandle* axis = TestAxisTensorHandle();
+ TFE_Op* minOp = MinOp(ctx, input, axis);
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(minOp, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(minOp);
+ TFE_DeleteTensorHandle(input);
+ TFE_DeleteTensorHandle(axis);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float output[2] = {0};
+ EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
+ memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1, output[0]);
+ EXPECT_EQ(3, output[1]);
+ TF_DeleteStatus(status);
+}
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST(CAPI, Execute_MatMul_XLA_CPU) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_Op* matmul = MatMulOp(ctx, m, m);
+
+ TFE_OpSetXLACompilation(matmul, true);
+
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ // Running a primitive TF operator via XLA is not yet supported.
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(m);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ EXPECT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TF_DeleteStatus(status);
+}
+
+TEST(CAPI, Execute_Min_XLA_CPU) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* input = TestMatrixTensorHandle();
+ TFE_TensorHandle* axis = TestAxisTensorHandle();
+ TFE_Op* minOp = MinOp(ctx, input, axis);
+
+ TFE_OpSetXLACompilation(minOp, true);
+
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(minOp, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(minOp);
+ TFE_DeleteTensorHandle(input);
+ TFE_DeleteTensorHandle(axis);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float output[2] = {0};
+ EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
+ memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1, output[0]);
+ EXPECT_EQ(3, output[1]);
+ TF_DeleteStatus(status);
+}
+#endif // TENSORFLOW_EAGER_USE_XLA
+
TEST(CAPI, ExecuteWithTracing) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_DeleteStatus(status);
}
-TEST(CAPI, Function) {
+TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_DeleteStatus(status);
}
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST(CAPI, Function_ident_XLA_CPU) {
+ // First create a simple identity function.
+ TF_Graph* function_graph = TF_NewGraph();
+ TF_OperationDescription* arg_descr =
+ TF_NewOperation(function_graph, "Placeholder", "arg");
+ TF_SetAttrType(arg_descr, "dtype", TF_INT32);
+ TF_Status* status = TF_NewStatus();
+ TF_Operation* arg = TF_FinishOperation(arg_descr, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_OperationDescription* id_descr =
+ TF_NewOperation(function_graph, "Identity", "id");
+ TF_SetAttrType(id_descr, "T", TF_INT32);
+ TF_AddInput(id_descr, {arg, 0});
+ TF_Operation* id = TF_FinishOperation(id_descr, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_Output input{arg, 0};
+ TF_Output output{id, 0};
+ TF_Function* fn =
+ TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
+ &output, nullptr, nullptr, "test", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteGraph(function_graph);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+ TFE_ContextAddFunction(ctx, fn, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteFunction(fn);
+
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+ *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+ TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteTensor(t);
+
+ TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_OpAddInput(op, h, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+
+ // Now run it via XLA.
+ TFE_OpSetXLACompilation(op, true);
+
+ std::vector<TFE_TensorHandle*> result;
+ result.push_back(nullptr);
+ int num_retvals = 1;
+ TFE_Execute(op, result.data(), &num_retvals, status);
+ TFE_DeleteOp(op);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ ASSERT_EQ(num_retvals, 1);
+
+ TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteTensor(r);
+ TFE_DeleteTensorHandle(result[0]);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+#endif // TENSORFLOW_EAGER_USE_XLA
+
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(