From a5dd6369841a2ea9ecd6e2a56416859afb0c7716 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 16 Mar 2018 13:00:31 -0700 Subject: [PATCH] Move if_op kernel to //third_party/tensorflow/compiler/tf2xla/kernels PiperOrigin-RevId: 189381067 --- tensorflow/compiler/tf2xla/kernels/BUILD | 17 +++ tensorflow/compiler/tf2xla/kernels/if_op.cc | 226 ++++++++++++++++++++++++++++ tensorflow/compiler/tf2xla/kernels/if_op.h | 59 ++++++++ 3 files changed, 302 insertions(+) create mode 100644 tensorflow/compiler/tf2xla/kernels/if_op.cc create mode 100644 tensorflow/compiler/tf2xla/kernels/if_op.h diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d2fa933..0bbfe86 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -93,6 +93,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -154,6 +155,22 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "if_op", + srcs = ["if_op.cc"], + hdrs = ["if_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. tf_kernel_library( diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc new file mode 100644 index 0000000..eefbe55 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -0,0 +1,226 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/if_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr)); + then_branch_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr)); + else_branch_ = *name_attr; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaIfOp::Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + + OP_REQUIRES(ctx, cond_type_ == DT_BOOL, + errors::InvalidArgument( + "Condition argument must be a boolean for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "Condition argument must be a scalar for XLA compilation")); + + VLOG(1) << "Building If: " << input_types_.size() << " inputs"; + + std::vector inputs(input_types_.size()); + std::vector arguments(input_types_.size()); + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + inputs[i] = ctx->Input(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString(); + } + } + + // Compile both branches of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult then_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + XlaCompiler::CompilationResult else_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + } + } + + // Check that both branches have identical input shapes. + OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), + errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape))); + + // Check that both branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape), + errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + then_result.resource_updates.size() == + else_result.resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in then and else branch")); + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES( + ctx, equal, + errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i)); + } + + xla::ComputationDataHandle outputs = + b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, + b->Tuple(inputs), *else_result.computation); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (int i = 0; i < result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result->resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = result->outputs.size() + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + b->GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "If variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building If"; +} + +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 0000000..f9bc98a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp); + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ -- 2.7.4