From a799cdbe78ca2c2e9c41f2b1bf8a3f57162fbcea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 May 2018 02:11:52 -0700 Subject: [PATCH] Automated g4 rollback of changelist 195748721 PiperOrigin-RevId: 195790581 --- tensorflow/compiler/jit/BUILD | 22 --- tensorflow/compiler/jit/create_xla_launch_op.cc | 206 +++++---------------- tensorflow/compiler/jit/create_xla_launch_op.h | 35 ---- .../compiler/jit/create_xla_launch_op_test.cc | 144 -------------- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 90 ++------- tensorflow/compiler/jit/kernels/xla_launch_op.h | 51 ++--- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/jit/xla_launch_util.cc | 18 +- tensorflow/compiler/jit/xla_launch_util.h | 15 +- tensorflow/compiler/tests/BUILD | 4 - tensorflow/compiler/tests/eager_test.py | 112 +---------- .../python/examples/resnet50/resnet50_test.py | 55 +++--- tensorflow/python/eager/function.py | 127 ++++--------- 13 files changed, 164 insertions(+), 718 deletions(-) delete mode 100644 tensorflow/compiler/jit/create_xla_launch_op.h delete mode 100644 tensorflow/compiler/jit/create_xla_launch_op_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e942b46..07136d6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -261,7 +261,6 @@ cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", - "create_xla_launch_op.h", ], deps = [ ":common", @@ -271,27 +270,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) - -tf_cc_test( - name = "create_xla_launch_op_test", - srcs = [ - "create_xla_launch_op.h", - "create_xla_launch_op_test.cc", - ], - deps = [ - ":create_xla_launch_op", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 6ac84dc..18d9013 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/create_xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" @@ -26,189 +25,78 @@ limitations under the License. namespace tensorflow { namespace { -// Utility which searches for values in a sorted list by scanning over it once. -// No matter how many times ScanForValue is called, the list is scanned at most -// once. However, if a call to ScanForValue skips over a value, that value is -// not revisited in future calls to ScanForValue, so callers must take -// care to order their calls. +// Givens a NodeDef 'ndef' and the function library runtime 'flr', if +// 'ndef' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. // -// Useful for merging multiple sorted lists in O(n) time. -class SinglePassSearch { - public: - // Creates a SinglePassSearch object that can be used to search in `values`. - // Does not take ownership of `values`. `values` must outlive this. - // `values` must be sorted. - explicit SinglePassSearch(const std::vector* values) - : current_index_(0), values_(values) {} - - // Scans forward in the vector looking for "value", updating the internal - // position in to the vector. - // Returns true iff the vector contains the given value at or after current - // position. - // Not thread-safe. - bool ScanForValue(int value) { - while (current_index_ < values_->size() && - (*values_)[current_index_] <= value) { - if ((*values_)[current_index_] == value) { - current_index_++; - return true; - } - current_index_++; - } - return false; - } - - private: - int current_index_; - const std::vector* values_; -}; - -Status CompilationRequested(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) { +// This routine is here so that FunctionLibraryRuntime can jit a +// specific function call as requested. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, + std::unique_ptr* kernel) { bool xla_compile = false; - // Check if op is marked _XlaCompile=true. - Status status = flr.GetFunctionLibraryDefinition()->GetAttr( - node_def, kXlaCompileAttr, &xla_compile); - if (!status.ok() || !xla_compile) { - if (VLOG_IS_ON(3)) { - if (!status.ok()) { - VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " - << node_def.op() << ". status=" << status.ToString(); - } else { - VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; - } - } - return Status(error::INVALID_ARGUMENT, ""); + if (!flr->GetFunctionLibraryDefinition() + ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) + .ok() || + !xla_compile) { + // Not marked as _XlaCompile=true. + return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); + } + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, ndef)) { + // ndef is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); } - return Status::OK(); -} - -// Given a FunctionLibraryRuntime and a NodeDef calling a function in the -// runtime, returns this function's body in `fbody` as well as the indices -// of its constant and resource arguments. -// `fbody` is owned by `flr`. -// `constant_arg_indices` and `resource_arg_indices` should be empty vector. -// They are sorted in ascending order on this function's return. -Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - const FunctionBody** fbody, - std::vector* constant_arg_indices, - std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If node_def is not instantiable, e.g., the function does not exist, + // If ndef is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); - *fbody = flr->GetFunctionBody(handle); - CHECK(*fbody); // Can't be nullptr since we just instantiated it. - const DataTypeVector& arg_types = (*fbody)->arg_types; - std::vector const_args(arg_types.size()); + flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); // Can't be nullptr since we just instantiated it. + std::vector const_args(fbody->arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - constant_arg_indices->push_back(i); - } - } - - // There can be hundreds of resource variables. Reserve the space for them. - // We don't reserve for constants above as they are usually few. - resource_arg_indices->reserve(arg_types.size()); - for (int i = 0; i < arg_types.size(); ++i) { - if (arg_types[i] == DT_RESOURCE) { - resource_arg_indices->push_back(i); + // There is a const arg. Bail out. + return errors::InvalidArgument("Const arg: ", i, " in ", + DebugString(fbody->fdef)); } } - return Status::OK(); -} - -} // namespace - -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel) { - TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); - - VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, node_def)) { - // node_def is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", - node_def.ShortDebugString()); - } - - // Get function body, constant args, and resource args. - const FunctionBody* fbody = nullptr; - std::vector constant_arg_indices; - std::vector resource_arg_indices; - TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); - - // Set input and output memory types. + NodeDef launch_def; + launch_def.set_name(ndef.name()); + launch_def.set_op("_XlaLaunch"); + launch_def.set_device(flr->device()->name()); + AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); + AddNodeAttr("Nresources", 0, &launch_def); + AddNodeAttr("Targs", fbody->arg_types, &launch_def); + AddNodeAttr("Tresults", fbody->ret_types, &launch_def); + NameAttrList func; + func.set_name(ndef.op()); + *(func.mutable_attr()) = ndef.attr(); + AddNodeAttr("function", func, &launch_def); + + // TODO(b/32387911): Handles the host memory types across function + // calls properly. For now, we assume all inputs and outputs are on + // the device memory. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); - // These indices are used only for optimization purposes. They allow us - // to loop over constant_arg_indices and resource_arg_indices only once - // while iterating over all the function arguments checking if it is a - // resource or a constant. - // The reason we optimized this code is because functions can have a lot of - // captured arguments. For example, the backward pass of ResNet50 takes in all - // 214 variables and a similar number of activations. - SinglePassSearch constants_search(&constant_arg_indices); - SinglePassSearch resources_search(&resource_arg_indices); - for (int i = 0; i < fbody->arg_types.size(); ++i) { - if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { - // Compile-time constants and resource handles are expected to be in - // host memory. - input_memory_types[i] = HOST_MEMORY; - } - } - // One might wonder, about the case where a compile-time constant argument - // (which must be in host memory) is also used as an input into an op, - // e.g. Add, that expects its inputs in device memory. Here is how it - // works now. - // First, what do we mean by "op expects an input in XYZ memory"? - // There are two types of "ops" here: the tf2xla kernel and the HLO - // computation it builds. The tf2xla kernel needs to retrieve the actual - // numeric value of the compile-time constant tensors, so it really expects - // them to be on in host memory. However, for other inputs, it refers to them - // using xla::ComputationDataHandle, which is just a symbolic handle that - // xla::ComputationBuilder assigns. How does this handle gets assigned for - // constant arguments? Even constant arguments get an _Arg node in the graph - // instatiated for Function compilation. The tf2xla kernel for constant _Arg - // nodes takes the constant value, converts it to XlaLiteral, and feeds it - // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This - // constant XlaLiteral is included in the HLO graph, and subsequently, in - // the actual executable, which is copied to the device before being - // executed. Thus, when this executable runs, the constant is available in - // device memory. - - // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); - // Create the kernel. - NameAttrList function; - function.set_name(node_def.op()); - *(function.mutable_attr()) = node_def.attr(); - Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &node_def, + dev->GetAllocator(AllocatorAttributes()), &launch_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - - *kernel = absl::make_unique( - &construction, constant_arg_indices, resource_arg_indices, function); + kernel->reset(new XlaLocalLaunchOp(&construction)); return s; } -namespace { - bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h deleted file mode 100644 index 98a22e3..0000000 --- a/tensorflow/compiler/jit/create_xla_launch_op.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ - -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -class FunctionLibraryRuntime; -class OpKernel; - -// Given a NodeDef 'node_def' and the function library runtime 'flr', if -// 'node_def' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc deleted file mode 100644 index c222824..0000000 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/create_xla_launch_op.h" - -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" - -namespace tensorflow { - -NodeDef ToNodeDef(const string& text) { - NodeDef node_def; - EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); - return node_def; -} - -// Create a FunctionDef that takes one resource and one regular param -FunctionDef XTimesY() { - return FunctionDefHelper::Define( - // Name - "XTimesY", - // Args - {"x: float", "y: resource"}, - // Return values - {"z: float"}, - // Attr def - {}, - // Nodes - { - {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, - {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, - }); -} - -class CreateXlaLaunchOpTest : public ::testing::Test { - protected: - void Init(const std::vector& flib) { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - device_count->insert({"CPU", 1}); - TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); - - FunctionDefLibrary proto; - for (const auto& fdef : flib) { - *(proto.add_function()) = fdef; - } - lib_def_ = absl::make_unique( - OpRegistry::Global(), proto); - OptimizerOptions opts; - device_mgr_ = absl::make_unique(devices_); - pflr_ = absl::make_unique( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); - flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); - } - - FunctionLibraryRuntime* flr_; - std::vector devices_; - std::unique_ptr device_mgr_; - std::unique_ptr lib_def_; - std::unique_ptr pflr_; - - std::unique_ptr kernel_; -}; - -AttrValue BoolAttr(bool b) { - AttrValue v; - v.set_b(b); - return v; -} - -TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { - FunctionDef fdef = XTimesY(); - (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); - Init({fdef}); - - Status status = CreateXlaLaunchOp( - flr_, ToNodeDef(R"pb( - name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' - )pb"), &kernel_); - ASSERT_TRUE(status.ok()) << status.ToString(); - - EXPECT_EQ("XTimesY", kernel_->name()); - EXPECT_EQ("XTimesY", kernel_->type_string()); - - EXPECT_EQ(2, kernel_->num_inputs()); - EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); - EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); - EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); - EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); - - EXPECT_EQ(1, kernel_->num_outputs()); - EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); - EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); -} - -TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { - FunctionDef fdef = XTimesY(); - Init({fdef}); - - Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), &kernel_); - EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); -} - -TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { - FunctionDef fdef = XTimesY(); - (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); - Init({fdef}); - - Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), &kernel_); - EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 86a9fd3..049d170 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -39,15 +39,15 @@ limitations under the License. namespace tensorflow { -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function) - : OpKernel(ctx), - constants_(constants), - resources_(resources), - device_type_(ctx->device_type()), - function_(function) { +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : OpKernel(ctx), device_type_(ctx->device_type()) { + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); + function_ = *func; + DataTypeVector constant_types; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); + num_constant_args_ = constant_types.size(); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { @@ -57,8 +57,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, } } -Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -90,8 +90,8 @@ Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOpBase::Compute " +void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOp::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -124,7 +124,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } std::map variables = - SnapshotResourceVariables(ctx, resources_); + SnapshotResourceVariables(ctx, num_resource_args_); xla::LocalClient* client = static_cast(cache->client()); @@ -161,7 +161,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map constant_args; - for (int i : constants_) { + for (int i = 0; i < num_constant_args_; ++i) { constant_args.insert({i, ctx->input(i)}); } OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, @@ -170,8 +170,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context(client, xla_allocator, - allocate_xla_tensors); + XlaComputationLaunchContext launch_context( + num_resource_args_, client, xla_allocator, allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -194,62 +194,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } -namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - -// Helper static functions to construct parameters for -// XlaLocalLaunchBase constructor from OpKernelConstruction. -std::vector ConstantsVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - std::vector constants(constant_types.size()); - std::iota(constants.begin(), constants.end(), 0); - return constants; -} - -std::vector ResourcesVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - - DataTypeVector arg_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Targs", &arg_types)); - - int num_resources; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Nresources", &num_resources)); - - std::vector resources(num_resources); - std::iota(resources.begin(), resources.end(), - constant_types.size() + arg_types.size()); - return resources; -} - -NameAttrList FunctionAttr(OpKernelConstruction* ctx) { - const NameAttrList* func; - OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); - return *func; -} - -#undef OP_REQUIRES_OK_RETURN -} // namespace - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), - FunctionAttr(ctx)) {} - XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8dfc4b3..8f8e646 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,41 +26,6 @@ limitations under the License. namespace tensorflow { -// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. -// The only difference is that it does not require arguments to follow -// the "constants, then regular args, then resources" order. -// It takes vectors of constant and resource arguments explicitly. -// It does not have corresponding OpDef because it is never present -// in the GraphDef. -// Currently, it is used by eager runtime. FunctionLibraryRuntime creates -// this kernel when asked to create a kernel for an XLA-compiled function. -class XlaLocalLaunchBase : public OpKernel { - public: - XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function); - XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; - XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; - ~XlaLocalLaunchBase() override = default; - - void Compute(OpKernelContext* ctx) override; - - protected: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache); - - // Indexes of compile-time constant inputs - std::vector constants_; - // Indexes of resource inputs - std::vector resources_; - - DeviceType device_type_; - NameAttrList function_; - se::Platform::Id platform_id_; -}; - // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -70,12 +35,26 @@ class XlaLocalLaunchBase : public OpKernel { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public XlaLocalLaunchBase { +class XlaLocalLaunchOp : public OpKernel { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; + void Compute(OpKernelContext* ctx) override; + private: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** compiler); + + DeviceType device_type_; + NameAttrList function_; + int num_constant_args_; + // Number of resource variable arguments. + int num_resource_args_; + + se::Platform::Id platform_id_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 6b83cf6..60458f6 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,12 +48,13 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map variables = GetVariables(ctx); + int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - client, client->backend().memory_allocator(), true); + num_resource_args, client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 0223f97..33e5361 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,13 +38,14 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables) { +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables) { std::map snapshot; - for (int i : variables) { + int first_variable = ctx->num_inputs() - num_variables; + for (int i = 0; i < num_variables; ++i) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, i); - OptionalTensor& tensor = snapshot[i]; + ResourceHandle handle = HandleFromInput(ctx, first_variable + i); + OptionalTensor& tensor = snapshot[first_variable + i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -111,9 +112,10 @@ ScopedShapedBuffer ExtractSubShapedBuffer( using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors) - : client_(client), + int64 num_resource_args, xla::LocalClient* client, + xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) + : num_resource_args_(num_resource_args), + client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index a243125..38291b0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -31,17 +31,15 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, whose -// indices are specified in `variables` argument. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, which are +// the last `num_variables` arguments. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. If a -// resource variable is not initialized, the corresponding OptionalTensor -// will have its `present` field set to false. -std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables); +// Returns a map of TensorFlow argument index to resource variable. +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -74,7 +72,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(xla::LocalClient* client, + XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -94,6 +92,7 @@ class XlaComputationLaunchContext { const std::vector& arguments() const { return arg_ptrs_; } private: + int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 9791792..aaea83a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -327,11 +327,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", "//tensorflow/python:platform_test", - "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 5ab1585..bdd0185 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,16 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.layers import convolutional -from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest @@ -49,7 +43,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen0(self): with self.test_scope(): - empty = constant_op.constant([], dtype=dtypes.float32) + empty = constant_op.constant([], dtype=dtypes.int32) result = array_ops.unstack(empty, 0) self.assertTrue(isinstance(result, list)) self.assertEqual(0, len(result)) @@ -57,7 +51,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen1(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) result = array_ops.split(value, 1, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) @@ -66,7 +60,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen3(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) result = array_ops.split(value, 3, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) @@ -137,105 +131,7 @@ class EagerTest(XLATestCase): self.assertEqual(2., grads[0][0].numpy()) -class EagerFunctionTest(XLATestCase): - - def testBasic(self): - with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) - t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) - sq = matmul(t, t, transpose_a=True) - self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) - - def testConv(self): - if 'GPU' in self.device: - # TODO(b/32333178) - self.skipTest('Current implementation of RandomStandardNormal kernel ' - 'is very slow on GPU, and has been blacklisted.') - with self.test_scope(): - data_format = 'channels_last' - conv = convolutional.Conv2D( - filters=1, kernel_size=2, padding='VALID', - data_format=data_format, activation=nn_ops.relu, - kernel_initializer=init_ops.ones_initializer(), - bias_initializer=init_ops.zeros_initializer()) - pool = pooling.MaxPooling2D(2, 2, data_format=data_format) - - def model(x): - x = conv(x) - return pool(x) - model = function.defun(model, compiled=True) - - x = array_ops.ones([1, 4, 4, 1]) - y = model(x) - self.assertAllEqual(y.numpy(), [[[[4.]]]]) - - def testReadVariable(self): - with self.test_scope(): - v = resource_variable_ops.ResourceVariable(1.0) - - @function.defun(compiled=True) - def f(): - return v.read_value() - - var = f() - self.assertEqual(1.0, var.numpy()) - - def testUpdateVariable(self): - with self.test_scope(): - v = resource_variable_ops.ResourceVariable(1.0) - - def f(v): - v.assign_add(1.0) - return v - - f = function.defun(f, compiled=True) - - var = f(v) - self.assertEqual(2.0, var.numpy()) - - def testAllArgumentKinds(self): - """Test a complex function that takes different argument kinds. - - tf2xla machinery that translates, compiles, and runs defuns - classifies arguments into: compile-time constants, regular tensors, - and resources. This test creates a function with a mix of all these - kinds. Moreover, the order of function arguments is intentionally mixed up. - - This also tests the case when the same argument is a compile-time constant - as well as used in an operation that normally expects its inputs to be - in device memory - addition in this case. - """ - with self.test_scope(): - def foo(c1, r1, v1, c2, v2, r2): - # c1 and c2 are compile-time constants - # r1 and r2 are regular tensors - # v1 and v2 are resource variables - a = c1 + r1 - b = math_ops.cast(c2, dtypes.float32) + v2 - c = array_ops.slice(v1, c1, c2) - d = r2 * v2 - return a, b, c, d - - foo = function.defun(foo, compiled=True) - - c1 = [0, 0] - c2 = array_ops.ones([2], dtype=dtypes.int32) - - r1 = array_ops.ones([2]) - r2 = [[2., 2.], [3., 3.]] - - v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) - v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) - - a, b, c, d = foo(c1, r1, v1, c2, v2, r2) - - self.assertAllEqual([1, 1], a.numpy()) - self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) - self.assertAllEqual([[1.]], c.numpy()) - self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) - - -if __name__ == '__main__': +if __name__ == "__main__": ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) googletest.main() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index b8f352d..8517a3b 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -36,7 +36,9 @@ def device_and_data_format(): 'channels_last') -def random_batch(batch_size, data_format): +def random_batch(batch_size, device_and_format=None): + _, data_format = device_and_format or device_and_data_format() + shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) shape = (batch_size,) + shape @@ -68,7 +70,7 @@ class ResNet50Test(tf.test.TestCase): if defun: model.call = tfe.defun(model.call) with tf.device(device), tfe.execution_mode(execution_mode): - images, _ = random_batch(2, data_format) + images, _ = random_batch(2) output = model(images, training=False) tfe.async_wait() self.assertEqual((2, 1000), output.shape) @@ -89,7 +91,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): - images, _ = random_batch(2, data_format) + images, _ = random_batch(2) output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) @@ -99,7 +101,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): - images, _ = random_batch(2, data_format) + images, _ = random_batch(2) output = model(images, training=False) self.assertEqual((2, 2048), output.shape) @@ -113,7 +115,7 @@ class ResNet50Test(tf.test.TestCase): name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) - images, labels = random_batch(2, data_format) + images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) tfe.async_wait() @@ -132,7 +134,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): - images, labels = random_batch(2, data_format) + images, labels = random_batch(2) gc.disable() # Warm up. Note that this first run does create significant amounts of # garbage to be collected. The hope is that this is a build-only effect, @@ -200,18 +202,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, device_and_format, defun=False, - execution_mode=None, compiled=False): + def _benchmark_eager_apply(self, label, defun=False, execution_mode=None, + device_and_format=None): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format + device, data_format = device_and_format or device_and_data_format() model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): - images, _ = random_batch(batch_size, data_format) + images, _ = random_batch(batch_size, device_and_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: @@ -225,34 +227,30 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_apply_sync(self): - self._benchmark_eager_apply('eager_apply', device_and_data_format(), - defun=False) + self._benchmark_eager_apply('eager_apply', defun=False) def benchmark_eager_apply_async(self): self._benchmark_eager_apply( - 'eager_apply_async', device_and_data_format(), defun=False, - execution_mode=tfe.ASYNC) + 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', - device_and_data_format(), defun=True) + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) def _benchmark_eager_train(self, label, make_iterator, - device_and_format, defun=False, execution_mode=None, - compiled=False): + device_and_format=None): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format + device, data_format = device_and_format or device_and_data_format() for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size, data_format) + (images, labels) = random_batch(batch_size, device_and_format) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call, compiled=compiled) + model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -275,21 +273,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train_sync(self): - self._benchmark_eager_train('eager_train', MockIterator, - device_and_data_format(), defun=False) + self._benchmark_eager_train('eager_train', MockIterator, defun=False) def benchmark_eager_train_async(self): self._benchmark_eager_train( 'eager_train_async', MockIterator, - device_and_data_format(), defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, - device_and_data_format(), defun=True) + 'eager_train_with_defun', MockIterator, defun=True) def benchmark_eager_train_datasets(self): @@ -299,8 +294,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, - device_and_data_format(), defun=False) + 'eager_train_dataset', make_iterator, defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -310,8 +304,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, - device_and_data_format(), defun=True) + 'eager_train_dataset_with_defun', make_iterator, defun=True) if __name__ == '__main__': diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 60cfacc..741bd2a 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -23,7 +23,6 @@ import collections import numpy as np -from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -226,7 +225,7 @@ def _inference_name(n): class _EagerDefinedFunction(object): """Function object with the interface of tf _DefinedFunction.""" - def __init__(self, name, graph, operations, inputs, outputs, attrs): + def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: @@ -236,7 +235,6 @@ class _EagerDefinedFunction(object): which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function - attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access @@ -248,14 +246,6 @@ class _EagerDefinedFunction(object): [], None, compat.as_str("")) - - for name, attr_value in attrs.items(): - serialized = attr_value.SerializeToString() - # TODO(iga): this creates and deletes a new TF_Status for every attr. - # It might be worth creating a convenient way to re-use status. - pywrap_tensorflow.TF_FunctionSetAttrValueProto( - fn, compat.as_str(name), serialized) - # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: @@ -297,6 +287,25 @@ def _flatten(sequence): class GraphModeFunction(object): """Callable object representing a graph-mode function. + + Args: + name: str the name of the created function + input_placeholders: list of placeholder values (tensors) to feed when + calling the wrapped function. + extra_inputs: Tensor inputs this function definition closed over which + are passed as arguments. Need to track so gradients are supported + correctly. + graph: the Graph from which the operations will be pulled. Used as + a context when computing gradients. + operations: the subset of Operations in the graph used in the function + definition. + outputs: a flat list of the Tensors in the graph used as outputs to the + function + func_outputs: a possibly nested python object which will be returned by + this function. The Tensors in this structure will be replaced by their + corresponding values in outputs. + output_shapes: List of shapes of all tensors in outputs + variables: (optional) List of variables to watch during function execution. """ def __init__(self, @@ -308,36 +317,9 @@ class GraphModeFunction(object): outputs, func_outputs, output_shapes, - variables=None, - attrs=None): - """Initialize a GraphModeFunction. - - Args: - name: str the name of the created function - input_placeholders: list of placeholder values (tensors) to feed when - calling the wrapped function. - extra_inputs: Tensor inputs this function definition closed over which - are passed as arguments. Need to track so gradients are supported - correctly. - graph: the Graph from which the operations will be pulled. Used as - a context when computing gradients. - operations: the subset of Operations in the graph used in the function - definition. - outputs: a flat list of the Tensors in the graph used as outputs to the - function - func_outputs: a possibly nested python object which will be returned by - this function. The Tensors in this structure will be replaced by their - corresponding values in outputs. - output_shapes: List of shapes of all tensors in outputs - variables: (optional) List of variables to watch during function - execution. - attrs: (optional) dict mapping names of attributes to their AttrValue - values. Attributes in `attrs` will be included in this function's - definition. - """ - self._attrs = attrs or {} + variables=None): defined_function = _EagerDefinedFunction( - name, graph, operations, input_placeholders, outputs, self._attrs) + name, graph, operations, input_placeholders, outputs) if len(input_placeholders) != len(defined_function.signature.input_arg): raise ValueError("Internal error: invalid lengths. %s %s" % ( len(input_placeholders), len(defined_function.signature.input_arg))) @@ -390,7 +372,7 @@ class GraphModeFunction(object): forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + captures, self._attrs) + filtered_outputs + captures) all_inputs = self._out_grad_placeholders + captures # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. @@ -404,7 +386,7 @@ class GraphModeFunction(object): bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( bname, all_inputs, [], self._graph, function_def_ops, - backward_outputs, in_gradients, output_shapes, attrs=self._attrs) + backward_outputs, in_gradients, output_shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" @@ -578,7 +560,7 @@ def _get_defun_inputs(args): return nest.pack_sequence_as(args, ret) -def _defun_internal(name, func, compiled, args, kwds): +def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): @@ -643,14 +625,9 @@ def _defun_internal(name, func, compiled, args, kwds): for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func.func) # pylint: disable=protected-access - - attrs = {} - if compiled: - attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True) - return GraphModeFunction( fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, - func_outputs, output_shapes, variables, attrs) + func_outputs, output_shapes, variables) # Defun uses this instead of Tensor as a cache key. Using dtype because @@ -692,7 +669,7 @@ def _register(fn): # TODO(apassos): better error messages for non-hashable arguments. -def named_defun(func, name, compiled=False): +def named_defun(func, name): """Defines a function with a given name. See the documentation for `defun` for more information on the semantics of the @@ -701,7 +678,6 @@ def named_defun(func, name, compiled=False): Args: func: the function to be wrapped. name: the name given to it. - compiled: if true, the framework will attempt to compile func with XLA. Returns: the wrapped function. @@ -718,13 +694,13 @@ def named_defun(func, name, compiled=False): if cache_key not in arguments_to_functions: arguments_to_functions[cache_key] = _defun_internal( - name, func, compiled, args, kwds) + name, func, args, kwds) return arguments_to_functions[cache_key](*args) return decorated -def defun(func=None, compiled=False): +def defun(func): """Decorator to compile func into graph_mode. `defun` converts a function that constructs a TensorFlow graph into a function @@ -767,45 +743,18 @@ def defun(func=None, compiled=False): ``` Args: - func: function to be compiled. If `func` is None, returns a - decorator that can be invoked with a single argument - `func`. The - end result is equivalent to providing all the arguments up front. - In other words, defun(compiled=True)(func) is equivalent to - defun(func, compiled=True). The former allows the following use case: - @tfe.defun(compiled=True) - def foo(...): - ... - compiled: If True, an attempt to compile `func` with XLA will be made. - If it fails, function will be run normally. Experimental. - Currently, supported only for execution on TPUs. + func: function to be compiled. Returns: - If `func` is not None, returns callable that will execute the compiled - function (and return zero or more `tf.Tensor` objects). - If `func` is None, returns a decorator that, when invoked with a single - `func` argument, returns a callable equivalent to the case above. + A callable that will execute the compiled function (and return zero + or more `tf.Tensor` objects). """ # TODO(apassos): deal with captured global state. Deal with control flow. - def decorated(function): - try: - name = function.__name__ - except AttributeError: - name = "function" - return tf_decorator.make_decorator( - function, named_defun(function, name, compiled=compiled)) - - # This code path is for the `foo = tfe.defun(foo, ...)` use case - if func is not None: - return decorated(func) - - # This code path is for the - # - # @tfe.defun(...) - # def foo(...): - # ... - # - # use case, which is equivalent to `foo = tfe.defun(...)(foo)` - return decorated + try: + name = func.__name__ + except AttributeError: + name = "function" + return tf_decorator.make_decorator(func, named_defun(func, name)) def make_defun_op(func, *args, **kwds): @@ -857,7 +806,7 @@ def make_defun_op(func, *args, **kwds): name = func.__name__ if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") - return _defun_internal(name, func, False, args, kwds) + return _defun_internal(name, func, args, kwds) class AutomaticControlDependencies(object): -- 2.7.4