],
)
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
],
)
+xla_test(
+ name = "elemental_ir_emitter_test",
+ srcs = ["elemental_ir_emitter_test.cc"],
+ backends = [
+ "cpu",
+ "gpu",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
cc_library(
name = "hlo_module_config",
srcs = ["hlo_module_config.cc"],
const llvm_ir::IrArray::Index& dot_result_index) const {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
- int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
- hlo->operand(0)->shape().dimensions_size() - 1);
+
+ const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
+ int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
+ int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
+
+ int64 contracted_dim_size =
+ hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_index.push_back(dot_result_index[i]);
}
- lhs_index.push_back(inner_loop->GetIndVarValue());
+ lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue());
- for (int64 i = 0; i < rhs_dims - 2; i++) {
+ for (int64 i = 0; i < rhs_dims - 1; i++) {
rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
}
- rhs_index.push_back(inner_loop->GetIndVarValue());
- rhs_index.push_back(dot_result_index.back());
+ rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
llvm::Value* current_accumulator =
ir_builder_->CreateLoad(accumulator_alloca);
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ElementalIrEmitterExecutionTest : public HloTestBase {
+ protected:
+ void RunTest(const string& hlo_text,
+ tensorflow::gtl::ArraySlice<Literal*> args) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text, config));
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt));
+ }
+};
+
+XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) {
+ const string hlo_text = R"(
+HloModule FusedDot
+
+fused_computation {
+ arg0 = s32[1,2,1]{2,1,0} parameter(0)
+ reshape.lhs = s32[2,1]{1,0} reshape(arg0)
+ arg1 = s32[1,2,1]{2,1,0} parameter(1)
+ reshape.rhs = s32[2,1]{1,0} reshape(arg1)
+ ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+
+ENTRY main {
+ entry_arg0 = s32[1,2,1]{2,1,0} parameter(0)
+ entry_arg1 = s32[1,2,1]{2,1,0} parameter(1)
+ ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation
+}
+)";
+
+ std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
+ std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {lhs.get(), rhs.get()});
+}
+} // namespace
+} // namespace xla