Make the elemental ir emitter for dot operations respect contraction dims
authorSanjoy Das <sanjoy@google.com>
Fri, 11 May 2018 21:05:38 +0000 (14:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 21:07:50 +0000 (14:07 -0700)
PiperOrigin-RevId: 196305803

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/elemental_ir_emitter.cc
tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc [new file with mode: 0644]

index f6af816..f1e57f3 100644 (file)
@@ -12,6 +12,7 @@ package_group(
     ],
 )
 
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
 load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
@@ -2371,6 +2372,24 @@ cc_library(
     ],
 )
 
+xla_test(
+    name = "elemental_ir_emitter_test",
+    srcs = ["elemental_ir_emitter_test.cc"],
+    backends = [
+        "cpu",
+        "gpu",
+    ],
+    deps = [
+        "//tensorflow/compiler/xla:execution_options_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla/tests:client_library_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+    ],
+)
+
 cc_library(
     name = "hlo_module_config",
     srcs = ["hlo_module_config.cc"],
index f2ad6ea..0a400e9 100644 (file)
@@ -1863,8 +1863,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
     const llvm_ir::IrArray::Index& dot_result_index) const {
   auto lhs_generator = operand_to_generator.at(hlo->operand(0));
   auto rhs_generator = operand_to_generator.at(hlo->operand(1));
-  int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
-      hlo->operand(0)->shape().dimensions_size() - 1);
+
+  const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
+  int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
+  int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
+
+  int64 contracted_dim_size =
+      hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
   int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
   int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
 
@@ -1895,13 +1900,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
   for (int64 i = 0; i < lhs_dims - 1; i++) {
     lhs_index.push_back(dot_result_index[i]);
   }
-  lhs_index.push_back(inner_loop->GetIndVarValue());
+  lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue());
 
-  for (int64 i = 0; i < rhs_dims - 2; i++) {
+  for (int64 i = 0; i < rhs_dims - 1; i++) {
     rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
   }
-  rhs_index.push_back(inner_loop->GetIndVarValue());
-  rhs_index.push_back(dot_result_index.back());
+  rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
 
   llvm::Value* current_accumulator =
       ir_builder_->CreateLoad(accumulator_alloca);
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
new file mode 100644 (file)
index 0000000..b43dc0c
--- /dev/null
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ElementalIrEmitterExecutionTest : public HloTestBase {
+ protected:
+  void RunTest(const string& hlo_text,
+               tensorflow::gtl::ArraySlice<Literal*> args) {
+    HloModuleConfig config;
+    config.set_debug_options(GetDebugOptionsForTest());
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                            tools::Parse(hlo_text, config));
+    EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt));
+  }
+};
+
+XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) {
+  const string hlo_text = R"(
+HloModule FusedDot
+
+fused_computation {
+  arg0 = s32[1,2,1]{2,1,0} parameter(0)
+  reshape.lhs = s32[2,1]{1,0} reshape(arg0)
+  arg1 = s32[1,2,1]{2,1,0} parameter(1)
+  reshape.rhs = s32[2,1]{1,0} reshape(arg1)
+  ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+
+ENTRY main {
+  entry_arg0 = s32[1,2,1]{2,1,0} parameter(0)
+  entry_arg1 = s32[1,2,1]{2,1,0} parameter(1)
+  ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation
+}
+)";
+
+  std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
+  std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
+  RunTest(hlo_text, {lhs.get(), rhs.get()});
+}
+}  // namespace
+}  // namespace xla