],
)
+tf_cc_test(
+ name = "dfs_hlo_visitor_with_default_test",
+ srcs = ["dfs_hlo_visitor_with_default_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_runner",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_reachability",
srcs = ["hlo_reachability.cc"],
// DfsHloVisitor with default action based on the HloInstruction being visited.
// Users should not use this class directly, but use the type aliases
// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
+//
+// Do *not* add an override to this class if the opcode is covered by
+// HandleElementwiseUnary/Binary. These opcode handlers dispatch to
+// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
+// here will break passes which rely on the HandleElementwiseUnary/Binary
+// handling these opcodes.
template <typename HloInstructionPtr>
class DfsHloVisitorWithDefaultBase
: public DfsHloVisitorBase<HloInstructionPtr> {
Status HandleConcatenate(HloInstructionPtr concatenate) override {
return DefaultAction(concatenate);
}
- Status HandleConvert(HloInstructionPtr convert) override {
- return DefaultAction(convert);
- }
- Status HandleCopy(HloInstructionPtr copy) override {
- return DefaultAction(copy);
- }
Status HandleSelect(HloInstructionPtr select) override {
return DefaultAction(select);
}
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
- Status HandleCompare(HloInstructionPtr compare) override {
- return DefaultAction(compare);
- }
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
}
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_runner.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class DfsHloVisitorWithDefaultTest : public HloTestBase {};
+
+TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) {
+ // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called
+ // on the appropriate HLO ops (elementwise binary/unary ops).
+
+ class ElementwiseTestVisitor : public DfsHloVisitorWithDefault {
+ public:
+ Status DefaultAction(HloInstruction* hlo) override {
+ // The HLO should be neither an elementwise unary nor binary op. These
+ // cases are handled in HandleElementwiseBinary/Unary.
+ TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2))
+ << hlo->ToString();
+ TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1))
+ << hlo->ToString();
+ return Status::OK();
+ }
+
+ Status HandleElementwiseBinary(HloInstruction* hlo) override {
+ // HLO should be elementwise binary.
+ TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2)
+ << hlo->ToString();
+ return Status::OK();
+ }
+ Status HandleElementwiseUnary(HloInstruction* hlo) override {
+ // HLO should be elementwise unary.
+ TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1)
+ << hlo->ToString();
+ return Status::OK();
+ }
+ };
+
+ // HLO module contains are arbitrary mix of elementwise and non-elementwise
+ // operations.
+ const string& hlo_string = R"(
+HloModule TestModule
+
+ENTRY TestComputation {
+ arg = f32[] parameter(0)
+ tuple = (f32[]) tuple(arg)
+ gte = f32[] get-tuple-element(tuple), index=0
+ abs = f32[] abs(arg)
+ add = f32[] add(arg, gte)
+ broadcast = f32[42] broadcast(add), dimensions={}
+ slice = f32[0] slice(broadcast), slice={[1:2]}
+ copy = f32[] copy(arg)
+ eq = pred[] equal-to(arg, gte)
+ neg = f32[] negate(arg)
+ ROOT convert = f64[] convert(f32[] arg)
+})";
+ std::unique_ptr<HloModule> module =
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
+ .ConsumeValueOrDie();
+ ElementwiseTestVisitor visitor;
+ TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
+}
+
+} // namespace
+} // namespace xla