From 36b3c94a99704c8e1973ae5c043aec4870ae84ff Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 5 Mar 2018 13:44:42 -0800 Subject: [PATCH] Add methods for extracting the shapes of the entry computation from an HloProto. PiperOrigin-RevId: 187915821 --- tensorflow/compiler/xla/service/BUILD | 18 +++ tensorflow/compiler/xla/service/hlo_proto_util.cc | 135 +++++++++++++++++ tensorflow/compiler/xla/service/hlo_proto_util.h | 9 ++ .../compiler/xla/service/hlo_proto_util_test.cc | 161 +++++++++++++++++++++ 4 files changed, 323 insertions(+) create mode 100644 tensorflow/compiler/xla/service/hlo_proto_util_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6f52703..3eecc46 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2387,6 +2387,24 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + ], +) + +tf_cc_test( + name = "hlo_proto_util_test", + srcs = ["hlo_proto_util_test.cc"], + deps = [ + ":hlo", + ":hlo_proto", + ":hlo_proto_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 78e6a10..f75c452 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -15,8 +15,112 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include + +#include "tensorflow/compiler/xla/util.h" + namespace xla { +namespace { + +// Returns the entry computation of the HLO module in the given HloProto. +StatusOr GetEntryComputation( + const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + + if (hlo_proto.hlo_module().entry_computation_name().empty()) { + return NotFound("HloProto has empty entry computation name."); + } + + const string& entry_computation_name = + hlo_proto.hlo_module().entry_computation_name(); + const HloComputationProto* entry_computation = nullptr; + for (const HloComputationProto& computation : + hlo_proto.hlo_module().computations()) { + if (computation.name() == entry_computation_name) { + if (entry_computation == nullptr) { + entry_computation = &computation; + } else { + return InvalidArgument( + "HloProto has multiple computations with entry computation named " + "%s.", + entry_computation_name.c_str()); + } + } + } + if (entry_computation == nullptr) { + return InvalidArgument("HloProto has no entry computation named %s.", + entry_computation_name.c_str()); + } + return entry_computation; +} + +// Returns the root instruction of the given computation proto. +StatusOr GetRootInstruction( + const HloComputationProto& computation) { + if (computation.root_name().empty()) { + return InvalidArgument("Missing root instruction name."); + } + + const HloInstructionProto* root = nullptr; + for (const HloInstructionProto& instruction : computation.instructions()) { + if (instruction.name() == computation.root_name()) { + if (root == nullptr) { + root = &instruction; + } else { + return InvalidArgument( + "Computation has multiple instructions named %s.", + computation.root_name().c_str()); + } + } + } + if (root == nullptr) { + return InvalidArgument("Computation has no instruction named %s.", + computation.root_name().c_str()); + } + return root; +} + +// Returns the parameters of the given computation. Parameter numbers are +// checked for validity and contiguousness. +StatusOr> GetParameters( + const HloComputationProto& computation) { + std::vector parameters; + for (const HloInstructionProto& instruction : computation.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + parameters.push_back(&instruction); + } + } + + // Verify the uniqueness and validity of the parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers; + for (const HloInstructionProto* parameter : parameters) { + if (parameter->parameter_number() < 0 || + parameter->parameter_number() >= parameters.size()) { + return InvalidArgument( + "Parameter instruction %s has invalid parameter number %lld.", + parameter->name().c_str(), parameter->parameter_number()); + } + if (parameter_numbers.count(parameter->parameter_number()) != 0) { + return InvalidArgument( + "Multiple parameter instructions have parameter number %lld.", + parameter->parameter_number()); + } + parameter_numbers.insert(parameter->parameter_number()); + } + + std::sort(parameters.begin(), parameters.end(), + [](const HloInstructionProto* a, const HloInstructionProto* b) { + return a->parameter_number() < b->parameter_number(); + }); + + return parameters; +} + +} // namespace + HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { HloOrderingProto proto_ordering = @@ -35,4 +139,35 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto) { + TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, + GetEntryComputation(hlo_proto)); + TF_ASSIGN_OR_RETURN(std::vector parameters, + GetParameters(*entry_computation)); + std::vector parameter_shapes; + for (const HloInstructionProto* parameter : parameters) { + if (!parameter->has_shape()) { + return InvalidArgument("Parameter instruction %s is missing shape.", + parameter->name().c_str()); + } + parameter_shapes.push_back(¶meter->shape()); + } + return parameter_shapes; +} + +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { + TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, + GetEntryComputation(hlo_proto)); + + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + GetRootInstruction(*entry_computation)); + if (!root->has_shape()) { + return InvalidArgument("Instruction %s is missing shape.", + root->name().c_str()); + } + + return &root->shape(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 320288f..3d9c375 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Returns the shapes of the parameters of the entry computation. Shape pointers +// refer to shapes inside of the given HloProto. +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto); + +// Returns the shape of the output of the entry computation. The shape pointer +// refers to the output shape inside of the given HloProto. +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc new file mode 100644 index 0000000..0c0abf1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class HloProtoUtilTest : public ::testing::Test {}; + +TEST_F(HloProtoUtilTest, ParamsAndOutputShape) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_entry_computation_name("entry"); + HloComputationProto* computation = module->add_computations(); + computation->set_name("entry"); + computation->set_root_name("root"); + + HloInstructionProto* param0 = computation->add_instructions(); + param0->set_opcode(HloOpcodeString(HloOpcode::kParameter)); + param0->set_parameter_number(0); + *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42}); + + HloInstructionProto* param2 = computation->add_instructions(); + param2->set_opcode(HloOpcodeString(HloOpcode::kParameter)); + param2->set_parameter_number(2); + *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3}); + + HloInstructionProto* param1 = computation->add_instructions(); + param1->set_opcode(HloOpcodeString(HloOpcode::kParameter)); + param1->set_parameter_number(1); + *param1->mutable_shape() = ShapeUtil::MakeShape(F64, {}); + + HloInstructionProto* root = computation->add_instructions(); + root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); + root->set_name("root"); + *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); + + VLOG(1) << hlo_proto.DebugString(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector parameter_shapes, + EntryComputationParameterShapes(hlo_proto)); + ASSERT_EQ(parameter_shapes.size(), 3); + EXPECT_TRUE( + ShapeUtil::Equal(*parameter_shapes[0], ShapeUtil::MakeShape(F32, {42}))); + EXPECT_TRUE( + ShapeUtil::Equal(*parameter_shapes[1], ShapeUtil::MakeShape(F64, {}))); + EXPECT_TRUE(ShapeUtil::Equal(*parameter_shapes[2], + ShapeUtil::MakeShape(S32, {1, 2, 3}))); + + TF_ASSERT_OK_AND_ASSIGN(const Shape* output_shape, + EntryComputationOutputShape(hlo_proto)); + EXPECT_TRUE(ShapeUtil::Equal(*output_shape, ShapeUtil::MakeShape(U8, {2}))); +} + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeNoParameters) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_entry_computation_name("entry"); + HloComputationProto* computation = module->add_computations(); + computation->set_name("entry"); + computation->set_root_name("root"); + + HloInstructionProto* root = computation->add_instructions(); + root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); + root->set_name("root"); + *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); + + TF_ASSERT_OK_AND_ASSIGN(std::vector parameter_shapes, + EntryComputationParameterShapes(hlo_proto)); + ASSERT_EQ(parameter_shapes.size(), 0); +} + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { + HloProto hlo_proto; + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing HloModuleProto")); +} + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingEntryComputation) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_entry_computation_name("entry"); + HloComputationProto* computation = module->add_computations(); + computation->set_name("not_entry"); + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("has no entry computation named")); +} + +TEST_F(HloProtoUtilTest, OutputShapeMissingEntryRoot) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_entry_computation_name("entry"); + HloComputationProto* computation = module->add_computations(); + computation->set_name("entry"); + computation->set_root_name("root"); + + auto status = EntryComputationOutputShape(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("has no instruction named")); +} + +TEST_F(HloProtoUtilTest, ParamsShapesMissingParameterNumbers) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_entry_computation_name("entry"); + HloComputationProto* computation = module->add_computations(); + computation->set_name("entry"); + computation->set_root_name("root"); + + HloInstructionProto* param0 = computation->add_instructions(); + param0->set_opcode(HloOpcodeString(HloOpcode::kParameter)); + param0->set_parameter_number(0); + *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42}); + + HloInstructionProto* param2 = computation->add_instructions(); + param2->set_opcode(HloOpcodeString(HloOpcode::kParameter)); + param2->set_parameter_number(2); + *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3}); + + HloInstructionProto* root = computation->add_instructions(); + root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); + root->set_name("root"); + *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("invalid parameter number")); +} + +} // namespace +} // namespace xla -- 2.7.4