Move if_op kernel to //third_party/tensorflow/compiler/tf2xla/kernels
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 20:00:31 +0000 (13:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 20:09:35 +0000 (13:09 -0700)
PiperOrigin-RevId: 189381067

tensorflow/compiler/tf2xla/kernels/BUILD
tensorflow/compiler/tf2xla/kernels/if_op.cc [new file with mode: 0644]
tensorflow/compiler/tf2xla/kernels/if_op.h [new file with mode: 0644]

index d2fa933..0bbfe86 100644 (file)
@@ -93,6 +93,7 @@ tf_kernel_library(
         "shape_util.h",
     ],
     deps = [
+        ":if_op",
         ":while_op",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
@@ -154,6 +155,22 @@ tf_kernel_library(
     ],
 )
 
+tf_kernel_library(
+    name = "if_op",
+    srcs = ["if_op.cc"],
+    hdrs = ["if_op.h"],
+    deps = [
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla/ops:functional_ops",
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla/client:computation_builder",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
 # Kernels that only work on CPU, because they use XLA custom calls.
 # Only link this when using the CPU backend for XLA.
 tf_kernel_library(
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
new file mode 100644 (file)
index 0000000..eefbe55
--- /dev/null
@@ -0,0 +1,226 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+  const NameAttrList* name_attr;
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr));
+  then_branch_ = *name_attr;
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr));
+  else_branch_ = *name_attr;
+
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+}
+
+// TODO(b/35949885): There is duplication here with the handling of the
+// while_op. Refactor the common code out/rework.
+void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
+  xla::ComputationBuilder* b = ctx->builder();
+
+  OP_REQUIRES(ctx, cond_type_ == DT_BOOL,
+              errors::InvalidArgument(
+                  "Condition argument must be a boolean for XLA compilation"));
+  OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)),
+              errors::InvalidArgument(
+                  "Condition argument must be a scalar for XLA compilation"));
+
+  VLOG(1) << "Building If: " << input_types_.size() << " inputs";
+
+  std::vector<xla::ComputationDataHandle> inputs(input_types_.size());
+  std::vector<XlaCompiler::Argument> arguments(input_types_.size());
+  for (int i = 0; i < input_types_.size(); ++i) {
+    XlaCompiler::Argument& arg = arguments[i];
+    DataType type = ctx->input_type(i + 1);
+    if (type == DT_RESOURCE) {
+      XlaResource* resource;
+      OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
+
+      arg.initialized = resource->initialized();
+      arg.kind = XlaCompiler::Argument::kResource;
+      arg.resource_kind = resource->kind();
+      OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
+
+      arg.type = resource->type();
+      arg.shape = resource->shape();
+      OP_REQUIRES(ctx, arg.initialized,
+                  errors::Unimplemented("Uninitialized arguments: ", arg.name));
+      arg.tensor_array_size = resource->tensor_array_size();
+      for (const auto& gradient : resource->tensor_array_gradients()) {
+        arg.tensor_array_gradients.insert(gradient.first);
+      }
+      arg.name = resource->name();
+      VLOG(2) << "Resource " << resource->name()
+              << " type: " << DataTypeString(arg.type)
+              << " shape: " << arg.shape.DebugString()
+              << " initialized: " << arg.initialized;
+    } else {
+      arg.kind = XlaCompiler::Argument::kParameter;
+      arg.type = input_types_[i];
+      arg.shape = ctx->InputShape(i + 1);
+      inputs[i] = ctx->Input(i + 1);
+      VLOG(2) << "Arg type: " << DataTypeString(arg.type)
+              << " shape: " << arg.shape.DebugString();
+    }
+  }
+
+  // Compile both branches of the conditional.
+  XlaCompiler::CompileOptions options;
+  options.use_tuple_arg = true;
+  options.resolve_compile_time_constants = false;
+  options.return_updated_values_for_all_resources = true;
+  options.is_entry_computation = false;
+  XlaCompiler* compiler = ctx->compiler();
+
+  XlaCompiler::CompilationResult then_result;
+  OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
+                                                arguments, &then_result));
+  XlaCompiler::CompilationResult else_result;
+  OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
+                                                arguments, &else_result));
+
+  for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
+    for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
+      XlaResource* resource;
+      OP_REQUIRES_OK(ctx,
+                     ctx->GetResourceInput(update.input_index + 1, &resource));
+      XlaCompiler::Argument& arg = arguments[update.input_index];
+
+      // Add any TensorArray gradients touched by the then/else computation to
+      // the enclosing graph.
+      for (const string& grad_source : update.tensor_array_gradients_accessed) {
+        VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
+                << grad_source;
+        XlaResource* gradient;
+        OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
+                                grad_source, b, &gradient));
+      }
+      // Add all of the TensorArray gradients to the argument. For simplicity,
+      // we always pass all known gradients.
+      for (const auto& gradient : resource->tensor_array_gradients()) {
+        arg.tensor_array_gradients.insert(gradient.first);
+      }
+    }
+  }
+
+  // Check that both branches have identical input shapes.
+  OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
+              errors::FailedPrecondition("Expected one input shape"));
+  xla::Shape then_input_shape = then_result.xla_input_shapes[0];
+  OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape),
+              errors::FailedPrecondition("Expected tuple shape"));
+  OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1,
+              errors::FailedPrecondition("Expected one input shape"));
+  xla::Shape else_input_shape = else_result.xla_input_shapes[0];
+  OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape),
+              errors::FailedPrecondition("Expected tuple shape"));
+  OP_REQUIRES(ctx,
+              xla::ShapeUtil::Compatible(then_input_shape, else_input_shape),
+              errors::InvalidArgument(
+                  "Input shapes of then and else branches do not match: ",
+                  xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
+                  xla::ShapeUtil::HumanString(else_input_shape)));
+
+  // Check that both branches have identical output shapes.
+  OP_REQUIRES(
+      ctx,
+      xla::ShapeUtil::Compatible(then_result.xla_output_shape,
+                                 else_result.xla_output_shape),
+      errors::InvalidArgument(
+          "Output shapes of then and else branches do not match: ",
+          xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
+          xla::ShapeUtil::HumanString(else_result.xla_output_shape)));
+
+  VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
+  VLOG(2) << "Output shape: "
+          << xla::ShapeUtil::HumanString(then_result.xla_output_shape);
+
+  // We set return_updated_values_for_all_resources=true and we pass the same
+  // arguments to both computations, so the resource update count must match.
+  OP_REQUIRES(ctx,
+              then_result.resource_updates.size() ==
+                  else_result.resource_updates.size(),
+              errors::FailedPrecondition(
+                  "Different number of resources in then and else branch"));
+  for (int i = 0; i < then_result.resource_updates.size(); ++i) {
+    const auto& lhs = then_result.resource_updates[i];
+    const auto& rhs = else_result.resource_updates[i];
+    bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
+                 lhs.tensor_array_gradients_accessed ==
+                     rhs.tensor_array_gradients_accessed;
+    OP_REQUIRES(
+        ctx, equal,
+        errors::FailedPrecondition(
+            "Mismatch in resource of then and else branch for resource ", i));
+  }
+
+  xla::ComputationDataHandle outputs =
+      b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
+                     b->Tuple(inputs), *else_result.computation);
+  // Sets non-variable outputs.
+  for (int i = 0; i < output_types_.size(); ++i) {
+    if (ctx->input_type(i) != DT_RESOURCE) {
+      xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i);
+      if (VLOG_IS_ON(2)) {
+        LOG(INFO) << "Setting output " << i;
+        auto shape_or = b->GetShape(output_handle);
+        if (shape_or.ok()) {
+          LOG(INFO) << "Shape for output " << i << ": "
+                    << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie());
+        } else {
+          LOG(INFO) << "Shape unknown for output " << i;
+        }
+      }
+      ctx->SetOutput(i, output_handle);
+    }
+  }
+
+  // Updates the values of any resource variables modified by the conditional
+  // bodies.
+  for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
+    for (int i = 0; i < result->resource_updates.size(); ++i) {
+      const XlaCompiler::ResourceUpdate& update = result->resource_updates[i];
+      XlaResource* resource;
+      OP_REQUIRES_OK(ctx,
+                     ctx->GetResourceInput(update.input_index + 1, &resource));
+      if (update.modified) {
+        int pos = result->outputs.size() + i;
+        OP_REQUIRES_OK(ctx,
+                       resource->SetFromPack(
+                           arguments[update.input_index].tensor_array_gradients,
+                           b->GetTupleElement(outputs, pos), b));
+      }
+      VLOG(2) << "If variable: pos: " << update.input_index
+              << " name: " << resource->name()
+              << " modified: " << update.modified
+              << " type: " << DataTypeString(update.type)
+              << " shape: " << update.shape.DebugString();
+    }
+  }
+  VLOG(1) << "Done building If";
+}
+
+REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
new file mode 100644 (file)
index 0000000..f9bc98a
--- /dev/null
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+
+namespace tensorflow {
+
+// This TensorFlow op provides a functional conditional primitive.
+//
+// The outputs of the then/else branches must agree on the number, types, and
+// shapes of the Tensors carried around the two bodies.
+//
+// Computations in then/else bodies may read from and write to resource
+// variables.
+// Resource variables may be passed as arguments to the then/else function's
+// bodies. The XlaCompiler converts resource variable arguments
+// into parameters to the XLA computation and moves them to the end of the
+// parameter list, and by using the `return_updated_values_for_all_variables`
+// we ensure that all variables that appear in the input also appear at the
+// end of the then/else bodies output. This ensures the then/else bodies output
+// signatures match.
+//
+// It is the user's responsibility to ensure that each non-variable _Arg matches
+// the corresponding _Retval.
+class XlaIfOp : public XlaOpKernel {
+ public:
+  explicit XlaIfOp(OpKernelConstruction* ctx);
+
+  void Compile(XlaOpKernelContext* ctx) override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp);
+
+  NameAttrList then_branch_;
+  NameAttrList else_branch_;
+  DataType cond_type_;
+  DataTypeVector input_types_;
+  DataTypeVector output_types_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_