"//tensorflow/compiler/xla:all_files",
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
+ "//tensorflow/compiler/xla/client/xla_client:all_files",
"//tensorflow/compiler/xla/legacy_flags:all_files",
"//tensorflow/compiler/xla/python:all_files",
"//tensorflow/compiler/xla/service:all_files",
--- /dev/null
+# Description:
+# The new XLA client libraries.
+#
+# This is NOT YET ready to use.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# TODO(b/74197823): Replace computation_builder with xla_builder.
+cc_library(
+ name = "xla_builder",
+ srcs = ["xla_builder.cc"],
+ hdrs = ["xla_builder.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "xla_builder_test",
+ srcs = ["xla_builder_test.cc"],
+ deps = [
+ ":xla_builder",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/core:test",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace xla {
+
+using tensorflow::strings::StrCat;
+
+namespace {
+
+int64 GetUniqueId() {
+ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+ static int64 built_counter = 0;
+ tensorflow::mutex_lock loc(mu);
+ const int64 id = built_counter++;
+ return id;
+}
+
+// Returns true if an instruction with the given opcode can be the root of the
+// computation.
+bool CanBeRoot(HloOpcode opcode) {
+ switch (opcode) {
+ case HloOpcode::kSend:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kTrace:
+ return false;
+ default:
+ return true;
+ }
+}
+
+void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) {
+ instr->set_opcode(HloOpcodeString(opcode));
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Shape>> XlaBuilder::GetShape(const XlaOp& op) const {
+ TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
+ return MakeUnique<Shape>(instr->shape());
+}
+
+StatusOr<Shape> XlaOp::GetShape() const {
+ TF_RET_CHECK(builder_ != nullptr);
+ TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this));
+ return *shape;
+}
+
+XlaBuilder::XlaBuilder(const string& computation_name)
+ : name_(computation_name) {}
+
+XlaBuilder::~XlaBuilder() {}
+
+void XlaBuilder::NoteError(const Status& error) {
+ CHECK(!error.ok());
+ if (die_immediately_on_error_) {
+ LOG(FATAL) << "error building computation: " << error;
+ }
+
+ if (first_error_.ok()) {
+ first_error_ = error;
+ first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
+ }
+}
+
+StatusOr<XlaComputation> XlaBuilder::Build() {
+ if (!first_error_.ok()) {
+ string backtrace;
+ first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
+ return AppendStatus(first_error_, backtrace);
+ }
+
+ HloComputationProto entry;
+ ProgramShape* program_shape = entry.mutable_program_shape();
+
+ entry.set_name(name_);
+
+ // Not all instructions can be roots. Walk backwards from the last added
+ // instruction until a valid root is found.
+ for (int64 i = instructions_.size() - 1; i >= 0; i--) {
+ TF_ASSIGN_OR_RETURN(HloOpcode opcode,
+ StringToHloOpcode(instructions_[i].opcode()));
+ if (CanBeRoot(opcode)) {
+ entry.set_root_name(instructions_[i].name());
+ *program_shape->mutable_result() = instructions_[i].shape();
+ break;
+ }
+ }
+ if (entry.root_name().empty()) {
+ return FailedPrecondition("no root instruction was found");
+ }
+
+ // Check that the parameter numbers are continuous from 0, and add parameter
+ // shapes and names to the program shape.
+ const int64 param_count = parameter_numbers_.size();
+ for (int64 i = 0; i < param_count; i++) {
+ program_shape->add_parameters();
+ program_shape->add_parameter_names();
+ }
+ for (const HloInstructionProto& instr : instructions_) {
+ // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
+ // to verify continuity, we just need to verify that every parameter is in
+ // the right range.
+ if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
+ const int64 index = instr.parameter_number();
+ TF_RET_CHECK(index >= 0 && index < param_count)
+ << "invalid parameter number: " << index;
+ *program_shape->mutable_parameters(index) = instr.shape();
+ *program_shape->mutable_parameter_names(index) = instr.name();
+ }
+ }
+
+ for (auto& instruction : instructions_) {
+ entry.add_instructions()->Swap(&instruction);
+ }
+
+ const int64 id = GetUniqueId();
+ entry.set_id(id);
+ XlaComputation computation(id);
+ HloModuleProto* module = computation.mutable_proto();
+ module->set_name(entry.name());
+ module->set_entry_computation_name(entry.name());
+ *module->mutable_program_shape() = entry.program_shape();
+ for (auto& e : embedded_) {
+ module->add_computations()->Swap(&e.second);
+ }
+ module->add_computations()->Swap(&entry);
+
+ return std::move(computation);
+}
+
+XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ auto op = [&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ SetOpcode(&instr, HloOpcode::kAdd);
+ TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs));
+ TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferBinaryOpShape(
+ HloOpcode::kAdd, lhs_instr->shape(),
+ rhs_instr->shape(), broadcast_dimensions));
+ instr.add_operand_names(lhs_instr->name());
+ instr.add_operand_names(rhs_instr->name());
+ return AddInstruction(std::move(instr));
+ };
+ return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
+ HloInstructionProto instr;
+ SetOpcode(&instr, HloOpcode::kConstant);
+ *instr.mutable_shape() = literal.shape();
+ *instr.mutable_literal() = literal.ToProto();
+ return AddInstruction(std::move(instr));
+}
+
+XlaOp XlaBuilder::Call(const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ auto op = [&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ SetOpcode(&instr, HloOpcode::kCall);
+ std::vector<const Shape*> operand_shapes;
+ for (const auto& operand : operands) {
+ TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand));
+ operand_shapes.push_back(&input->shape());
+ }
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferCallShape(
+ operand_shapes,
+ /*to_apply=*/computation.GetProgramShape()));
+
+ // Add input operands.
+ for (const auto& operand : operands) {
+ TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
+ instr.add_operand_names(operand_instr->name());
+ }
+
+ // Add called computation.
+ *instr.add_called_computation_names() = computation.proto().name();
+ for (const HloComputationProto& e : computation.proto().computations()) {
+ embedded_.insert({e.id(), e});
+ }
+
+ return AddInstruction(std::move(instr));
+ };
+ return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
+ const string& name) {
+ auto op = [&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ SetOpcode(&instr, HloOpcode::kParameter);
+ if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
+ return InvalidArgument("parameter %lld already registered",
+ parameter_number);
+ }
+ parameter_numbers_.insert(parameter_number);
+ instr.set_parameter_number(parameter_number);
+ instr.set_name(name);
+ *instr.mutable_shape() = shape;
+ return AddInstruction(std::move(instr));
+ };
+ return NoteErrorOrReturn(op());
+}
+
+XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
+ const int64 handle = instructions_.size();
+ if (instr.name().empty()) {
+ instr.set_name(StrCat(instr.opcode(), ".", handle));
+ } else {
+ // Append the handle to make sure the name is unique.
+ instr.set_name(StrCat(instr.name(), ".", handle));
+ }
+ instructions_.push_back(instr);
+
+ XlaOp op(handle, this);
+ return op;
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
+ const XlaOp& op) const {
+ TF_RET_CHECK(op.builder_ == this);
+ if (op.handle() >= instructions_.size() || op.handle() < 0) {
+ return InvalidArgument("no XlaOp value %lld", op.handle());
+ }
+ return &instructions_[op.handle()];
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO(b/74197823): Replace computation_builder.h with this file.
+//
+// This is NOT YET ready to use.
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
+
+#include <map>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stacktrace.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class XlaBuilder;
+
+// This represents an instruction that has been enqueued using the XlaBuilder.
+// This is used to pass to subsequent computations that depends upon the
+// instruction as an operand.
+//
+// TODO(b/74197823): Replace xla::ComputationDataHandle with this one.
+class XlaOp {
+ public:
+ StatusOr<Shape> GetShape() const;
+
+ private:
+ XlaOp() : handle_(0), builder_(nullptr) {}
+ XlaOp(int64 handle, XlaBuilder* builder)
+ : handle_(handle), builder_(builder) {}
+
+ int64 handle() const { return handle_; }
+ friend class XlaBuilder;
+
+ int64 handle_;
+ XlaBuilder* builder_; // Not owned.
+};
+
+// The computation graph that the user builds up with the XlaBuilder.
+//
+// TODO(b/74197823): Replace xla::Computation with this one.
+class XlaComputation {
+ public:
+ XlaComputation(const XlaComputation&) = delete;
+ XlaComputation& operator=(const XlaComputation&) = delete;
+
+ XlaComputation(XlaComputation&& from) { *this = std::move(from); }
+
+ XlaComputation& operator=(XlaComputation&& from) {
+ proto_ = std::move(from.proto());
+ unique_id_ = from.unique_id_;
+ return *this;
+ }
+
+ // Returns the "program shape" (parameter and return shapes) for this
+ // computation.
+ const ProgramShape& GetProgramShape() const { return proto_.program_shape(); }
+
+ const HloModuleProto& proto() const { return proto_; }
+
+ private:
+ // Creates a null Computation.
+ XlaComputation(const int64 unique_id) : unique_id_(unique_id) {}
+ HloModuleProto* mutable_proto() { return &proto_; }
+ friend class XlaBuilder;
+
+ int64 unique_id_;
+ HloModuleProto proto_;
+};
+
+// A convenient interface for building up computations.
+//
+// Thread-compatible.
+//
+// TODO(b/74197823): Replace xla::ComputationBuilder with this one.
+class XlaBuilder {
+ public:
+ // computation_name: name to use for the built computation.
+ XlaBuilder(const string& computation_name);
+
+ XlaBuilder(const XlaBuilder&) = delete;
+ XlaBuilder& operator=(const XlaBuilder&) = delete;
+
+ ~XlaBuilder();
+
+ // Returns the computation name.
+ const string& name() const { return name_; }
+
+ // Sets the builder to a mode where it will die immediately when an error is
+ // encountered, rather than producing it in a deferred fashion when Build() is
+ // called (which is the default).
+ void set_die_immediately_on_error(bool enabled) {
+ die_immediately_on_error_ = enabled;
+ }
+
+ // Enqueues an add instruction onto the computation.
+ XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a call instruction onto the computation.
+ XlaOp Call(const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+ // Enqueues a "retrieve parameter value" instruction for a parameter that was
+ // passed to the computation.
+ XlaOp Parameter(int64 parameter_number, const Shape& shape,
+ const string& name);
+
+ // Enqueues a constant with the value of the given literal onto the
+ // computation.
+ XlaOp ConstantLiteral(const Literal& literal);
+
+ // Enqueues a constant onto the computation. Methods are templated on the
+ // native host type (NativeT) which corresponds to a specific XLA
+ // PrimitiveType as given in the following table:
+ //
+ // Native Type PrimitiveType
+ // -----------------------------
+ // bool PRED
+ // int32 S32
+ // int64 S64
+ // uint32 U32
+ // uint64 U64
+ // float F32
+ // double F64
+ //
+ // Note: not all primitive types defined in xla_data.proto have a
+ // corresponding native type yet.
+ template <typename NativeT>
+ XlaOp ConstantR0(NativeT value);
+
+ // Returns the shape of the given op.
+ StatusOr<std::unique_ptr<Shape>> GetShape(const XlaOp& op) const;
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status.
+ StatusOr<XlaComputation> Build();
+
+ private:
+ XlaOp AddInstruction(HloInstructionProto&& instr);
+
+ // Notes that the error occurred by:
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to Build())
+ // * dying if die_immediately_on_error_ is true
+ void NoteError(const Status& error);
+
+ XlaOp NoteErrorOrReturn(StatusOr<XlaOp>&& op) {
+ if (!op.ok()) {
+ NoteError(op.status());
+ return XlaOp();
+ }
+ return op.ConsumeValueOrDie();
+ }
+
+ StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+
+ string name_; // Name to use for the built computation.
+
+ // The first error encountered while building the computation.
+ // This is OK until the first error is encountered.
+ Status first_error_;
+
+ // The saved stack trace from the point at which the first error occurred.
+ tensorflow::SavedStackTrace first_error_backtrace_;
+
+ // The instructions of this computation.
+ std::vector<HloInstructionProto> instructions_;
+
+ // The embedded computations used by this computation. Each computation was
+ // the entry computation of some XlaComputation, the key is the unique id of
+ // that XlaComputation.
+ std::map<int64, HloComputationProto> embedded_;
+
+ // The unique parameter numbers.
+ tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+
+ // Mode bit that indicates whether to die when a first error is encountered.
+ bool die_immediately_on_error_ = false;
+};
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR0(NativeT value) {
+ return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+
+using ::testing::HasSubstr;
+
+// TODO(b/74197823): Move the tests to service/.
+class XlaBuilderTest : public ::testing::Test {
+ protected:
+ StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
+ TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build());
+ const HloModuleProto& proto = computation.proto();
+ TF_ASSIGN_OR_RETURN(const auto& config,
+ HloModule::CreateModuleConfigFromProto(proto));
+ return HloModule::CreateFromProto(proto, config);
+ }
+
+ // Returns the name of the test currently being run.
+ string TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+};
+
+TEST_F(XlaBuilderTest, OnePlusTwo) {
+ XlaBuilder b(TestName());
+ b.Add(b.ConstantR0<float>(1.0), b.ConstantR0<float>(2.0));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ParamPlusConstant) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
+ b.Add(x, b.ConstantR0<float>(1.0));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ParamPlusParam) {
+ XlaBuilder b(TestName());
+ const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
+ const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
+ auto x = b.Parameter(0, x_shape, "x");
+ auto y = b.Parameter(1, y_shape, "y");
+ auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
+
+ TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape());
+ EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST_F(XlaBuilderTest, XPlusX) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
+ b.Add(x, x);
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
+}
+
+TEST_F(XlaBuilderTest, ShapeInferenceError) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
+ b.Add(x, y);
+ auto statusor = BuildHloModule(&b);
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference"));
+}
+
+TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
+ XlaBuilder b_call("add");
+ b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
+
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
+ auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y");
+ b.Add(x, y);
+ auto statusor = BuildHloModule(&b);
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("parameter 0 already registered"));
+}
+
+TEST_F(XlaBuilderTest, Call) {
+ XlaBuilder b_call("the_only_to_apply");
+ auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
+ auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
+ b_call.Add(p0, p1);
+ TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto one = b.ConstantR0<float>(1);
+ auto two = b.ConstantR0<float>(2);
+ b.Add(b.Call(call, {x, y}), b.Call(call, {one, two}));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
+ op::Call(op::Constant(), op::Constant())));
+}
+
+} // namespace
+} // namespace xla
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
repeated int64 gather_window_bounds = 34;
+
+ // The id of this instruction.
+ int64 id = 35;
}
// Serialization of HloComputation.
// The program shape (with layout) of this computation.
xla.ProgramShape program_shape = 4;
+
+ // The id of this computation.
+ int64 id = 5;
}
// Serialization of HloModule.
// The program shape (with layout) of the entry computation.
xla.ProgramShape program_shape = 4;
+
+ // The id of this module.
+ int64 id = 5;
}
// Serialization of HloOrdering.
}
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
+ HloOpcode opcode, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs,
+ broadcast_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
VLOG(2) << tensorflow::strings::Printf(
static StatusOr<Shape> InferBinaryOpShape(
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ static StatusOr<Shape> InferBinaryOpShape(
+ HloOpcode opcode, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
const HloInstruction* lhs,
const HloInstruction* rhs);