Add methods for extracting the shapes of the entry computation from an HloProto.
authorMark Heffernan <meheff@google.com>
Mon, 5 Mar 2018 21:44:42 +0000 (13:44 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 21:48:53 +0000 (13:48 -0800)
PiperOrigin-RevId: 187915821

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_proto_util.cc
tensorflow/compiler/xla/service/hlo_proto_util.h
tensorflow/compiler/xla/service/hlo_proto_util_test.cc [new file with mode: 0644]

index 6f52703..3eecc46 100644 (file)
@@ -2387,6 +2387,24 @@ cc_library(
         ":hlo",
         ":hlo_proto",
         "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:util",
+    ],
+)
+
+tf_cc_test(
+    name = "hlo_proto_util_test",
+    srcs = ["hlo_proto_util_test.cc"],
+    deps = [
+        ":hlo",
+        ":hlo_proto",
+        ":hlo_proto_util",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
     ],
 )
 
index 78e6a10..f75c452 100644 (file)
@@ -15,8 +15,112 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
 
+#include <string>
+
+#include "tensorflow/compiler/xla/util.h"
+
 namespace xla {
 
+namespace {
+
+// Returns the entry computation of the HLO module in the given HloProto.
+StatusOr<const HloComputationProto*> GetEntryComputation(
+    const HloProto& hlo_proto) {
+  if (!hlo_proto.has_hlo_module()) {
+    return NotFound("HloProto missing HloModuleProto.");
+  }
+
+  if (hlo_proto.hlo_module().entry_computation_name().empty()) {
+    return NotFound("HloProto has empty entry computation name.");
+  }
+
+  const string& entry_computation_name =
+      hlo_proto.hlo_module().entry_computation_name();
+  const HloComputationProto* entry_computation = nullptr;
+  for (const HloComputationProto& computation :
+       hlo_proto.hlo_module().computations()) {
+    if (computation.name() == entry_computation_name) {
+      if (entry_computation == nullptr) {
+        entry_computation = &computation;
+      } else {
+        return InvalidArgument(
+            "HloProto has multiple computations with entry computation named "
+            "%s.",
+            entry_computation_name.c_str());
+      }
+    }
+  }
+  if (entry_computation == nullptr) {
+    return InvalidArgument("HloProto has no entry computation named %s.",
+                           entry_computation_name.c_str());
+  }
+  return entry_computation;
+}
+
+// Returns the root instruction of the given computation proto.
+StatusOr<const HloInstructionProto*> GetRootInstruction(
+    const HloComputationProto& computation) {
+  if (computation.root_name().empty()) {
+    return InvalidArgument("Missing root instruction name.");
+  }
+
+  const HloInstructionProto* root = nullptr;
+  for (const HloInstructionProto& instruction : computation.instructions()) {
+    if (instruction.name() == computation.root_name()) {
+      if (root == nullptr) {
+        root = &instruction;
+      } else {
+        return InvalidArgument(
+            "Computation has multiple instructions named %s.",
+            computation.root_name().c_str());
+      }
+    }
+  }
+  if (root == nullptr) {
+    return InvalidArgument("Computation has no instruction named %s.",
+                           computation.root_name().c_str());
+  }
+  return root;
+}
+
+// Returns the parameters of the given computation. Parameter numbers are
+// checked for validity and contiguousness.
+StatusOr<std::vector<const HloInstructionProto*>> GetParameters(
+    const HloComputationProto& computation) {
+  std::vector<const HloInstructionProto*> parameters;
+  for (const HloInstructionProto& instruction : computation.instructions()) {
+    if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
+      parameters.push_back(&instruction);
+    }
+  }
+
+  // Verify the uniqueness and validity of the parameter numbers.
+  tensorflow::gtl::FlatSet<int64> parameter_numbers;
+  for (const HloInstructionProto* parameter : parameters) {
+    if (parameter->parameter_number() < 0 ||
+        parameter->parameter_number() >= parameters.size()) {
+      return InvalidArgument(
+          "Parameter instruction %s has invalid parameter number %lld.",
+          parameter->name().c_str(), parameter->parameter_number());
+    }
+    if (parameter_numbers.count(parameter->parameter_number()) != 0) {
+      return InvalidArgument(
+          "Multiple parameter instructions have parameter number %lld.",
+          parameter->parameter_number());
+    }
+    parameter_numbers.insert(parameter->parameter_number());
+  }
+
+  std::sort(parameters.begin(), parameters.end(),
+            [](const HloInstructionProto* a, const HloInstructionProto* b) {
+              return a->parameter_number() < b->parameter_number();
+            });
+
+  return parameters;
+}
+
+}  // namespace
+
 HloProto MakeHloProto(const HloModule& module,
                       const BufferAssignment& assignment) {
   HloOrderingProto proto_ordering =
@@ -35,4 +139,35 @@ HloProto MakeHloProto(const HloModule& module) {
   return proto;
 }
 
+StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
+    const HloProto& hlo_proto) {
+  TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
+                      GetEntryComputation(hlo_proto));
+  TF_ASSIGN_OR_RETURN(std::vector<const HloInstructionProto*> parameters,
+                      GetParameters(*entry_computation));
+  std::vector<const Shape*> parameter_shapes;
+  for (const HloInstructionProto* parameter : parameters) {
+    if (!parameter->has_shape()) {
+      return InvalidArgument("Parameter instruction %s is missing shape.",
+                             parameter->name().c_str());
+    }
+    parameter_shapes.push_back(&parameter->shape());
+  }
+  return parameter_shapes;
+}
+
+StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) {
+  TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
+                      GetEntryComputation(hlo_proto));
+
+  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
+                      GetRootInstruction(*entry_computation));
+  if (!root->has_shape()) {
+    return InvalidArgument("Instruction %s is missing shape.",
+                           root->name().c_str());
+  }
+
+  return &root->shape();
+}
+
 }  // namespace xla
index 320288f..3d9c375 100644 (file)
@@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module,
 // will not be included in the output.
 HloProto MakeHloProto(const HloModule& module);
 
+// Returns the shapes of the parameters of the entry computation. Shape pointers
+// refer to shapes inside of the given HloProto.
+StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
+    const HloProto& hlo_proto);
+
+// Returns the shape of the output of the entry computation. The shape pointer
+// refers to the output shape inside of the given HloProto.
+StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto);
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
new file mode 100644 (file)
index 0000000..0c0abf1
--- /dev/null
@@ -0,0 +1,161 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace {
+
+class HloProtoUtilTest : public ::testing::Test {};
+
+TEST_F(HloProtoUtilTest, ParamsAndOutputShape) {
+  HloProto hlo_proto;
+  HloModuleProto* module = hlo_proto.mutable_hlo_module();
+  module->set_entry_computation_name("entry");
+  HloComputationProto* computation = module->add_computations();
+  computation->set_name("entry");
+  computation->set_root_name("root");
+
+  HloInstructionProto* param0 = computation->add_instructions();
+  param0->set_opcode(HloOpcodeString(HloOpcode::kParameter));
+  param0->set_parameter_number(0);
+  *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42});
+
+  HloInstructionProto* param2 = computation->add_instructions();
+  param2->set_opcode(HloOpcodeString(HloOpcode::kParameter));
+  param2->set_parameter_number(2);
+  *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3});
+
+  HloInstructionProto* param1 = computation->add_instructions();
+  param1->set_opcode(HloOpcodeString(HloOpcode::kParameter));
+  param1->set_parameter_number(1);
+  *param1->mutable_shape() = ShapeUtil::MakeShape(F64, {});
+
+  HloInstructionProto* root = computation->add_instructions();
+  root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
+  root->set_name("root");
+  *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
+
+  VLOG(1) << hlo_proto.DebugString();
+
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes,
+                          EntryComputationParameterShapes(hlo_proto));
+  ASSERT_EQ(parameter_shapes.size(), 3);
+  EXPECT_TRUE(
+      ShapeUtil::Equal(*parameter_shapes[0], ShapeUtil::MakeShape(F32, {42})));
+  EXPECT_TRUE(
+      ShapeUtil::Equal(*parameter_shapes[1], ShapeUtil::MakeShape(F64, {})));
+  EXPECT_TRUE(ShapeUtil::Equal(*parameter_shapes[2],
+                               ShapeUtil::MakeShape(S32, {1, 2, 3})));
+
+  TF_ASSERT_OK_AND_ASSIGN(const Shape* output_shape,
+                          EntryComputationOutputShape(hlo_proto));
+  EXPECT_TRUE(ShapeUtil::Equal(*output_shape, ShapeUtil::MakeShape(U8, {2})));
+}
+
+TEST_F(HloProtoUtilTest, ParamsAndOutputShapeNoParameters) {
+  HloProto hlo_proto;
+  HloModuleProto* module = hlo_proto.mutable_hlo_module();
+  module->set_entry_computation_name("entry");
+  HloComputationProto* computation = module->add_computations();
+  computation->set_name("entry");
+  computation->set_root_name("root");
+
+  HloInstructionProto* root = computation->add_instructions();
+  root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
+  root->set_name("root");
+  *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
+
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes,
+                          EntryComputationParameterShapes(hlo_proto));
+  ASSERT_EQ(parameter_shapes.size(), 0);
+}
+
+TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) {
+  HloProto hlo_proto;
+
+  auto status = EntryComputationParameterShapes(hlo_proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              ::testing::HasSubstr("missing HloModuleProto"));
+}
+
+TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingEntryComputation) {
+  HloProto hlo_proto;
+  HloModuleProto* module = hlo_proto.mutable_hlo_module();
+  module->set_entry_computation_name("entry");
+  HloComputationProto* computation = module->add_computations();
+  computation->set_name("not_entry");
+
+  auto status = EntryComputationParameterShapes(hlo_proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              ::testing::HasSubstr("has no entry computation named"));
+}
+
+TEST_F(HloProtoUtilTest, OutputShapeMissingEntryRoot) {
+  HloProto hlo_proto;
+  HloModuleProto* module = hlo_proto.mutable_hlo_module();
+  module->set_entry_computation_name("entry");
+  HloComputationProto* computation = module->add_computations();
+  computation->set_name("entry");
+  computation->set_root_name("root");
+
+  auto status = EntryComputationOutputShape(hlo_proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              ::testing::HasSubstr("has no instruction named"));
+}
+
+TEST_F(HloProtoUtilTest, ParamsShapesMissingParameterNumbers) {
+  HloProto hlo_proto;
+  HloModuleProto* module = hlo_proto.mutable_hlo_module();
+  module->set_entry_computation_name("entry");
+  HloComputationProto* computation = module->add_computations();
+  computation->set_name("entry");
+  computation->set_root_name("root");
+
+  HloInstructionProto* param0 = computation->add_instructions();
+  param0->set_opcode(HloOpcodeString(HloOpcode::kParameter));
+  param0->set_parameter_number(0);
+  *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42});
+
+  HloInstructionProto* param2 = computation->add_instructions();
+  param2->set_opcode(HloOpcodeString(HloOpcode::kParameter));
+  param2->set_parameter_number(2);
+  *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3});
+
+  HloInstructionProto* root = computation->add_instructions();
+  root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
+  root->set_name("root");
+  *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
+
+  auto status = EntryComputationParameterShapes(hlo_proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              ::testing::HasSubstr("invalid parameter number"));
+}
+
+}  // namespace
+}  // namespace xla