Migrate HloExecutionProfileTest to textual HLO
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 16 May 2018 14:01:04 +0000 (07:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 14:03:53 +0000 (07:03 -0700)
Also add lhs_contracting_dims and rhs_contracting_dims to make the test more realistic.
Before, the dot operation was created with CreateBinary instead of CreateCanonicalDot.

PiperOrigin-RevId: 196822255

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/hlo_execution_profile_test.cc

index 6c40a3a..457768c 100644 (file)
@@ -1757,6 +1757,7 @@ tf_cc_test(
         ":hlo_execution_profile",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
         "//tensorflow/core:lib",
     ],
 )
index dcc4583..4900c81 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 
 namespace xla {
@@ -28,22 +29,19 @@ using ::testing::ContainsRegex;
 class HloExecutionProfileTest : public HloTestBase {};
 
 TEST_F(HloExecutionProfileTest, Basic) {
-  std::unique_ptr<HloModule> hlo_module = CreateNewModule();
-
-  HloComputation::Builder builder(TestName());
+  auto hlo_module = tools::Parse(R"(
+  HloModule test_module
+  ENTRY entry_computation {
+    lhs = f32[30,30]{1,0} parameter(0)
+    rhs = f32[30,30]{1,0} parameter(1)
+    add = f32[30,30]{1,0} add(lhs, rhs)
+    ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  })")
+                        .ValueOrDie();
+  const HloInstruction* dot_instruction =
+      hlo_module->entry_computation()->root_instruction();
+  const HloInstruction* add_instruction = dot_instruction->operand(1);
   Shape shape = ShapeUtil::MakeShape(F32, {30, 30});
-  HloInstruction* param_lhs =
-      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
-  HloInstruction* param_rhs =
-      builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
-  HloInstruction* add_instruction =
-      builder.AddInstruction(HloInstruction::CreateBinary(
-          shape, HloOpcode::kAdd, param_lhs, param_rhs));
-  HloInstruction* dot_instruction =
-      builder.AddInstruction(HloInstruction::CreateBinary(
-          shape, HloOpcode::kDot, param_lhs, add_instruction));
-
-  hlo_module->AddEntryComputation(builder.Build());
 
   auto shape_size_function = [&](const Shape& shape) {
     const int64 pointer_size = 8;