"//tensorflow/core:lib",
],
)
+
+cc_library(
+ name = "indexed_array_analysis",
+ srcs = ["indexed_array_analysis.cc"],
+ hdrs = ["indexed_array_analysis.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ptr_util",
+ ],
+)
+
+tf_cc_test(
+ name = "indexed_array_analysis_test",
+ srcs = ["indexed_array_analysis_test.cc"],
+ deps = [
+ ":hlo_matchers",
+ ":indexed_array_analysis",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:test",
+ ],
+)
"//tensorflow/compiler/xla/service:hlo_scheduling",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/service:indexed_array_analysis",
"//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
pass.AddPass<HloConstantFolding>();
pass.AddPass<ConditionalSimplifier>();
}
+ pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
[&target_machine_features](
const HloInstruction& dot,
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace gtl = ::tensorflow::gtl;
+
+namespace {
+using Analysis = IndexedArrayAnalysis;
+using UnknownArray = Analysis::UnknownArray;
+using ConstantArray = Analysis::ConstantArray;
+using ScalarIndexedArray = Analysis::ScalarIndexedArray;
+} // namespace
+
+string IndexedArrayAnalysis::ToString(Array* root) {
+ switch (root->kind()) {
+ case Array::kUnknown: {
+ auto* unknown_tensor = root->as<UnknownArray>();
+ return tensorflow::strings::StrCat("%",
+ unknown_tensor->instruction().name());
+ }
+
+ case Array::kConstant: {
+ return tensorflow::strings::StrCat(
+ "(constant ", ShapeUtil::HumanString(root->shape()), ")");
+ }
+
+ case Array::kScalarIndexedConstant:
+ case Array::kScalarIndexed: {
+ auto* indexed_array = root->as<ScalarIndexedArray>();
+ string name = root->kind() == Array::kScalarIndexedConstant
+ ? "scalar-indexed-const"
+ : "scalar-indexed";
+ return tensorflow::strings::StrCat(
+ "(", name, " ", ToString(indexed_array->source()), " ",
+ ToString(indexed_array->indices()), " ", indexed_array->source_dim(),
+ "->[", tensorflow::str_util::Join(indexed_array->output_dims(), ","),
+ "])");
+ }
+ }
+}
+
+Analysis::Array* IndexedArrayAnalysis::GetArrayFor(
+ const HloInstruction* instr) {
+ auto it = cache_.find(instr);
+ if (it != cache_.end()) {
+ return it->second;
+ }
+
+ TraverseAndPopulateCache(instr);
+ return FindOrDie(cache_, instr);
+}
+
+void IndexedArrayAnalysis::TraverseAndPopulateCache(
+ const HloInstruction* root) {
+ // Depth first search over the DAG, invoking ComputeArrayFor in post order.
+ // The HLO instructions already in the cache are considered leaves.
+
+ gtl::InlinedVector<const HloInstruction*, 4> stack;
+
+ enum DfsState { kDiscovered, kVisited };
+ gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
+
+ stack.push_back(root);
+ InsertOrDie(&dfs_state_map, root, kDiscovered);
+
+ do {
+ const HloInstruction* instr = stack.back();
+ if (cache_.count(instr)) {
+ stack.pop_back();
+ continue;
+ }
+
+ switch (FindOrDie(dfs_state_map, instr)) {
+ case kDiscovered: {
+ for (const HloInstruction* operand : instr->operands()) {
+ if (!cache_.count(operand)) {
+ stack.push_back(operand);
+ CHECK(!dfs_state_map.count(operand) ||
+ dfs_state_map[operand] == kDiscovered);
+ dfs_state_map[operand] = kDiscovered;
+ }
+ }
+ dfs_state_map[instr] = kVisited;
+ break;
+ }
+
+ case kVisited:
+ stack.pop_back();
+ InsertOrDie(&cache_, instr, ComputeArrayFor(instr));
+ break;
+ }
+ } while (!stack.empty());
+}
+
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
+ const HloInstruction* instr) {
+ Array* computed_array;
+ switch (instr->opcode()) {
+ default:
+ computed_array = nullptr;
+ break;
+ case HloOpcode::kConstant:
+ computed_array = ComputeArrayForConstant(instr->literal());
+ break;
+ case HloOpcode::kGather:
+ computed_array = ComputeArrayForGather(
+ instr->shape(), instr->gather_dimension_numbers(),
+ instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1)));
+ break;
+ }
+
+ if (!computed_array) {
+ computed_array = Construct<UnknownArray>(instr);
+ }
+
+ return computed_array;
+}
+
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant(
+ const Literal& literal) {
+ return Construct<ConstantArray>(&literal);
+}
+
+ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
+ ScalarIndexedArray* source, Array* indices, int64 source_dim,
+ tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
+ // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
+ // `source` is the inner Gather(A, X).
+
+ Array* a = source->source();
+ Array* x = source->indices();
+ Array* y = indices;
+
+ // This bit is slightly tricky, so we do a naive "simulation" of the two
+ // consecutive gather operations to infer what the composed gather should look
+ // like.
+
+ enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
+
+ std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
+ IndexComponent::Ungathered);
+
+ // Simulate the first gather.
+ simulated_index.erase(simulated_index.begin() + source->source_dim());
+ for (int64 gather_dim : source->output_dims()) {
+ simulated_index.insert(simulated_index.begin() + gather_dim,
+ IndexComponent::GatheredFirst);
+ }
+
+ // Simulate the second gather.
+ simulated_index.erase(simulated_index.begin() + source_dim);
+ for (int64 output_dim : output_dims) {
+ simulated_index.insert(simulated_index.begin() + output_dim,
+ IndexComponent::GatheredSecond);
+ }
+
+ int64 source_dim_for_index_array =
+ FindIndex(source->output_dims(), source_dim);
+ CHECK_NE(source_dim_for_index_array, source->output_dims().size());
+
+ std::vector<int64> output_dims_for_index_array;
+ int64 gathered_index_components_seen = 0;
+ for (IndexComponent simulation_dim : simulated_index) {
+ if (simulation_dim == IndexComponent::GatheredSecond) {
+ output_dims_for_index_array.push_back(gathered_index_components_seen);
+ }
+ if (simulation_dim != IndexComponent::Ungathered) {
+ gathered_index_components_seen++;
+ }
+ }
+
+ std::vector<int64> dim_sizes_for_composed_index;
+ std::vector<int64> output_dims_for_new_gather;
+ for (int64 i = 0, e = simulated_index.size(); i < e; i++) {
+ if (simulated_index[i] != IndexComponent::Ungathered) {
+ dim_sizes_for_composed_index.push_back(shape.dimensions(i));
+ output_dims_for_new_gather.push_back(i);
+ }
+ }
+
+ Array* inner_indices = ConstructScalarIndexedArray(
+ x, y, source_dim_for_index_array, output_dims_for_index_array,
+ ShapeUtil::MakeShape(x->shape().element_type(),
+ dim_sizes_for_composed_index));
+ return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
+ output_dims_for_new_gather,
+ std::move(shape));
+}
+
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather(
+ const Shape& shape, const GatherDimensionNumbers& dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ Array* indices) {
+ if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
+ return nullptr;
+ }
+
+ CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
+ if (!c_binary_search(dim_numbers.elided_window_dims(),
+ dim_numbers.gather_dims_to_operand_dims(0))) {
+ return nullptr;
+ }
+
+ int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
+ std::vector<int64> output_dims;
+ for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
+ if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ output_dims.push_back(i);
+ }
+ }
+
+ if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
+ auto it = c_find(indexed->output_dims(), source_dim);
+ if (it != indexed->output_dims().end()) {
+ return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
+ shape);
+ }
+ } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
+ return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
+ output_dims, shape);
+ }
+
+ return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
+ shape);
+}
+
+tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
+ return "indexed-array-analysis-printer-pass";
+}
+
+StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
+ if (!VLOG_IS_ON(2)) {
+ return false;
+ }
+
+ IndexedArrayAnalysis analysis;
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ for (auto* instr : computation->instructions()) {
+ auto* t = analysis.GetArrayFor(instr);
+ if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
+ VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t);
+ }
+ }
+ }
+
+ return false;
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
+
+#include <type_traits>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace xla {
+
+// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a
+// gather from another array. It does this by mapping HLO instructions to
+// instances of IndexedArrayAnalysis::Array, which can be inspected to discover
+// whether said HLO is equivalent to a gather.
+class IndexedArrayAnalysis {
+ public:
+ // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array.
+ // Array really just a sum type of the classes that inherit from it. The
+ // meaning of each of the subtypes is documented on the subtype declaration.
+ //
+ // Array instances are immutable once created.
+ class Array {
+ public:
+ enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed };
+
+ virtual Kind kind() const = 0;
+ virtual const Shape& shape() const = 0;
+
+ // Does a checked downcast from `Array` to `T` which must be one of its
+ // subtypes.
+ template <typename T>
+ T* as() {
+ static_assert((std::is_base_of<Array, T>::value),
+ "target type not derived from source type");
+ // We skip the CHECK and hence the dynamic_cast if RTTI is disabled.
+#if !defined(__GNUC__) || defined(__GXX_RTTI)
+ CHECK_NE(dynamic_cast<T*>(this), nullptr);
+#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
+
+ return static_cast<T*>(this);
+ }
+
+ virtual ~Array() = default;
+
+ Array& operator=(const Array& other) = delete;
+ };
+
+ // Represents an HLO instruction that was not analyzable by this
+ // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing
+ // HloInstruction.
+ class UnknownArray : public Array {
+ public:
+ Kind kind() const override { return kUnknown; }
+ const Shape& shape() const override { return instruction().shape(); }
+ const HloInstruction& instruction() const { return instruction_; }
+
+ private:
+ explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {}
+
+ const HloInstruction& instruction_;
+
+ friend class IndexedArrayAnalysis;
+ };
+
+ // Represents a constant value. This constant value may be present in the HLO
+ // module being analyzed, or it could have been created on the fly by the
+ // analysis.
+ class ConstantArray : public Array {
+ public:
+ Kind kind() const override { return kConstant; }
+ const Shape& shape() const override { return literal()->shape(); }
+ const Literal* literal() const { return literal_; }
+
+ private:
+ explicit ConstantArray(const Literal* literal) : literal_(literal) {}
+ const Literal* literal_;
+
+ friend class IndexedArrayAnalysis;
+ };
+
+ // ---------------------------------------------------------------------------
+ // Indexed Array Overview
+ // ---------------------------------------------------------------------------
+ //
+ // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this
+ // analysis. ScalarIndexedConstantArray is just a specialization of
+ // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this
+ // overview.
+ //
+ // A ScalarIndexedArray represents an array that can be computed by indexing
+ // into a "source" array using an "indices" tensor. A simple example is a
+ // gather operation gathering 12 rows out of a [100,100] matrix -- such an
+ // operation will be represented by an instance of a ScalarIndexedArray with
+ // the [100,100] matrix as the "source" array and the [12]-shaped indices
+ // array as the "indices" tensor. The ScalarIndexedArray operation itself
+ // will be of shape [12,100] (assuming we were gathering with axis=0).
+ //
+ // Gather operations are not the only operation that maps to
+ // ScalarIndexedArray instances (if that were true there would be little point
+ // in having a separate analysis). We can often infer ScalarIndexedArrays for
+ // other operations too. For instance, consider:
+ //
+ // %source = f32[100,100] constant
+ // %indices = s32[12] ...
+ // %gather = f32[12,100] ... gather from %source using %indices at axis 0
+ // %dot = dot(%gather, other_constant) [canonical contracting dims]
+ //
+ // The dot operation itself is also a ScalarIndexedArray with source =
+ // dot(constant, other_constant) and indices = %indices. A reshape of %gather
+ // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately
+ // reshaped constant and indices = %indices.
+
+ // Represents the result of a gather operation. This gather operation may
+ // explicitly be present in the HLO module being analyzed, or it could have
+ // been created on the fly by the analysis.
+ //
+ // An instance of ScalarIndexedArray represents a array whose I'th element can
+ // be mapped to the J'th element of the `source` array (where I and J are
+ // multidimensional indices) in this way:
+ //
+ // I' = remove components at positions `output_dims` from I
+ // G' = remove components not at positions `output_dims` from I
+ // T = indices[G']
+ // J = I' with T inserted at position `source_dim`
+ //
+ // For example, if source is of shape [11,13,17,19], indices is of shape
+ // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of
+ // shape [23,11,29,19] and the output index [A,B,C,D,E] is mapped to the input
+ // index [B,D,indices[A,C],E].
+ class ScalarIndexedArray : public Array {
+ public:
+ Kind kind() const override { return kScalarIndexed; }
+ const Shape& shape() const override { return shape_; }
+
+ Array* source() const { return source_; }
+ Array* indices() const { return indices_; }
+ int64 source_dim() const { return source_dim_; }
+ tensorflow::gtl::ArraySlice<int64> output_dims() const {
+ return output_dims_;
+ }
+
+ private:
+ explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
+ std::vector<int64> output_dims, Shape shape)
+ : source_(source),
+ indices_(indices),
+ source_dim_(source_dim),
+ output_dims_(std::move(output_dims)),
+ shape_(std::move(shape)) {}
+
+ Array* source_;
+ Array* indices_;
+ int64 source_dim_;
+ std::vector<int64> output_dims_;
+ Shape shape_;
+
+ friend class IndexedArrayAnalysis;
+ };
+
+ // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to
+ // have a ConstantArray instance as the source. This is an ergonomic
+ // concession -- in theory it is possible to just keep ScalarIndexedArray and
+ // check source()->kind().
+ class ScalarIndexedConstantArray : public ScalarIndexedArray {
+ public:
+ Kind kind() const override { return kScalarIndexedConstant; }
+
+ const Literal& literal() const {
+ return *source()->as<ConstantArray>()->literal();
+ }
+
+ private:
+ explicit ScalarIndexedConstantArray(Array* source, Array* indices,
+ int64 source_dim,
+ std::vector<int64> output_dims,
+ Shape shape)
+ : ScalarIndexedArray(source, indices, source_dim,
+ std::move(output_dims), std::move(shape)) {
+ CHECK(dynamic_cast<ConstantArray*>(source));
+ }
+
+ friend class IndexedArrayAnalysis;
+ };
+
+ // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance
+ // keeps ownership of the returned Array instance.
+ //
+ // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO
+ // instructions to IndexedArrayAnalysis::Array instances. This entire cache
+ // becomes stale and may cause the analysis to return incorrect results if any
+ // transitive operand (stopping at the containing computation) is modified for
+ // any HLO instruction on which GetArrayFor has been invoked.
+ //
+ // NB! By inspecting the implementation, you may be able to infer a stronger
+ // caching guarantee than what is mentioned above. Nevertheless, what is
+ // stated above is the contract.
+ Array* GetArrayFor(const HloInstruction* instr);
+
+ // Pretty-prints the expression rooted at `root`.
+ string ToString(Array* root);
+
+ private:
+ // Helper function that ensures that every HLO instruction that is
+ // transitively used by `root` has an entry in `cache_`.
+ void TraverseAndPopulateCache(const HloInstruction* root);
+
+ // Creates an Array instance for `instr` under the assumption that all
+ // operations of `instr` are present in `cache_`.
+ Array* ComputeArrayFor(const HloInstruction* instr);
+
+ Array* ComputeArrayForConstant(const Literal& literal);
+
+ Array* ComputeArrayForGather(const Shape& shape,
+ const GatherDimensionNumbers& dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds,
+ Array* source, Array* indices);
+
+ // This tries to fold a ScalarIndexedArray which has another
+ // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
+ // ScalarIndexedArray as indices. If `source` happened to be a
+ // ScalarIndexedConstantArray this can result in an expression that is more
+ // canonical.
+ //
+ // As an example, consider a gather operation, G0, gathering 7 elements from
+ // an array "Arr" of shape [100] resulting in an array of shape [7], and a
+ // second gather operation, G1, which gathers 3 elements out of the result of
+ // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0
+ // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can
+ // instead rewrite G1 to gather directly from "Arr" with the three indices
+ // from I0 as per I1. In other words, we can rewrite:
+ //
+ // G0 = [Arr[i] for i in I0]
+ // G1 = [G0[i] for i in I1]
+ //
+ // into
+ //
+ // I2 = [I0[i] for i in I1]
+ // G1 = [Arr[i] for i in I2]
+ ScalarIndexedArray* FoldGatherOfGather(
+ ScalarIndexedArray* source, Array* indices, int64 source_dim,
+ tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+
+ template <typename T, typename... Args>
+ T* Construct(Args&&... args) {
+ T* new_tensor = new T(std::forward<Args>(args)...);
+ owned_tensors_.push_back(std::unique_ptr<T>(new_tensor));
+ return new_tensor;
+ }
+
+ ScalarIndexedArray* ConstructScalarIndexedArray(
+ Array* source, Array* indices, int64 source_dim,
+ std::vector<int64> output_dims, Shape shape) {
+ if (source->kind() == Array::kConstant) {
+ return Construct<ScalarIndexedConstantArray>(source, indices, source_dim,
+ std::move(output_dims),
+ std::move(shape));
+ } else {
+ return Construct<ScalarIndexedArray>(source, indices, source_dim,
+ std::move(output_dims),
+ std::move(shape));
+ }
+ }
+
+ std::vector<std::unique_ptr<Array>> owned_tensors_;
+ std::vector<std::unique_ptr<Literal>> owned_literals_;
+ tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
+};
+
+// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
+// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
+// unconditionally add to the regular HLO pass pipeline.
+class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override;
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+namespace xla {
+namespace {
+class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
+ protected:
+ void AssertArrayForRootExpressionIs(const string& hlo_text,
+ const string& root_expression) {
+ IndexedArrayAnalysis indexed_tensor_analysis;
+ ParseAndVerifyModule(hlo_text);
+
+ string result =
+ indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor(
+ module().entry_computation()->root_instruction()));
+ LOG(INFO) << result;
+ ASSERT_EQ(result, root_expression);
+ }
+};
+
+TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text,
+ "(scalar-indexed %operand %indices 0->[0])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
+ indices = s32[5] parameter(0)
+ ROOT gather = s32[5,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
+ indices_a = s32[5] parameter(0)
+ indices_b = s32[2] parameter(1)
+ gather_a = s32[5,3] gather(operand, indices_a),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+ ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text,
+ "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a "
+ "%indices_b 0->[0]) 0->[0])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,2] parameter(0)
+ indices_a = s32[5,7] parameter(1)
+ indices_b = s32[2] parameter(2)
+ gather_a = s32[5,3,7] gather(operand, indices_a),
+ output_window_dims={1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=2,
+ window_bounds={3,1}
+ ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
+ output_window_dims={0,1},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={2},
+ index_vector_dim=1,
+ window_bounds={5,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text,
+ "(scalar-indexed %operand (scalar-indexed "
+ "%indices_a %indices_b 1->[1]) 1->[0,2])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,6] parameter(0)
+ indices_a = s32[2] parameter(1)
+ indices_b = s32[5,7] parameter(2)
+ gather_a = s32[2,6] gather(operand, indices_a),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,6}
+ ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,6}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text,
+ "(scalar-indexed %operand (scalar-indexed "
+ "%indices_a %indices_b 0->[0,1]) 0->[0,2])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,2] parameter(0)
+ indices_a = s32[5,7] parameter(1)
+ indices_b = s32[4,8] parameter(2)
+ gather_a = s32[5,3,7] gather(operand, indices_a),
+ output_window_dims={1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=2,
+ window_bounds={3,1}
+ ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
+ output_window_dims={1,2},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={2},
+ index_vector_dim=2,
+ window_bounds={5,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text,
+ "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b "
+ "1->[0,2]) 1->[0,1,3])");
+}
+} // namespace
+} // namespace xla