--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ const NameAttrList* name_attr;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr));
+ then_branch_ = *name_attr;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr));
+ else_branch_ = *name_attr;
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+}
+
+// TODO(b/35949885): There is duplication here with the handling of the
+// while_op. Refactor the common code out/rework.
+void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ OP_REQUIRES(ctx, cond_type_ == DT_BOOL,
+ errors::InvalidArgument(
+ "Condition argument must be a boolean for XLA compilation"));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)),
+ errors::InvalidArgument(
+ "Condition argument must be a scalar for XLA compilation"));
+
+ VLOG(1) << "Building If: " << input_types_.size() << " inputs";
+
+ std::vector<xla::ComputationDataHandle> inputs(input_types_.size());
+ std::vector<XlaCompiler::Argument> arguments(input_types_.size());
+ for (int i = 0; i < input_types_.size(); ++i) {
+ XlaCompiler::Argument& arg = arguments[i];
+ DataType type = ctx->input_type(i + 1);
+ if (type == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
+
+ arg.initialized = resource->initialized();
+ arg.kind = XlaCompiler::Argument::kResource;
+ arg.resource_kind = resource->kind();
+ OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
+
+ arg.type = resource->type();
+ arg.shape = resource->shape();
+ OP_REQUIRES(ctx, arg.initialized,
+ errors::Unimplemented("Uninitialized arguments: ", arg.name));
+ arg.tensor_array_size = resource->tensor_array_size();
+ for (const auto& gradient : resource->tensor_array_gradients()) {
+ arg.tensor_array_gradients.insert(gradient.first);
+ }
+ arg.name = resource->name();
+ VLOG(2) << "Resource " << resource->name()
+ << " type: " << DataTypeString(arg.type)
+ << " shape: " << arg.shape.DebugString()
+ << " initialized: " << arg.initialized;
+ } else {
+ arg.kind = XlaCompiler::Argument::kParameter;
+ arg.type = input_types_[i];
+ arg.shape = ctx->InputShape(i + 1);
+ inputs[i] = ctx->Input(i + 1);
+ VLOG(2) << "Arg type: " << DataTypeString(arg.type)
+ << " shape: " << arg.shape.DebugString();
+ }
+ }
+
+ // Compile both branches of the conditional.
+ XlaCompiler::CompileOptions options;
+ options.use_tuple_arg = true;
+ options.resolve_compile_time_constants = false;
+ options.return_updated_values_for_all_resources = true;
+ options.is_entry_computation = false;
+ XlaCompiler* compiler = ctx->compiler();
+
+ XlaCompiler::CompilationResult then_result;
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
+ arguments, &then_result));
+ XlaCompiler::CompilationResult else_result;
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
+ arguments, &else_result));
+
+ for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
+ for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetResourceInput(update.input_index + 1, &resource));
+ XlaCompiler::Argument& arg = arguments[update.input_index];
+
+ // Add any TensorArray gradients touched by the then/else computation to
+ // the enclosing graph.
+ for (const string& grad_source : update.tensor_array_gradients_accessed) {
+ VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
+ << grad_source;
+ XlaResource* gradient;
+ OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
+ grad_source, b, &gradient));
+ }
+ // Add all of the TensorArray gradients to the argument. For simplicity,
+ // we always pass all known gradients.
+ for (const auto& gradient : resource->tensor_array_gradients()) {
+ arg.tensor_array_gradients.insert(gradient.first);
+ }
+ }
+ }
+
+ // Check that both branches have identical input shapes.
+ OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
+ errors::FailedPrecondition("Expected one input shape"));
+ xla::Shape then_input_shape = then_result.xla_input_shapes[0];
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape),
+ errors::FailedPrecondition("Expected tuple shape"));
+ OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1,
+ errors::FailedPrecondition("Expected one input shape"));
+ xla::Shape else_input_shape = else_result.xla_input_shapes[0];
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape),
+ errors::FailedPrecondition("Expected tuple shape"));
+ OP_REQUIRES(ctx,
+ xla::ShapeUtil::Compatible(then_input_shape, else_input_shape),
+ errors::InvalidArgument(
+ "Input shapes of then and else branches do not match: ",
+ xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
+ xla::ShapeUtil::HumanString(else_input_shape)));
+
+ // Check that both branches have identical output shapes.
+ OP_REQUIRES(
+ ctx,
+ xla::ShapeUtil::Compatible(then_result.xla_output_shape,
+ else_result.xla_output_shape),
+ errors::InvalidArgument(
+ "Output shapes of then and else branches do not match: ",
+ xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
+ xla::ShapeUtil::HumanString(else_result.xla_output_shape)));
+
+ VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
+ VLOG(2) << "Output shape: "
+ << xla::ShapeUtil::HumanString(then_result.xla_output_shape);
+
+ // We set return_updated_values_for_all_resources=true and we pass the same
+ // arguments to both computations, so the resource update count must match.
+ OP_REQUIRES(ctx,
+ then_result.resource_updates.size() ==
+ else_result.resource_updates.size(),
+ errors::FailedPrecondition(
+ "Different number of resources in then and else branch"));
+ for (int i = 0; i < then_result.resource_updates.size(); ++i) {
+ const auto& lhs = then_result.resource_updates[i];
+ const auto& rhs = else_result.resource_updates[i];
+ bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
+ lhs.tensor_array_gradients_accessed ==
+ rhs.tensor_array_gradients_accessed;
+ OP_REQUIRES(
+ ctx, equal,
+ errors::FailedPrecondition(
+ "Mismatch in resource of then and else branch for resource ", i));
+ }
+
+ xla::ComputationDataHandle outputs =
+ b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
+ b->Tuple(inputs), *else_result.computation);
+ // Sets non-variable outputs.
+ for (int i = 0; i < output_types_.size(); ++i) {
+ if (ctx->input_type(i) != DT_RESOURCE) {
+ xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i);
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "Setting output " << i;
+ auto shape_or = b->GetShape(output_handle);
+ if (shape_or.ok()) {
+ LOG(INFO) << "Shape for output " << i << ": "
+ << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie());
+ } else {
+ LOG(INFO) << "Shape unknown for output " << i;
+ }
+ }
+ ctx->SetOutput(i, output_handle);
+ }
+ }
+
+ // Updates the values of any resource variables modified by the conditional
+ // bodies.
+ for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
+ for (int i = 0; i < result->resource_updates.size(); ++i) {
+ const XlaCompiler::ResourceUpdate& update = result->resource_updates[i];
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetResourceInput(update.input_index + 1, &resource));
+ if (update.modified) {
+ int pos = result->outputs.size() + i;
+ OP_REQUIRES_OK(ctx,
+ resource->SetFromPack(
+ arguments[update.input_index].tensor_array_gradients,
+ b->GetTupleElement(outputs, pos), b));
+ }
+ VLOG(2) << "If variable: pos: " << update.input_index
+ << " name: " << resource->name()
+ << " modified: " << update.modified
+ << " type: " << DataTypeString(update.type)
+ << " shape: " << update.shape.DebugString();
+ }
+ }
+ VLOG(1) << "Done building If";
+}
+
+REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+
+namespace tensorflow {
+
+// This TensorFlow op provides a functional conditional primitive.
+//
+// The outputs of the then/else branches must agree on the number, types, and
+// shapes of the Tensors carried around the two bodies.
+//
+// Computations in then/else bodies may read from and write to resource
+// variables.
+// Resource variables may be passed as arguments to the then/else function's
+// bodies. The XlaCompiler converts resource variable arguments
+// into parameters to the XLA computation and moves them to the end of the
+// parameter list, and by using the `return_updated_values_for_all_variables`
+// we ensure that all variables that appear in the input also appear at the
+// end of the then/else bodies output. This ensures the then/else bodies output
+// signatures match.
+//
+// It is the user's responsibility to ensure that each non-variable _Arg matches
+// the corresponding _Retval.
+class XlaIfOp : public XlaOpKernel {
+ public:
+ explicit XlaIfOp(OpKernelConstruction* ctx);
+
+ void Compile(XlaOpKernelContext* ctx) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp);
+
+ NameAttrList then_branch_;
+ NameAttrList else_branch_;
+ DataType cond_type_;
+ DataTypeVector input_types_;
+ DataTypeVector output_types_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_