[TF:XLA] Fix PotentiallyImplementedAsEigenConvolution to use the correct shape as...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 22:12:35 +0000 (15:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:14:47 +0000 (04:14 -0700)
A small bug is found in accessing the kernel's shape of the convolution instruction in PotentiallyImplementedAsEigenConvolution. The bug was fixed and a new testcase is created to reveal the bug.

PiperOrigin-RevId: 190282385

tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc [new file with mode: 0644]

index 093db02..0faa9e9 100644 (file)
@@ -670,6 +670,22 @@ cc_library(
     ],
 )
 
+tf_cc_test(
+    name = "ir_emission_utils_test",
+    srcs = ["ir_emission_utils_test.cc"],
+    deps = [
+        ":ir_emission_utils",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_matchers",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+    ],
+)
+
 cc_library(
     name = "cpu_layout_assignment",
     srcs = ["cpu_layout_assignment.cc"],
index 788217a..f209a69 100644 (file)
@@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution(
   //
   // To be sufficient, certain layout constraints need to be satisfied as well.
   const Shape& input_shape = convolution.operand(0)->shape();
-  const Shape& kernel_shape = convolution.operand(0)->shape();
+  const Shape& kernel_shape = convolution.operand(1)->shape();
   if (ShapeUtil::HasZeroElements(input_shape) ||
       ShapeUtil::HasZeroElements(kernel_shape)) {
     return false;
   }
+  // Make sure input and kernel has the same data type.
+  CHECK(
+      ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape));
   // TODO(b/65408531): Explore using Eigen dot for complex64 type.
-  if (ShapeUtil::ElementIsComplex(input_shape) ||
-      ShapeUtil::ElementIsComplex(kernel_shape)) {
+  if (ShapeUtil::ElementIsComplex(input_shape)) {
     return false;
   }
   if (window_util::HasWindowReversal(convolution.window())) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
new file mode 100644 (file)
index 0000000..215f48c
--- /dev/null
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) {
+  const char* const hlo_string = R"(
+HloModule ModuleWithConv
+
+ENTRY Conv {
+  input = f32[32,50,28,28]{3,2,1,0} parameter(0)
+  kernel = f32[0,32,5,5]{3,2,1,0} parameter(1)
+  ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel),
+    window={size=5x5},
+    dim_labels=b01f_01io->b01f
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_string));
+
+  HloComputation* entry_computation = module->entry_computation();
+
+  HloInstruction* conv_instr = entry_computation->root_instruction();
+  EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr));
+}
+
+}  // namespace
+}  // namespace xla