From 6b4b9d5c16cddebf6a91b0d027848c3c12371982 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 14 Mar 2018 11:49:04 -0700 Subject: [PATCH] [XLA] Add XlaBuilder. This is the first step of implementing the client-service interface redesign. Implemented ops: Add, Call, Constant, Parameter. PiperOrigin-RevId: 189061445 --- tensorflow/BUILD | 1 + tensorflow/compiler/xla/client/xla_client/BUILD | 77 +++++++ .../compiler/xla/client/xla_client/xla_builder.cc | 253 +++++++++++++++++++++ .../compiler/xla/client/xla_client/xla_builder.h | 216 ++++++++++++++++++ .../xla/client/xla_client/xla_builder_test.cc | 137 +++++++++++ tensorflow/compiler/xla/service/hlo.proto | 9 + tensorflow/compiler/xla/service/shape_inference.cc | 7 + tensorflow/compiler/xla/service/shape_inference.h | 3 + 8 files changed, 703 insertions(+) create mode 100644 tensorflow/compiler/xla/client/xla_client/BUILD create mode 100644 tensorflow/compiler/xla/client/xla_client/xla_builder.cc create mode 100644 tensorflow/compiler/xla/client/xla_client/xla_builder.h create mode 100644 tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 957528c..057ac79 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -433,6 +433,7 @@ filegroup( "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", + "//tensorflow/compiler/xla/client/xla_client:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD new file mode 100644 index 0000000..b912889 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -0,0 +1,77 @@ +# Description: +# The new XLA client libraries. +# +# This is NOT YET ready to use. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# TODO(b/74197823): Replace computation_builder with xla_builder. +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc new file mode 100644 index 0000000..6328a4f --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -0,0 +1,253 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +using tensorflow::strings::StrCat; + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 built_counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = built_counter++; + return id; +} + +// Returns true if an instruction with the given opcode can be the root of the +// computation. +bool CanBeRoot(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kSend: + case HloOpcode::kOutfeed: + case HloOpcode::kTrace: + return false; + default: + return true; + } +} + +void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) { + instr->set_opcode(HloOpcodeString(opcode)); +} + +} // namespace + +StatusOr> XlaBuilder::GetShape(const XlaOp& op) const { + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); + return MakeUnique(instr->shape()); +} + +StatusOr XlaOp::GetShape() const { + TF_RET_CHECK(builder_ != nullptr); + TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this)); + return *shape; +} + +XlaBuilder::XlaBuilder(const string& computation_name) + : name_(computation_name) {} + +XlaBuilder::~XlaBuilder() {} + +void XlaBuilder::NoteError(const Status& error) { + CHECK(!error.ok()); + if (die_immediately_on_error_) { + LOG(FATAL) << "error building computation: " << error; + } + + if (first_error_.ok()) { + first_error_ = error; + first_error_backtrace_.CreateCurrent(/*skip_count=*/1); + } +} + +StatusOr XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + ProgramShape* program_shape = entry.mutable_program_shape(); + + entry.set_name(name_); + + // Not all instructions can be roots. Walk backwards from the last added + // instruction until a valid root is found. + for (int64 i = instructions_.size() - 1; i >= 0; i--) { + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instructions_[i].opcode())); + if (CanBeRoot(opcode)) { + entry.set_root_name(instructions_[i].name()); + *program_shape->mutable_result() = instructions_[i].shape(); + break; + } + } + if (entry.root_name().empty()) { + return FailedPrecondition("no root instruction was found"); + } + + // Check that the parameter numbers are continuous from 0, and add parameter + // shapes and names to the program shape. + const int64 param_count = parameter_numbers_.size(); + for (int64 i = 0; i < param_count; i++) { + program_shape->add_parameters(); + program_shape->add_parameter_names(); + } + for (const HloInstructionProto& instr : instructions_) { + // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So + // to verify continuity, we just need to verify that every parameter is in + // the right range. + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + const int64 index = instr.parameter_number(); + TF_RET_CHECK(index >= 0 && index < param_count) + << "invalid parameter number: " << index; + *program_shape->mutable_parameters(index) = instr.shape(); + *program_shape->mutable_parameter_names(index) = instr.name(); + } + } + + for (auto& instruction : instructions_) { + entry.add_instructions()->Swap(&instruction); + } + + const int64 id = GetUniqueId(); + entry.set_id(id); + XlaComputation computation(id); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_entry_computation_name(entry.name()); + *module->mutable_program_shape() = entry.program_shape(); + for (auto& e : embedded_) { + module->add_computations()->Swap(&e.second); + } + module->add_computations()->Swap(&entry); + + return std::move(computation); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + auto op = [&]() -> StatusOr { + HloInstructionProto instr; + SetOpcode(&instr, HloOpcode::kAdd); + TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs)); + TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBinaryOpShape( + HloOpcode::kAdd, lhs_instr->shape(), + rhs_instr->shape(), broadcast_dimensions)); + instr.add_operand_names(lhs_instr->name()); + instr.add_operand_names(rhs_instr->name()); + return AddInstruction(std::move(instr)); + }; + return NoteErrorOrReturn(op()); +} + +XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { + HloInstructionProto instr; + SetOpcode(&instr, HloOpcode::kConstant); + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr)); +} + +XlaOp XlaBuilder::Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + auto op = [&]() -> StatusOr { + HloInstructionProto instr; + SetOpcode(&instr, HloOpcode::kCall); + std::vector operand_shapes; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand)); + operand_shapes.push_back(&input->shape()); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferCallShape( + operand_shapes, + /*to_apply=*/computation.GetProgramShape())); + + // Add input operands. + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand)); + instr.add_operand_names(operand_instr->name()); + } + + // Add called computation. + *instr.add_called_computation_names() = computation.proto().name(); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } + + return AddInstruction(std::move(instr)); + }; + return NoteErrorOrReturn(op()); +} + +XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, + const string& name) { + auto op = [&]() -> StatusOr { + HloInstructionProto instr; + SetOpcode(&instr, HloOpcode::kParameter); + if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) { + return InvalidArgument("parameter %lld already registered", + parameter_number); + } + parameter_numbers_.insert(parameter_number); + instr.set_parameter_number(parameter_number); + instr.set_name(name); + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr)); + }; + return NoteErrorOrReturn(op()); +} + +XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) { + const int64 handle = instructions_.size(); + if (instr.name().empty()) { + instr.set_name(StrCat(instr.opcode(), ".", handle)); + } else { + // Append the handle to make sure the name is unique. + instr.set_name(StrCat(instr.name(), ".", handle)); + } + instructions_.push_back(instr); + + XlaOp op(handle, this); + return op; +} + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp& op) const { + TF_RET_CHECK(op.builder_ == this); + if (op.handle() >= instructions_.size() || op.handle() < 0) { + return InvalidArgument("no XlaOp value %lld", op.handle()); + } + return &instructions_[op.handle()]; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h new file mode 100644 index 0000000..7632bd2 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -0,0 +1,216 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(b/74197823): Replace computation_builder.h with this file. +// +// This is NOT YET ready to use. + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class XlaBuilder; + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +// +// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. +class XlaOp { + public: + StatusOr GetShape() const; + + private: + XlaOp() : handle_(0), builder_(nullptr) {} + XlaOp(int64 handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64 handle() const { return handle_; } + friend class XlaBuilder; + + int64 handle_; + XlaBuilder* builder_; // Not owned. +}; + +// The computation graph that the user builds up with the XlaBuilder. +// +// TODO(b/74197823): Replace xla::Computation with this one. +class XlaComputation { + public: + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) { *this = std::move(from); } + + XlaComputation& operator=(XlaComputation&& from) { + proto_ = std::move(from.proto()); + unique_id_ = from.unique_id_; + return *this; + } + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + const ProgramShape& GetProgramShape() const { return proto_.program_shape(); } + + const HloModuleProto& proto() const { return proto_; } + + private: + // Creates a null Computation. + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +// A convenient interface for building up computations. +// +// Thread-compatible. +// +// TODO(b/74197823): Replace xla::ComputationBuilder with this one. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + XlaBuilder(const string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + ~XlaBuilder(); + + // Returns the computation name. + const string& name() const { return name_; } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Enqueues a constant with the value of the given literal onto the + // computation. + XlaOp ConstantLiteral(const Literal& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + XlaOp ConstantR0(NativeT value); + + // Returns the shape of the given op. + StatusOr> GetShape(const XlaOp& op) const; + + // Builds the computation with the requested operations, or returns a non-ok + // status. + StatusOr Build(); + + private: + XlaOp AddInstruction(HloInstructionProto&& instr); + + // Notes that the error occurred by: + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to Build()) + // * dying if die_immediately_on_error_ is true + void NoteError(const Status& error); + + XlaOp NoteErrorOrReturn(StatusOr&& op) { + if (!op.ok()) { + NoteError(op.status()); + return XlaOp(); + } + return op.ConsumeValueOrDie(); + } + + StatusOr LookUpInstruction(const XlaOp& op) const; + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + std::vector instructions_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; +}; + +template +XlaOp XlaBuilder::ConstantR0(NativeT value) { + return ConstantLiteral(*Literal::CreateR0(value)); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc new file mode 100644 index 0000000..a400e4e --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::HasSubstr; + +// TODO(b/74197823): Move the tests to service/. +class XlaBuilderTest : public ::testing::Test { + protected: + StatusOr> BuildHloModule(XlaBuilder* b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto(proto)); + return HloModule::CreateFromProto(proto, config); + } + + // Returns the name of the test currently being run. + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } +}; + +TEST_F(XlaBuilderTest, OnePlusTwo) { + XlaBuilder b(TestName()); + b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ParamPlusConstant) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + b.Add(x, b.ConstantR0(1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ParamPlusParam) { + XlaBuilder b(TestName()); + const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); + const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); + auto x = b.Parameter(0, x_shape, "x"); + auto y = b.Parameter(1, y_shape, "y"); + auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(XlaBuilderTest, XPlusX) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + b.Add(x, x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); +} + +TEST_F(XlaBuilderTest, ShapeInferenceError) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); +} + +TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { + XlaBuilder b_call("add"); + b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("parameter 0 already registered")); +} + +TEST_F(XlaBuilderTest, Call) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + b_call.Add(p0, p1); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = b.ConstantR0(1); + auto two = b.ConstantR0(2); + b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), + op::Call(op::Constant(), op::Constant()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 66fd317..bf903d6 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -133,6 +133,9 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_window_bounds = 34; + + // The id of this instruction. + int64 id = 35; } // Serialization of HloComputation. @@ -148,6 +151,9 @@ message HloComputationProto { // The program shape (with layout) of this computation. xla.ProgramShape program_shape = 4; + + // The id of this computation. + int64 id = 5; } // Serialization of HloModule. @@ -161,6 +167,9 @@ message HloModuleProto { // The program shape (with layout) of the entry computation. xla.ProgramShape program_shape = 4; + + // The id of this module. + int64 id = 5; } // Serialization of HloOrdering. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 74f744a..8c8bd6d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -945,6 +945,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, + broadcast_dimensions); +} + +/* static */ StatusOr ShapeInference::InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d30452..085fdac 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -56,6 +56,9 @@ class ShapeInference { static StatusOr InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + static StatusOr InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); -- 2.7.4