Do not crash on ROOT outfeed operations.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 4 May 2018 09:22:14 +0000 (02:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:48:26 +0000 (10:48 -0700)
PiperOrigin-RevId: 195388075

tensorflow/compiler/xla/service/cpu/ir_emitter.cc
tensorflow/compiler/xla/service/cpu/tests/BUILD
tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc [new file with mode: 0644]

index e473389..6347ee2 100644 (file)
@@ -2563,8 +2563,12 @@ Status IrEmitter::FinishVisit(HloInstruction* root) {
   // nothing to do since the result was already written directly into the output
   // buffer.
   VLOG(2) << "FinishVisit root: " << root->ToString();
-  llvm::Value* root_value = GetEmittedValueFor(root);
-  VLOG(2) << "  value: " << llvm_ir::DumpToString(*root_value);
+  if (root->opcode() == HloOpcode::kOutfeed) {
+    VLOG(2) << "  outfeed with value: "
+            << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
+  } else {
+    VLOG(2) << "  value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
+  }
 
   auto record_complete_computation = [&](llvm::Value* prof_counter) {
     if (prof_counter) {
index 4ddb7a8..18a915e 100644 (file)
@@ -161,3 +161,17 @@ tf_cc_test(
         "//tensorflow/core:test_main",
     ],
 )
+
+tf_cc_test(
+    name = "cpu_outfeed_test",
+    srcs = ["cpu_outfeed_test.cc"],
+    deps = [
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+        "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
index b10eb74..d6e0425 100644 (file)
@@ -50,16 +50,10 @@ ENTRY main {
   const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
 
   out0 = () outfeed(f32[2,3,2] const_a)
-  out1 = () outfeed(f32[2,3,2] const_b)
-
-  ROOT root = f32[] constant(1)
+  ROOT out1 = () outfeed(f32[2,3,2] const_b)
 }
 )";
 
-  // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work
-  // around b/78879738.  Once b/78879738 is fixed, we can set one of the
-  // outfeeds as the root.
-
   string filecheck_pattern = R"(
 CHECK: private constant [2 x [3 x [2 x float]]]
 CHECK-NOT: private constant [2 x [3 x [2 x float]]]
@@ -99,16 +93,10 @@ ENTRY main {
   const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body
 
   out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a)
-  out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b)
-
-  ROOT root = f32[] constant(1)
+  ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b)
 }
 )";
 
-  // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work
-  // around b/78879738.  Once b/78879738 is fixed, we can set one of the
-  // outfeeds as the root.
-
   string filecheck_pattern = R"(
 CHECK: private constant [2 x float]
 CHECK: private constant [2 x [1 x float]]
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
new file mode 100644 (file)
index 0000000..879372e
--- /dev/null
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuOutfeedTest : public CpuCodegenTest {};
+
+TEST_F(CpuOutfeedTest, OutfeedRoot) {
+  const string hlo_text = R"(
+HloModule Outfeed
+
+ENTRY main {
+  const_a = f32[2,3,2] constant(
+  f32[2,3,2]
+    {{{1, 2}, {1001, 1002}, {2001, 2002}},
+     {{2, 1}, {2001, 3002}, {2001, 2002}}})
+
+  ROOT out = () outfeed(f32[2,3,2] const_a)
+}
+)";
+
+  string filecheck_pattern = R"(
+CHECK: private constant [2 x [3 x [2 x float]]]
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_text));
+
+  CpuAotCompilationOptions options{
+      /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
+      /*entry_point_name=*/"entry",
+      /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+  CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+                                /*match_optimized_ir=*/false);
+}
+
+}  // namespace
+}  // namespace cpu
+}  // namespace xla