Fix problem with HandleElementwiseUnary/Binary in DfsHloVisitorWithDefault.
authorMark Heffernan <meheff@google.com>
Wed, 28 Mar 2018 01:31:55 +0000 (18:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 01:34:20 +0000 (18:34 -0700)
DfsHloVisitorWithDefault incorrectly included some overrides for handling
several elementwise binary and unary opcodes. These overrides explicitly
called DefaultAction which meant that these opcodes were not handled by
HandleElementwiseUnary/Binary. This CL removes these overrides and adds a
comment describing the potential problem. Unfortunately, I don't see a way
of automatically catching these issues when new opcodes are added, so the
comment will have to do.

PiperOrigin-RevId: 190708245

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc [new file with mode: 0644]

index f0bf68a..bde749d 100644 (file)
@@ -285,6 +285,23 @@ cc_library(
     ],
 )
 
+tf_cc_test(
+    name = "dfs_hlo_visitor_with_default_test",
+    srcs = ["dfs_hlo_visitor_with_default_test.cc"],
+    deps = [
+        ":hlo",
+        ":hlo_runner",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
+    ],
+)
+
 cc_library(
     name = "hlo_reachability",
     srcs = ["hlo_reachability.cc"],
index ecda528..240faeb 100644 (file)
@@ -35,6 +35,12 @@ class HloInstruction;
 // DfsHloVisitor with default action based on the HloInstruction being visited.
 // Users should not use this class directly, but use the type aliases
 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
+//
+// Do *not* add an override to this class if the opcode is covered by
+// HandleElementwiseUnary/Binary. These opcode handlers dispatch to
+// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
+// here will break passes which rely on the HandleElementwiseUnary/Binary
+// handling these opcodes.
 template <typename HloInstructionPtr>
 class DfsHloVisitorWithDefaultBase
     : public DfsHloVisitorBase<HloInstructionPtr> {
@@ -70,12 +76,6 @@ class DfsHloVisitorWithDefaultBase
   Status HandleConcatenate(HloInstructionPtr concatenate) override {
     return DefaultAction(concatenate);
   }
-  Status HandleConvert(HloInstructionPtr convert) override {
-    return DefaultAction(convert);
-  }
-  Status HandleCopy(HloInstructionPtr copy) override {
-    return DefaultAction(copy);
-  }
   Status HandleSelect(HloInstructionPtr select) override {
     return DefaultAction(select);
   }
@@ -91,9 +91,6 @@ class DfsHloVisitorWithDefaultBase
   Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
     return DefaultAction(crs);
   }
-  Status HandleCompare(HloInstructionPtr compare) override {
-    return DefaultAction(compare);
-  }
   Status HandleRng(HloInstructionPtr random) override {
     return DefaultAction(random);
   }
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc
new file mode 100644 (file)
index 0000000..825e143
--- /dev/null
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_runner.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class DfsHloVisitorWithDefaultTest : public HloTestBase {};
+
+TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) {
+  // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called
+  // on the appropriate HLO ops (elementwise binary/unary ops).
+
+  class ElementwiseTestVisitor : public DfsHloVisitorWithDefault {
+   public:
+    Status DefaultAction(HloInstruction* hlo) override {
+      // The HLO should be neither an elementwise unary nor binary op. These
+      // cases are handled in HandleElementwiseBinary/Unary.
+      TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2))
+          << hlo->ToString();
+      TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1))
+          << hlo->ToString();
+      return Status::OK();
+    }
+
+    Status HandleElementwiseBinary(HloInstruction* hlo) override {
+      // HLO should be elementwise binary.
+      TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2)
+          << hlo->ToString();
+      return Status::OK();
+    }
+    Status HandleElementwiseUnary(HloInstruction* hlo) override {
+      // HLO should be elementwise unary.
+      TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1)
+          << hlo->ToString();
+      return Status::OK();
+    }
+  };
+
+  // HLO module contains are arbitrary mix of elementwise and non-elementwise
+  // operations.
+  const string& hlo_string = R"(
+HloModule TestModule
+
+ENTRY TestComputation {
+  arg = f32[] parameter(0)
+  tuple = (f32[]) tuple(arg)
+  gte = f32[] get-tuple-element(tuple), index=0
+  abs = f32[] abs(arg)
+  add = f32[] add(arg, gte)
+  broadcast = f32[42] broadcast(add), dimensions={}
+  slice = f32[0] slice(broadcast), slice={[1:2]}
+  copy = f32[] copy(arg)
+  eq = pred[] equal-to(arg, gte)
+  neg = f32[] negate(arg)
+  ROOT convert = f64[] convert(f32[] arg)
+})";
+  std::unique_ptr<HloModule> module =
+      HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
+          .ConsumeValueOrDie();
+  ElementwiseTestVisitor visitor;
+  TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
+}
+
+}  // namespace
+}  // namespace xla