[XLA] Add XlaBuilder. This is the first step of implementing the client-service inter...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Mar 2018 18:49:04 +0000 (11:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 18:57:18 +0000 (11:57 -0700)
PiperOrigin-RevId: 189061445

tensorflow/BUILD
tensorflow/compiler/xla/client/xla_client/BUILD [new file with mode: 0644]
tensorflow/compiler/xla/client/xla_client/xla_builder.cc [new file with mode: 0644]
tensorflow/compiler/xla/client/xla_client/xla_builder.h [new file with mode: 0644]
tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/hlo.proto
tensorflow/compiler/xla/service/shape_inference.cc
tensorflow/compiler/xla/service/shape_inference.h

index 957528c..057ac79 100644 (file)
@@ -433,6 +433,7 @@ filegroup(
         "//tensorflow/compiler/xla:all_files",
         "//tensorflow/compiler/xla/client:all_files",
         "//tensorflow/compiler/xla/client/lib:all_files",
+        "//tensorflow/compiler/xla/client/xla_client:all_files",
         "//tensorflow/compiler/xla/legacy_flags:all_files",
         "//tensorflow/compiler/xla/python:all_files",
         "//tensorflow/compiler/xla/service:all_files",
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
new file mode 100644 (file)
index 0000000..b912889
--- /dev/null
@@ -0,0 +1,77 @@
+# Description:
+#   The new XLA client libraries.
+#
+# This is NOT YET ready to use.
+
+licenses(["notice"])  # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+    name = "friends",
+    includes = [
+        "//tensorflow/compiler/xla:friends",
+    ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+    name = "c_srcs",
+    data = glob([
+        "**/*.cc",
+        "**/*.h",
+    ]),
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# TODO(b/74197823): Replace computation_builder with xla_builder.
+cc_library(
+    name = "xla_builder",
+    srcs = ["xla_builder.cc"],
+    hdrs = ["xla_builder.h"],
+    deps = [
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_proto",
+        "//tensorflow/compiler/xla/service:shape_inference",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "xla_builder_test",
+    srcs = ["xla_builder_test.cc"],
+    deps = [
+        ":xla_builder",
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_matchers",
+        "//tensorflow/core:test",
+    ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
new file mode 100644 (file)
index 0000000..6328a4f
--- /dev/null
@@ -0,0 +1,253 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace xla {
+
+using tensorflow::strings::StrCat;
+
+namespace {
+
+int64 GetUniqueId() {
+  static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+  static int64 built_counter = 0;
+  tensorflow::mutex_lock loc(mu);
+  const int64 id = built_counter++;
+  return id;
+}
+
+// Returns true if an instruction with the given opcode can be the root of the
+// computation.
+bool CanBeRoot(HloOpcode opcode) {
+  switch (opcode) {
+    case HloOpcode::kSend:
+    case HloOpcode::kOutfeed:
+    case HloOpcode::kTrace:
+      return false;
+    default:
+      return true;
+  }
+}
+
+void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) {
+  instr->set_opcode(HloOpcodeString(opcode));
+}
+
+}  // namespace
+
+StatusOr<std::unique_ptr<Shape>> XlaBuilder::GetShape(const XlaOp& op) const {
+  TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
+  return MakeUnique<Shape>(instr->shape());
+}
+
+StatusOr<Shape> XlaOp::GetShape() const {
+  TF_RET_CHECK(builder_ != nullptr);
+  TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this));
+  return *shape;
+}
+
+XlaBuilder::XlaBuilder(const string& computation_name)
+    : name_(computation_name) {}
+
+XlaBuilder::~XlaBuilder() {}
+
+void XlaBuilder::NoteError(const Status& error) {
+  CHECK(!error.ok());
+  if (die_immediately_on_error_) {
+    LOG(FATAL) << "error building computation: " << error;
+  }
+
+  if (first_error_.ok()) {
+    first_error_ = error;
+    first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
+  }
+}
+
+StatusOr<XlaComputation> XlaBuilder::Build() {
+  if (!first_error_.ok()) {
+    string backtrace;
+    first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
+    return AppendStatus(first_error_, backtrace);
+  }
+
+  HloComputationProto entry;
+  ProgramShape* program_shape = entry.mutable_program_shape();
+
+  entry.set_name(name_);
+
+  // Not all instructions can be roots. Walk backwards from the last added
+  // instruction until a valid root is found.
+  for (int64 i = instructions_.size() - 1; i >= 0; i--) {
+    TF_ASSIGN_OR_RETURN(HloOpcode opcode,
+                        StringToHloOpcode(instructions_[i].opcode()));
+    if (CanBeRoot(opcode)) {
+      entry.set_root_name(instructions_[i].name());
+      *program_shape->mutable_result() = instructions_[i].shape();
+      break;
+    }
+  }
+  if (entry.root_name().empty()) {
+    return FailedPrecondition("no root instruction was found");
+  }
+
+  // Check that the parameter numbers are continuous from 0, and add parameter
+  // shapes and names to the program shape.
+  const int64 param_count = parameter_numbers_.size();
+  for (int64 i = 0; i < param_count; i++) {
+    program_shape->add_parameters();
+    program_shape->add_parameter_names();
+  }
+  for (const HloInstructionProto& instr : instructions_) {
+    // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
+    // to verify continuity, we just need to verify that every parameter is in
+    // the right range.
+    if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
+      const int64 index = instr.parameter_number();
+      TF_RET_CHECK(index >= 0 && index < param_count)
+          << "invalid parameter number: " << index;
+      *program_shape->mutable_parameters(index) = instr.shape();
+      *program_shape->mutable_parameter_names(index) = instr.name();
+    }
+  }
+
+  for (auto& instruction : instructions_) {
+    entry.add_instructions()->Swap(&instruction);
+  }
+
+  const int64 id = GetUniqueId();
+  entry.set_id(id);
+  XlaComputation computation(id);
+  HloModuleProto* module = computation.mutable_proto();
+  module->set_name(entry.name());
+  module->set_entry_computation_name(entry.name());
+  *module->mutable_program_shape() = entry.program_shape();
+  for (auto& e : embedded_) {
+    module->add_computations()->Swap(&e.second);
+  }
+  module->add_computations()->Swap(&entry);
+
+  return std::move(computation);
+}
+
+XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
+                      tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+  auto op = [&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    SetOpcode(&instr, HloOpcode::kAdd);
+    TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs));
+    TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs));
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferBinaryOpShape(
+                            HloOpcode::kAdd, lhs_instr->shape(),
+                            rhs_instr->shape(), broadcast_dimensions));
+    instr.add_operand_names(lhs_instr->name());
+    instr.add_operand_names(rhs_instr->name());
+    return AddInstruction(std::move(instr));
+  };
+  return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
+  HloInstructionProto instr;
+  SetOpcode(&instr, HloOpcode::kConstant);
+  *instr.mutable_shape() = literal.shape();
+  *instr.mutable_literal() = literal.ToProto();
+  return AddInstruction(std::move(instr));
+}
+
+XlaOp XlaBuilder::Call(const XlaComputation& computation,
+                       tensorflow::gtl::ArraySlice<XlaOp> operands) {
+  auto op = [&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    SetOpcode(&instr, HloOpcode::kCall);
+    std::vector<const Shape*> operand_shapes;
+    for (const auto& operand : operands) {
+      TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand));
+      operand_shapes.push_back(&input->shape());
+    }
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferCallShape(
+                            operand_shapes,
+                            /*to_apply=*/computation.GetProgramShape()));
+
+    // Add input operands.
+    for (const auto& operand : operands) {
+      TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
+      instr.add_operand_names(operand_instr->name());
+    }
+
+    // Add called computation.
+    *instr.add_called_computation_names() = computation.proto().name();
+    for (const HloComputationProto& e : computation.proto().computations()) {
+      embedded_.insert({e.id(), e});
+    }
+
+    return AddInstruction(std::move(instr));
+  };
+  return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
+                            const string& name) {
+  auto op = [&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    SetOpcode(&instr, HloOpcode::kParameter);
+    if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
+      return InvalidArgument("parameter %lld already registered",
+                             parameter_number);
+    }
+    parameter_numbers_.insert(parameter_number);
+    instr.set_parameter_number(parameter_number);
+    instr.set_name(name);
+    *instr.mutable_shape() = shape;
+    return AddInstruction(std::move(instr));
+  };
+  return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
+  const int64 handle = instructions_.size();
+  if (instr.name().empty()) {
+    instr.set_name(StrCat(instr.opcode(), ".", handle));
+  } else {
+    // Append the handle to make sure the name is unique.
+    instr.set_name(StrCat(instr.name(), ".", handle));
+  }
+  instructions_.push_back(instr);
+
+  XlaOp op(handle, this);
+  return op;
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
+    const XlaOp& op) const {
+  TF_RET_CHECK(op.builder_ == this);
+  if (op.handle() >= instructions_.size() || op.handle() < 0) {
+    return InvalidArgument("no XlaOp value %lld", op.handle());
+  }
+  return &instructions_[op.handle()];
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
new file mode 100644 (file)
index 0000000..7632bd2
--- /dev/null
@@ -0,0 +1,216 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO(b/74197823): Replace computation_builder.h with this file.
+//
+// This is NOT YET ready to use.
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
+
+#include <map>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stacktrace.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class XlaBuilder;
+
+// This represents an instruction that has been enqueued using the XlaBuilder.
+// This is used to pass to subsequent computations that depends upon the
+// instruction as an operand.
+//
+// TODO(b/74197823): Replace xla::ComputationDataHandle with this one.
+class XlaOp {
+ public:
+  StatusOr<Shape> GetShape() const;
+
+ private:
+  XlaOp() : handle_(0), builder_(nullptr) {}
+  XlaOp(int64 handle, XlaBuilder* builder)
+      : handle_(handle), builder_(builder) {}
+
+  int64 handle() const { return handle_; }
+  friend class XlaBuilder;
+
+  int64 handle_;
+  XlaBuilder* builder_;  // Not owned.
+};
+
+// The computation graph that the user builds up with the XlaBuilder.
+//
+// TODO(b/74197823): Replace xla::Computation with this one.
+class XlaComputation {
+ public:
+  XlaComputation(const XlaComputation&) = delete;
+  XlaComputation& operator=(const XlaComputation&) = delete;
+
+  XlaComputation(XlaComputation&& from) { *this = std::move(from); }
+
+  XlaComputation& operator=(XlaComputation&& from) {
+    proto_ = std::move(from.proto());
+    unique_id_ = from.unique_id_;
+    return *this;
+  }
+
+  // Returns the "program shape" (parameter and return shapes) for this
+  // computation.
+  const ProgramShape& GetProgramShape() const { return proto_.program_shape(); }
+
+  const HloModuleProto& proto() const { return proto_; }
+
+ private:
+  // Creates a null Computation.
+  XlaComputation(const int64 unique_id) : unique_id_(unique_id) {}
+  HloModuleProto* mutable_proto() { return &proto_; }
+  friend class XlaBuilder;
+
+  int64 unique_id_;
+  HloModuleProto proto_;
+};
+
+// A convenient interface for building up computations.
+//
+// Thread-compatible.
+//
+// TODO(b/74197823): Replace xla::ComputationBuilder with this one.
+class XlaBuilder {
+ public:
+  // computation_name: name to use for the built computation.
+  XlaBuilder(const string& computation_name);
+
+  XlaBuilder(const XlaBuilder&) = delete;
+  XlaBuilder& operator=(const XlaBuilder&) = delete;
+
+  ~XlaBuilder();
+
+  // Returns the computation name.
+  const string& name() const { return name_; }
+
+  // Sets the builder to a mode where it will die immediately when an error is
+  // encountered, rather than producing it in a deferred fashion when Build() is
+  // called (which is the default).
+  void set_die_immediately_on_error(bool enabled) {
+    die_immediately_on_error_ = enabled;
+  }
+
+  // Enqueues an add instruction onto the computation.
+  XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+            tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+  // Enqueues a call instruction onto the computation.
+  XlaOp Call(const XlaComputation& computation,
+             tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+  // Enqueues a "retrieve parameter value" instruction for a parameter that was
+  // passed to the computation.
+  XlaOp Parameter(int64 parameter_number, const Shape& shape,
+                  const string& name);
+
+  // Enqueues a constant with the value of the given literal onto the
+  // computation.
+  XlaOp ConstantLiteral(const Literal& literal);
+
+  // Enqueues a constant onto the computation. Methods are templated on the
+  // native host type (NativeT) which corresponds to a specific XLA
+  // PrimitiveType as given in the following table:
+  //
+  //  Native Type   PrimitiveType
+  // -----------------------------
+  //   bool           PRED
+  //   int32          S32
+  //   int64          S64
+  //   uint32         U32
+  //   uint64         U64
+  //   float          F32
+  //   double         F64
+  //
+  // Note: not all primitive types defined in xla_data.proto have a
+  // corresponding native type yet.
+  template <typename NativeT>
+  XlaOp ConstantR0(NativeT value);
+
+  // Returns the shape of the given op.
+  StatusOr<std::unique_ptr<Shape>> GetShape(const XlaOp& op) const;
+
+  // Builds the computation with the requested operations, or returns a non-ok
+  // status.
+  StatusOr<XlaComputation> Build();
+
+ private:
+  XlaOp AddInstruction(HloInstructionProto&& instr);
+
+  // Notes that the error occurred by:
+  // * storing it internally and capturing a backtrace if it's the first error
+  //   (this deferred value will be produced on the call to Build())
+  // * dying if die_immediately_on_error_ is true
+  void NoteError(const Status& error);
+
+  XlaOp NoteErrorOrReturn(StatusOr<XlaOp>&& op) {
+    if (!op.ok()) {
+      NoteError(op.status());
+      return XlaOp();
+    }
+    return op.ConsumeValueOrDie();
+  }
+
+  StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+
+  string name_;  // Name to use for the built computation.
+
+  // The first error encountered while building the computation.
+  // This is OK until the first error is encountered.
+  Status first_error_;
+
+  // The saved stack trace from the point at which the first error occurred.
+  tensorflow::SavedStackTrace first_error_backtrace_;
+
+  // The instructions of this computation.
+  std::vector<HloInstructionProto> instructions_;
+
+  // The embedded computations used by this computation. Each computation was
+  // the entry computation of some XlaComputation, the key is the unique id of
+  // that XlaComputation.
+  std::map<int64, HloComputationProto> embedded_;
+
+  // The unique parameter numbers.
+  tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+
+  // Mode bit that indicates whether to die when a first error is encountered.
+  bool die_immediately_on_error_ = false;
+};
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR0(NativeT value) {
+  return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
+}
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
new file mode 100644 (file)
index 0000000..a400e4e
--- /dev/null
@@ -0,0 +1,137 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+
+using ::testing::HasSubstr;
+
+// TODO(b/74197823): Move the tests to service/.
+class XlaBuilderTest : public ::testing::Test {
+ protected:
+  StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
+    TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build());
+    const HloModuleProto& proto = computation.proto();
+    TF_ASSIGN_OR_RETURN(const auto& config,
+                        HloModule::CreateModuleConfigFromProto(proto));
+    return HloModule::CreateFromProto(proto, config);
+  }
+
+  // Returns the name of the test currently being run.
+  string TestName() const {
+    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+  }
+};
+
+TEST_F(XlaBuilderTest, OnePlusTwo) {
+  XlaBuilder b(TestName());
+  b.Add(b.ConstantR0<float>(1.0), b.ConstantR0<float>(2.0));
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ParamPlusConstant) {
+  XlaBuilder b(TestName());
+  auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
+  b.Add(x, b.ConstantR0<float>(1.0));
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ParamPlusParam) {
+  XlaBuilder b(TestName());
+  const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
+  const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
+  auto x = b.Parameter(0, x_shape, "x");
+  auto y = b.Parameter(1, y_shape, "y");
+  auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
+
+  TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape());
+  EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST_F(XlaBuilderTest, XPlusX) {
+  XlaBuilder b(TestName());
+  auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
+  b.Add(x, x);
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
+}
+
+TEST_F(XlaBuilderTest, ShapeInferenceError) {
+  XlaBuilder b(TestName());
+  auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
+  auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
+  b.Add(x, y);
+  auto statusor = BuildHloModule(&b);
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference"));
+}
+
+TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
+  XlaBuilder b_call("add");
+  b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
+
+  XlaBuilder b(TestName());
+  auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
+  auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y");
+  b.Add(x, y);
+  auto statusor = BuildHloModule(&b);
+  ASSERT_FALSE(statusor.ok());
+  EXPECT_THAT(statusor.status().error_message(),
+              HasSubstr("parameter 0 already registered"));
+}
+
+TEST_F(XlaBuilderTest, Call) {
+  XlaBuilder b_call("the_only_to_apply");
+  auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
+  auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
+  b_call.Add(p0, p1);
+  TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
+  XlaBuilder b(TestName());
+  auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+  auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+  auto one = b.ConstantR0<float>(1);
+  auto two = b.ConstantR0<float>(2);
+  b.Add(b.Call(call, {x, y}), b.Call(call, {one, two}));
+  TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
+                            op::Call(op::Constant(), op::Constant())));
+}
+
+}  // namespace
+}  // namespace xla
index 66fd317..bf903d6 100644 (file)
@@ -133,6 +133,9 @@ message HloInstructionProto {
   // Gather dimension numbers.
   xla.GatherDimensionNumbers gather_dimension_numbers = 33;
   repeated int64 gather_window_bounds = 34;
+
+  // The id of this instruction.
+  int64 id = 35;
 }
 
 // Serialization of HloComputation.
@@ -148,6 +151,9 @@ message HloComputationProto {
 
   // The program shape (with layout) of this computation.
   xla.ProgramShape program_shape = 4;
+
+  // The id of this computation.
+  int64 id = 5;
 }
 
 // Serialization of HloModule.
@@ -161,6 +167,9 @@ message HloModuleProto {
 
   // The program shape (with layout) of the entry computation.
   xla.ProgramShape program_shape = 4;
+
+  // The id of this module.
+  int64 id = 5;
 }
 
 // Serialization of HloOrdering.
index 74f744a..8c8bd6d 100644 (file)
@@ -945,6 +945,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
+    HloOpcode opcode, const Shape& lhs, const Shape& rhs,
+    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+  return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs,
+                            broadcast_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
     BinaryOperation operation, const Shape& lhs, const Shape& rhs,
     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
   VLOG(2) << tensorflow::strings::Printf(
index 0d30452..085fdac 100644 (file)
@@ -56,6 +56,9 @@ class ShapeInference {
   static StatusOr<Shape> InferBinaryOpShape(
       BinaryOperation operation, const Shape& lhs, const Shape& rhs,
       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+  static StatusOr<Shape> InferBinaryOpShape(
+      HloOpcode opcode, const Shape& lhs, const Shape& rhs,
+      tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
                                             const HloInstruction* lhs,
                                             const HloInstruction* rhs);