name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
- "create_xla_launch_op.h",
],
deps = [
":common",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ],
- alwayslink = 1,
-)
-
-tf_cc_test(
- name = "create_xla_launch_op_test",
- srcs = [
- "create_xla_launch_op.h",
- "create_xla_launch_op_test.cc",
- ],
- deps = [
- ":create_xla_launch_op",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
],
)
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/create_xla_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
namespace tensorflow {
namespace {
-// Utility which searches for values in a sorted list by scanning over it once.
-// No matter how many times ScanForValue is called, the list is scanned at most
-// once. However, if a call to ScanForValue skips over a value, that value is
-// not revisited in future calls to ScanForValue, so callers must take
-// care to order their calls.
+// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
+// 'ndef' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
//
-// Useful for merging multiple sorted lists in O(n) time.
-class SinglePassSearch {
- public:
- // Creates a SinglePassSearch object that can be used to search in `values`.
- // Does not take ownership of `values`. `values` must outlive this.
- // `values` must be sorted.
- explicit SinglePassSearch(const std::vector<int>* values)
- : current_index_(0), values_(values) {}
-
- // Scans forward in the vector looking for "value", updating the internal
- // position in to the vector.
- // Returns true iff the vector contains the given value at or after current
- // position.
- // Not thread-safe.
- bool ScanForValue(int value) {
- while (current_index_ < values_->size() &&
- (*values_)[current_index_] <= value) {
- if ((*values_)[current_index_] == value) {
- current_index_++;
- return true;
- }
- current_index_++;
- }
- return false;
- }
-
- private:
- int current_index_;
- const std::vector<int>* values_;
-};
-
-Status CompilationRequested(const FunctionLibraryRuntime& flr,
- const NodeDef& node_def) {
+// This routine is here so that FunctionLibraryRuntime can jit a
+// specific function call as requested.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
+ std::unique_ptr<OpKernel>* kernel) {
bool xla_compile = false;
- // Check if op is marked _XlaCompile=true.
- Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
- node_def, kXlaCompileAttr, &xla_compile);
- if (!status.ok() || !xla_compile) {
- if (VLOG_IS_ON(3)) {
- if (!status.ok()) {
- VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
- << node_def.op() << ". status=" << status.ToString();
- } else {
- VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
- }
- }
- return Status(error::INVALID_ARGUMENT, "");
+ if (!flr->GetFunctionLibraryDefinition()
+ ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
+ .ok() ||
+ !xla_compile) {
+ // Not marked as _XlaCompile=true.
+ return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+ }
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterCompilationKernels();
+ if (!IsCompilable(flr, ndef)) {
+ // ndef is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
}
- return Status::OK();
-}
-
-// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
-// runtime, returns this function's body in `fbody` as well as the indices
-// of its constant and resource arguments.
-// `fbody` is owned by `flr`.
-// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
-// They are sorted in ascending order on this function's return.
-Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
- const NodeDef& node_def,
- const FunctionBody** fbody,
- std::vector<int>* constant_arg_indices,
- std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
- // If node_def is not instantiable, e.g., the function does not exist,
+ // If ndef is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
- flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
- *fbody = flr->GetFunctionBody(handle);
- CHECK(*fbody); // Can't be nullptr since we just instantiated it.
- const DataTypeVector& arg_types = (*fbody)->arg_types;
- std::vector<bool> const_args(arg_types.size());
+ flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
+ const FunctionBody* fbody = flr->GetFunctionBody(handle);
+ CHECK(fbody); // Can't be nullptr since we just instantiated it.
+ std::vector<bool> const_args(fbody->arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
- constant_arg_indices->push_back(i);
- }
- }
-
- // There can be hundreds of resource variables. Reserve the space for them.
- // We don't reserve for constants above as they are usually few.
- resource_arg_indices->reserve(arg_types.size());
- for (int i = 0; i < arg_types.size(); ++i) {
- if (arg_types[i] == DT_RESOURCE) {
- resource_arg_indices->push_back(i);
+ // There is a const arg. Bail out.
+ return errors::InvalidArgument("Const arg: ", i, " in ",
+ DebugString(fbody->fdef));
}
}
- return Status::OK();
-}
-
-} // namespace
-
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
- std::unique_ptr<OpKernel>* kernel) {
- TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
-
- VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
-
- // Make sure that kernels have been registered on the JIT device.
- XlaOpRegistry::RegisterCompilationKernels();
- if (!IsCompilable(flr, node_def)) {
- // node_def is calling a function that XLA can't compile.
- return errors::InvalidArgument("Not compilable: ",
- node_def.ShortDebugString());
- }
-
- // Get function body, constant args, and resource args.
- const FunctionBody* fbody = nullptr;
- std::vector<int> constant_arg_indices;
- std::vector<int> resource_arg_indices;
- TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
- flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
-
- // Set input and output memory types.
+ NodeDef launch_def;
+ launch_def.set_name(ndef.name());
+ launch_def.set_op("_XlaLaunch");
+ launch_def.set_device(flr->device()->name());
+ AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
+ AddNodeAttr("Nresources", 0, &launch_def);
+ AddNodeAttr("Targs", fbody->arg_types, &launch_def);
+ AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
+ NameAttrList func;
+ func.set_name(ndef.op());
+ *(func.mutable_attr()) = ndef.attr();
+ AddNodeAttr("function", func, &launch_def);
+
+ // TODO(b/32387911): Handles the host memory types across function
+ // calls properly. For now, we assume all inputs and outputs are on
+ // the device memory.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
- // These indices are used only for optimization purposes. They allow us
- // to loop over constant_arg_indices and resource_arg_indices only once
- // while iterating over all the function arguments checking if it is a
- // resource or a constant.
- // The reason we optimized this code is because functions can have a lot of
- // captured arguments. For example, the backward pass of ResNet50 takes in all
- // 214 variables and a similar number of activations.
- SinglePassSearch constants_search(&constant_arg_indices);
- SinglePassSearch resources_search(&resource_arg_indices);
- for (int i = 0; i < fbody->arg_types.size(); ++i) {
- if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
- // Compile-time constants and resource handles are expected to be in
- // host memory.
- input_memory_types[i] = HOST_MEMORY;
- }
- }
- // One might wonder, about the case where a compile-time constant argument
- // (which must be in host memory) is also used as an input into an op,
- // e.g. Add, that expects its inputs in device memory. Here is how it
- // works now.
- // First, what do we mean by "op expects an input in XYZ memory"?
- // There are two types of "ops" here: the tf2xla kernel and the HLO
- // computation it builds. The tf2xla kernel needs to retrieve the actual
- // numeric value of the compile-time constant tensors, so it really expects
- // them to be on in host memory. However, for other inputs, it refers to them
- // using xla::ComputationDataHandle, which is just a symbolic handle that
- // xla::ComputationBuilder assigns. How does this handle gets assigned for
- // constant arguments? Even constant arguments get an _Arg node in the graph
- // instatiated for Function compilation. The tf2xla kernel for constant _Arg
- // nodes takes the constant value, converts it to XlaLiteral, and feeds it
- // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
- // constant XlaLiteral is included in the HLO graph, and subsequently, in
- // the actual executable, which is copied to the device before being
- // executed. Thus, when this executable runs, the constant is available in
- // device memory.
-
- // XlaLaunch kernel keeps all outputs (including constants, which it copies),
- // in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
- // Create the kernel.
- NameAttrList function;
- function.set_name(node_def.op());
- *(function.mutable_attr()) = node_def.attr();
-
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
- dev->GetAllocator(AllocatorAttributes()), &node_def,
+ dev->GetAllocator(AllocatorAttributes()), &launch_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
-
- *kernel = absl::make_unique<XlaLocalLaunchBase>(
- &construction, constant_arg_indices, resource_arg_indices, function);
+ kernel->reset(new XlaLocalLaunchOp(&construction));
return s;
}
-namespace {
-
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
+++ /dev/null
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-class FunctionLibraryRuntime;
-class OpKernel;
-
-// Given a NodeDef 'node_def' and the function library runtime 'flr', if
-// 'node_def' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
- std::unique_ptr<OpKernel>* kernel);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+++ /dev/null
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/create_xla_launch_op.h"
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session_options.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-NodeDef ToNodeDef(const string& text) {
- NodeDef node_def;
- EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
- return node_def;
-}
-
-// Create a FunctionDef that takes one resource and one regular param
-FunctionDef XTimesY() {
- return FunctionDefHelper::Define(
- // Name
- "XTimesY",
- // Args
- {"x: float", "y: resource"},
- // Return values
- {"z: float"},
- // Attr def
- {},
- // Nodes
- {
- {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
- {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
- });
-}
-
-class CreateXlaLaunchOpTest : public ::testing::Test {
- protected:
- void Init(const std::vector<FunctionDef>& flib) {
- SessionOptions options;
- auto* device_count = options.config.mutable_device_count();
- device_count->insert({"CPU", 1});
- TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
-
- FunctionDefLibrary proto;
- for (const auto& fdef : flib) {
- *(proto.add_function()) = fdef;
- }
- lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
- OpRegistry::Global(), proto);
- OptimizerOptions opts;
- device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
- pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
- device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
- flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
- }
-
- FunctionLibraryRuntime* flr_;
- std::vector<Device*> devices_;
- std::unique_ptr<DeviceMgr> device_mgr_;
- std::unique_ptr<FunctionLibraryDefinition> lib_def_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
-
- std::unique_ptr<OpKernel> kernel_;
-};
-
-AttrValue BoolAttr(bool b) {
- AttrValue v;
- v.set_b(b);
- return v;
-}
-
-TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
- FunctionDef fdef = XTimesY();
- (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
- Init({fdef});
-
- Status status = CreateXlaLaunchOp(
- flr_, ToNodeDef(R"pb(
- name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
- )pb"), &kernel_);
- ASSERT_TRUE(status.ok()) << status.ToString();
-
- EXPECT_EQ("XTimesY", kernel_->name());
- EXPECT_EQ("XTimesY", kernel_->type_string());
-
- EXPECT_EQ(2, kernel_->num_inputs());
- EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
- EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
- EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
- EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
-
- EXPECT_EQ(1, kernel_->num_outputs());
- EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
- EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
-}
-
-TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
- FunctionDef fdef = XTimesY();
- Init({fdef});
-
- Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
- name: 'XTimesY'
- op: 'XTimesY'
- input: 'a'
- input: 'b'
- )proto"), &kernel_);
- EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
-}
-
-TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
- FunctionDef fdef = XTimesY();
- (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
- Init({fdef});
-
- Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
- name: 'XTimesY'
- op: 'XTimesY'
- input: 'a'
- input: 'b'
- )proto"), &kernel_);
- EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
-}
-
-} // namespace tensorflow
namespace tensorflow {
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), device_type_(ctx->device_type()) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
+ function_ = *func;
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
+ num_constant_args_ = constant_types.size();
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
}
}
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
+Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
return Status::OK();
}
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOp::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
}
std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
+ SnapshotResourceVariables(ctx, num_resource_args_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
xla::LocalExecutable* executable;
std::map<int, Tensor> constant_args;
- for (int i : constants_) {
+ for (int i = 0; i < num_constant_args_; ++i) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(client, xla_allocator,
- allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(
+ num_resource_args_, client, xla_allocator, allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
VLOG(1) << "Done";
}
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
namespace tensorflow {
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_;
-};
-
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+class XlaLocalLaunchOp : public OpKernel {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
+ void Compute(OpKernelContext* ctx) override;
+
private:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** compiler);
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ int num_constant_args_;
+ // Number of resource variable arguments.
+ int num_resource_args_;
+
+ se::Platform::Id platform_id_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
+ int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- client, client->backend().memory_allocator(), true);
+ num_resource_args, client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
using xla::ShapedBuffer;
} // anonymous namespace
-std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+ int num_variables) {
std::map<int, OptionalTensor> snapshot;
- for (int i : variables) {
+ int first_variable = ctx->num_inputs() - num_variables;
+ for (int i = 0; i < num_variables; ++i) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, i);
- OptionalTensor& tensor = snapshot[i];
+ ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
+ OptionalTensor& tensor = snapshot[first_variable + i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors)
- : client_(client),
+ int64 num_resource_args, xla::LocalClient* client,
+ xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
+ : num_resource_args_(num_resource_args),
+ client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
namespace tensorflow {
class XlaAllocator;
-// Takes a snapshot of the values of resource variable arguments, whose
-// indices are specified in `variables` argument. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, which are
+// the last `num_variables` arguments. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
-// Returns a map of TensorFlow argument index to resource variable. If a
-// resource variable is not initialized, the corresponding OptionalTensor
-// will have its `present` field set to false.
-std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+// Returns a map of TensorFlow argument index to resource variable.
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+ int num_variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
- XlaComputationLaunchContext(xla::LocalClient* client,
+ XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
+ int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:layers",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
"//tensorflow/python:platform_test",
- "//tensorflow/python/eager:function",
],
)
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.layers import convolutional
-from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.float32)
+ empty = constant_op.constant([], dtype=dtypes.int32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
self.assertEqual(2., grads[0][0].numpy())
-class EagerFunctionTest(XLATestCase):
-
- def testBasic(self):
- with self.test_scope():
- matmul = function.defun(math_ops.matmul, compiled=True)
- t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq = matmul(t, t, transpose_a=True)
- self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
-
- def testConv(self):
- if 'GPU' in self.device:
- # TODO(b/32333178)
- self.skipTest('Current implementation of RandomStandardNormal kernel '
- 'is very slow on GPU, and has been blacklisted.')
- with self.test_scope():
- data_format = 'channels_last'
- conv = convolutional.Conv2D(
- filters=1, kernel_size=2, padding='VALID',
- data_format=data_format, activation=nn_ops.relu,
- kernel_initializer=init_ops.ones_initializer(),
- bias_initializer=init_ops.zeros_initializer())
- pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
-
- def model(x):
- x = conv(x)
- return pool(x)
- model = function.defun(model, compiled=True)
-
- x = array_ops.ones([1, 4, 4, 1])
- y = model(x)
- self.assertAllEqual(y.numpy(), [[[[4.]]]])
-
- def testReadVariable(self):
- with self.test_scope():
- v = resource_variable_ops.ResourceVariable(1.0)
-
- @function.defun(compiled=True)
- def f():
- return v.read_value()
-
- var = f()
- self.assertEqual(1.0, var.numpy())
-
- def testUpdateVariable(self):
- with self.test_scope():
- v = resource_variable_ops.ResourceVariable(1.0)
-
- def f(v):
- v.assign_add(1.0)
- return v
-
- f = function.defun(f, compiled=True)
-
- var = f(v)
- self.assertEqual(2.0, var.numpy())
-
- def testAllArgumentKinds(self):
- """Test a complex function that takes different argument kinds.
-
- tf2xla machinery that translates, compiles, and runs defuns
- classifies arguments into: compile-time constants, regular tensors,
- and resources. This test creates a function with a mix of all these
- kinds. Moreover, the order of function arguments is intentionally mixed up.
-
- This also tests the case when the same argument is a compile-time constant
- as well as used in an operation that normally expects its inputs to be
- in device memory - addition in this case.
- """
- with self.test_scope():
- def foo(c1, r1, v1, c2, v2, r2):
- # c1 and c2 are compile-time constants
- # r1 and r2 are regular tensors
- # v1 and v2 are resource variables
- a = c1 + r1
- b = math_ops.cast(c2, dtypes.float32) + v2
- c = array_ops.slice(v1, c1, c2)
- d = r2 * v2
- return a, b, c, d
-
- foo = function.defun(foo, compiled=True)
-
- c1 = [0, 0]
- c2 = array_ops.ones([2], dtype=dtypes.int32)
-
- r1 = array_ops.ones([2])
- r2 = [[2., 2.], [3., 3.]]
-
- v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
- v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
-
- a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
-
- self.assertAllEqual([1, 1], a.numpy())
- self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
- self.assertAllEqual([[1.]], c.numpy())
- self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()
'channels_last')
-def random_batch(batch_size, data_format):
+def random_batch(batch_size, device_and_format=None):
+ _, data_format = device_and_format or device_and_data_format()
+
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
shape = (batch_size,) + shape
if defun:
model.call = tfe.defun(model.call)
with tf.device(device), tfe.execution_mode(execution_mode):
- images, _ = random_batch(2, data_format)
+ images, _ = random_batch(2)
output = model(images, training=False)
tfe.async_wait()
self.assertEqual((2, 1000), output.shape)
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
with tf.device(device):
- images, _ = random_batch(2, data_format)
+ images, _ = random_batch(2)
output = model(images, training=False)
output_shape = ((2, 2048, 1, 1)
if data_format == 'channels_first' else (2, 1, 1, 2048))
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
with tf.device(device):
- images, _ = random_batch(2, data_format)
+ images, _ = random_batch(2)
output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
name='t0').as_default(), tf.contrib.summary.always_record_summaries():
with tf.device(device), tfe.execution_mode(execution_mode):
optimizer = tf.train.GradientDescentOptimizer(0.1)
- images, labels = random_batch(2, data_format)
+ images, labels = random_batch(2)
train_one_step(model, images, labels, optimizer)
self.assertEqual(320, len(model.variables))
tfe.async_wait()
model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
- images, labels = random_batch(2, data_format)
+ images, labels = random_batch(2)
gc.disable()
# Warm up. Note that this first run does create significant amounts of
# garbage to be collected. The hope is that this is a build-only effect,
# which forces a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
+ device_and_format=None):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format
+ device, data_format = device_and_format or device_and_data_format()
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
with tf.device(device):
- images, _ = random_batch(batch_size, data_format)
+ images, _ = random_batch(batch_size, device_and_format)
for _ in xrange(num_burn):
model(images, training=False).cpu()
if execution_mode:
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_apply_sync(self):
- self._benchmark_eager_apply('eager_apply', device_and_data_format(),
- defun=False)
+ self._benchmark_eager_apply('eager_apply', defun=False)
def benchmark_eager_apply_async(self):
self._benchmark_eager_apply(
- 'eager_apply_async', device_and_data_format(), defun=False,
- execution_mode=tfe.ASYNC)
+ 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC)
def benchmark_eager_apply_with_defun(self):
- self._benchmark_eager_apply('eager_apply_with_defun',
- device_and_data_format(), defun=True)
+ self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
def _benchmark_eager_train(self,
label,
make_iterator,
- device_and_format,
defun=False,
execution_mode=None,
- compiled=False):
+ device_and_format=None):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format
+ device, data_format = device_and_format or device_and_data_format()
for batch_size in self._train_batch_sizes():
- (images, labels) = random_batch(batch_size, data_format)
+ (images, labels) = random_batch(batch_size, device_and_format)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train_sync(self):
- self._benchmark_eager_train('eager_train', MockIterator,
- device_and_data_format(), defun=False)
+ self._benchmark_eager_train('eager_train', MockIterator, defun=False)
def benchmark_eager_train_async(self):
self._benchmark_eager_train(
'eager_train_async',
MockIterator,
- device_and_data_format(),
defun=False,
execution_mode=tfe.ASYNC)
def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train(
- 'eager_train_with_defun', MockIterator,
- device_and_data_format(), defun=True)
+ 'eager_train_with_defun', MockIterator, defun=True)
def benchmark_eager_train_datasets(self):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset', make_iterator,
- device_and_data_format(), defun=False)
+ 'eager_train_dataset', make_iterator, defun=False)
def benchmark_eager_train_datasets_with_defun(self):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset_with_defun', make_iterator,
- device_and_data_format(), defun=True)
+ 'eager_train_dataset_with_defun', make_iterator, defun=True)
if __name__ == '__main__':
import numpy as np
-from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
class _EagerDefinedFunction(object):
"""Function object with the interface of tf _DefinedFunction."""
- def __init__(self, name, graph, operations, inputs, outputs, attrs):
+ def __init__(self, name, graph, operations, inputs, outputs):
"""Initializes an eager defined function.
Args:
which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
- attrs: dict mapping names of attributes to their AttrValue values
"""
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
[],
None,
compat.as_str(""))
-
- for name, attr_value in attrs.items():
- serialized = attr_value.SerializeToString()
- # TODO(iga): this creates and deletes a new TF_Status for every attr.
- # It might be worth creating a convenient way to re-use status.
- pywrap_tensorflow.TF_FunctionSetAttrValueProto(
- fn, compat.as_str(name), serialized)
-
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
class GraphModeFunction(object):
"""Callable object representing a graph-mode function.
+
+ Args:
+ name: str the name of the created function
+ input_placeholders: list of placeholder values (tensors) to feed when
+ calling the wrapped function.
+ extra_inputs: Tensor inputs this function definition closed over which
+ are passed as arguments. Need to track so gradients are supported
+ correctly.
+ graph: the Graph from which the operations will be pulled. Used as
+ a context when computing gradients.
+ operations: the subset of Operations in the graph used in the function
+ definition.
+ outputs: a flat list of the Tensors in the graph used as outputs to the
+ function
+ func_outputs: a possibly nested python object which will be returned by
+ this function. The Tensors in this structure will be replaced by their
+ corresponding values in outputs.
+ output_shapes: List of shapes of all tensors in outputs
+ variables: (optional) List of variables to watch during function execution.
"""
def __init__(self,
outputs,
func_outputs,
output_shapes,
- variables=None,
- attrs=None):
- """Initialize a GraphModeFunction.
-
- Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- func_outputs: a possibly nested python object which will be returned by
- this function. The Tensors in this structure will be replaced by their
- corresponding values in outputs.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function
- execution.
- attrs: (optional) dict mapping names of attributes to their AttrValue
- values. Attributes in `attrs` will be included in this function's
- definition.
- """
- self._attrs = attrs or {}
+ variables=None):
defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs, self._attrs)
+ name, graph, operations, input_placeholders, outputs)
if len(input_placeholders) != len(defined_function.signature.input_arg):
raise ValueError("Internal error: invalid lengths. %s %s" % (
len(input_placeholders), len(defined_function.signature.input_arg)))
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + captures, self._attrs)
+ filtered_outputs + captures)
all_inputs = self._out_grad_placeholders + captures
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
bname, all_inputs, [], self._graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
+ backward_outputs, in_gradients, output_shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, compiled, args, kwds):
+def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
-
- attrs = {}
- if compiled:
- attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
-
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
- func_outputs, output_shapes, variables, attrs)
+ func_outputs, output_shapes, variables)
# Defun uses this instead of Tensor as a cache key. Using dtype because
# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name, compiled=False):
+def named_defun(func, name):
"""Defines a function with a given name.
See the documentation for `defun` for more information on the semantics of the
Args:
func: the function to be wrapped.
name: the name given to it.
- compiled: if true, the framework will attempt to compile func with XLA.
Returns:
the wrapped function.
if cache_key not in arguments_to_functions:
arguments_to_functions[cache_key] = _defun_internal(
- name, func, compiled, args, kwds)
+ name, func, args, kwds)
return arguments_to_functions[cache_key](*args)
return decorated
-def defun(func=None, compiled=False):
+def defun(func):
"""Decorator to compile func into graph_mode.
`defun` converts a function that constructs a TensorFlow graph into a function
```
Args:
- func: function to be compiled. If `func` is None, returns a
- decorator that can be invoked with a single argument - `func`. The
- end result is equivalent to providing all the arguments up front.
- In other words, defun(compiled=True)(func) is equivalent to
- defun(func, compiled=True). The former allows the following use case:
- @tfe.defun(compiled=True)
- def foo(...):
- ...
- compiled: If True, an attempt to compile `func` with XLA will be made.
- If it fails, function will be run normally. Experimental.
- Currently, supported only for execution on TPUs.
+ func: function to be compiled.
Returns:
- If `func` is not None, returns callable that will execute the compiled
- function (and return zero or more `tf.Tensor` objects).
- If `func` is None, returns a decorator that, when invoked with a single
- `func` argument, returns a callable equivalent to the case above.
+ A callable that will execute the compiled function (and return zero
+ or more `tf.Tensor` objects).
"""
# TODO(apassos): deal with captured global state. Deal with control flow.
- def decorated(function):
- try:
- name = function.__name__
- except AttributeError:
- name = "function"
- return tf_decorator.make_decorator(
- function, named_defun(function, name, compiled=compiled))
-
- # This code path is for the `foo = tfe.defun(foo, ...)` use case
- if func is not None:
- return decorated(func)
-
- # This code path is for the
- #
- # @tfe.defun(...)
- # def foo(...):
- # ...
- #
- # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
- return decorated
+ try:
+ name = func.__name__
+ except AttributeError:
+ name = "function"
+ return tf_decorator.make_decorator(func, named_defun(func, name))
def make_defun_op(func, *args, **kwds):
name = func.__name__
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, False, args, kwds)
+ return _defun_internal(name, func, args, kwds)
class AutomaticControlDependencies(object):