#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
using tensorflow::str_util::Join;
} // namespace
-string IndexedArrayAnalysis::ToString(Array* root) {
+// TODO(sanjoy): Make this pass StatusOr safe.
+
+string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
auto* unknown_tensor = root->as<UnknownArray>();
}
case Array::kConstant: {
+ if (print_constants) {
+ string contents = root->as<ConstantArray>()->literal()->ToString();
+ return tensorflow::strings::StrCat(
+ "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
+ ")");
+ }
return tensorflow::strings::StrCat(
"(constant ", ShapeUtil::HumanString(root->shape()), ")");
}
? "scalar-indexed-const"
: "scalar-indexed";
return tensorflow::strings::StrCat(
- "(", name, " ", ToString(indexed_array->source()), " ",
- ToString(indexed_array->indices()), " ", indexed_array->source_dim(),
- "->[", Join(indexed_array->output_dims(), ","), "])");
+ "(", name, " ", ToString(indexed_array->source(), print_constants),
+ " ", ToString(indexed_array->indices(), print_constants), " ",
+ indexed_array->source_dim(), "->[",
+ Join(indexed_array->output_dims(), ","), "])");
}
}
}
Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
const HloInstruction* instr) {
Array* computed_array;
- switch (instr->opcode()) {
- default:
- computed_array = nullptr;
- break;
- case HloOpcode::kConstant:
- computed_array = ComputeArrayForConstant(instr->literal());
- break;
- case HloOpcode::kGather:
- computed_array = ComputeArrayForGather(
- instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
- FindOrDie(cache_, instr->operand(1)));
- break;
- case HloOpcode::kReshape:
- computed_array = ComputeArrayForReshape(
- instr->shape(), FindOrDie(cache_, instr->operand(0)));
- break;
+ if (instr->IsElementwise() && instr->operand_count() == 1) {
+ computed_array = ComputeArrayForElementwiseUnaryOp(
+ instr, FindOrDie(cache_, instr->operand(0)));
+ } else if (instr->IsElementwise() && instr->operand_count() == 2) {
+ computed_array = ComputeArrayForElementwiseBinaryOp(
+ instr, FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1)));
+ } else if (instr->opcode() == HloOpcode::kConstant) {
+ computed_array = ComputeArrayForConstant(instr->literal());
+ } else if (instr->opcode() == HloOpcode::kGather) {
+ computed_array = ComputeArrayForGather(
+ instr->shape(), instr->gather_dimension_numbers(),
+ instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1)));
+ } else if (instr->opcode() == HloOpcode::kReshape) {
+ computed_array = ComputeArrayForReshape(
+ instr->shape(), FindOrDie(cache_, instr->operand(0)));
+ } else {
+ computed_array = nullptr;
}
if (!computed_array) {
IndexComponent::Ungathered);
// Simulate the first gather.
- simulated_index.erase(simulated_index.begin() + source->source_dim());
+ EraseAt(&simulated_index, source->source_dim());
for (int64 gather_dim : source->output_dims()) {
simulated_index.insert(simulated_index.begin() + gather_dim,
IndexComponent::GatheredFirst);
}
// Simulate the second gather.
- simulated_index.erase(simulated_index.begin() + source_dim);
+ EraseAt(&simulated_index, source_dim);
for (int64 output_dim : output_dims) {
simulated_index.insert(simulated_index.begin() + output_dim,
IndexComponent::GatheredSecond);
int64 output_dim = scalar_indexed->output_dims()[i];
int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
reshape_passthrough_dims, output_dim);
- new_scalar_indexed_source_shape.erase(
- new_scalar_indexed_source_shape.begin() + output_dim_after_reshape);
+ EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
}
// After this, we need to add in the dimension that will be the source
output_dims_for_new_scalar_indexed_node, shape);
}
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
+ const HloInstruction* instr, Array* lhs, Array* rhs) {
+ // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
+ // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
+ //
+ // We can do this if every output dimension from the scalar-indexed node is a
+ // broadcasted dimension for the broadcast node. Informally, the precondition
+ // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
+ // that are not output-dims for the scalar-indexed node. In other words, for
+ // every assignment to the non-output dims in IDX we have a "constant" LHS to
+ // the BinaryOp. This transform propagates this "constant" to the source for
+ // the scalar-indexed node.
+
+ ScalarIndexedConstantArray* lhs_scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(lhs);
+ ScalarIndexedConstantArray* rhs_scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(rhs);
+
+ bool lhs_is_indexed;
+
+ // One of the operands must be scalar-indexed and the other must be a
+ // broadcast of a constant.
+ if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
+ lhs_is_indexed = true;
+ } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
+ lhs_is_indexed = false;
+ } else {
+ return nullptr;
+ }
+
+ ScalarIndexedConstantArray* scalar_indexed_const =
+ lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
+ UnknownArray* candidate_broadcast_array =
+ dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
+ if (!candidate_broadcast_array ||
+ candidate_broadcast_array->instruction().opcode() !=
+ HloOpcode::kBroadcast) {
+ return nullptr;
+ }
+
+ const HloInstruction* broadcast_instr =
+ &candidate_broadcast_array->instruction();
+ const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
+ if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
+ return nullptr;
+ }
+
+ ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+ auto is_broadcasted_dim = [&](int64 output_dim) {
+ return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+ };
+
+ // All of the output dims must be "broadcasted" dims for the other operand.
+ if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+ return nullptr;
+ }
+
+ // To figure out the broadcast dimensions for the (constant) source for the
+ // scalar-indexed node, we "simulate" the index transformation done by the
+ // existing broadcsat:
+ enum class IndexComponent { Broadcasted, NotBroadcasted };
+ std::vector<IndexComponent> simulated_index(
+ broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
+ for (int64 broadcast_dim : broadcast_dims) {
+ simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
+ }
+
+ // The scalar-indexed node "removes" the source dim and "inserts" the output
+ // dims. We do the opposite here to undo the scalar-indexed operation.
+ ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+ for (int64 i = output_dims.size() - 1; i >= 0; --i) {
+ CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
+ EraseAt(&simulated_index, output_dims[i]);
+ }
+
+ InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
+ IndexComponent::Broadcasted);
+
+ // new_inner_broadcast_dims holds the broadcast dimensions for the inner
+ // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to
+ // new_inner_broadcast_dims.
+ std::vector<int64> new_inner_broadcast_dims;
+ for (int64 i = 0; i < simulated_index.size(); i++) {
+ if (simulated_index[i] == IndexComponent::NotBroadcasted) {
+ new_inner_broadcast_dims.push_back(i);
+ }
+ }
+
+ // inner_broadcast_result is the Broadcast'(Const0) bit in
+ // BinaryOp(Broadcast'(Const0), Const1)
+ std::unique_ptr<Literal> inner_broadcast_result =
+ broadcast_const_operand->literal()
+ .Broadcast(scalar_indexed_const->source()->shape(),
+ new_inner_broadcast_dims)
+ .ConsumeValueOrDie();
+
+ // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
+ const Literal* literal_for_new_source;
+ if (lhs_is_indexed) {
+ literal_for_new_source =
+ TakeOwnership(HloEvaluator{}
+ .EvaluateElementwiseBinaryOp(
+ instr->opcode(), scalar_indexed_const->literal(),
+ *inner_broadcast_result)
+ .ConsumeValueOrDie());
+ } else {
+ literal_for_new_source =
+ TakeOwnership(HloEvaluator{}
+ .EvaluateElementwiseBinaryOp(
+ instr->opcode(), *inner_broadcast_result,
+ scalar_indexed_const->literal())
+ .ConsumeValueOrDie());
+ }
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, scalar_indexed_const->indices(),
+ scalar_indexed_const->source_dim(),
+ std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+ scalar_indexed_const->output_dims().end()),
+ scalar_indexed_const->shape());
+}
+
+Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
+ const HloInstruction* instr, Array* operand) {
+ auto* scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(operand);
+ if (operand == nullptr) {
+ return nullptr;
+ }
+
+ // Fold UnaryOp(ScalarIndexed(Const, Indices))
+ // => ScalarIndexed(UnaryOp(Const), Indices)
+
+ Literal* literal_for_new_source =
+ TakeOwnership(HloEvaluator{}
+ .EvaluateElementwiseUnaryOp(
+ instr->opcode(), scalar_indexed_const->literal())
+ .ConsumeValueOrDie());
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, scalar_indexed_const->indices(),
+ scalar_indexed_const->source_dim(),
+ std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+ scalar_indexed_const->output_dims().end()),
+ scalar_indexed_const->shape());
+}
+
tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
+ AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+ /*print_constants=*/false);
+ }
+
+ void AssertArrayWithConstantsForRootExpressionIs(
+ const string& hlo_text, const string& root_expression) {
+ AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+ /*print_constants=*/true);
+ }
+
+ private:
+ void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
+ const string& root_expression,
+ bool print_constants) {
IndexedArrayAnalysis indexed_tensor_analysis;
ParseAndVerifyModule(hlo_text);
- string result =
- indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor(
- module().entry_computation()->root_instruction()));
+ string result = indexed_tensor_analysis.ToString(
+ indexed_tensor_analysis.GetArrayFor(
+ module().entry_computation()->root_instruction()),
+ print_constants);
LOG(INFO) << result;
ASSERT_EQ(result, root_expression);
}
AssertArrayForRootExpressionIs(hlo_text, "%reshape");
}
+
+TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
+ string hlo_text = R"(
+HloModule UnaryOpOfGather
+
+ENTRY main {
+ operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ indices = s32[5] parameter(0)
+ gather = f32[5,4] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT tanh = f32[5,4] tanh(gather)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant f32[3,4] f32[3,4] {
+ { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
+ { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
+ { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
+ string hlo_text = R"(
+HloModule AddBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 6, 7, 8, 9 },
+ { 6, 8, 7, 9 },
+ { 9, 8, 7, 6 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+ SubtractBroadcastedScalarWithGather_GatherIsLhs) {
+ string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { -4, -3, -2, -1 },
+ { -4, -2, -3, -1 },
+ { -1, -2, -3, -4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+ SubtractBroadcastedScalarWithGather_GatherIsRhs) {
+ string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 4, 3, 2, 1 },
+ { 4, 2, 3, 1 },
+ { 1, 2, 3, 4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
+ string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant_vect = s32[4] constant({10,11,12,13})
+ constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 11, 13, 15, 17 },
+ { 11, 14, 14, 17 },
+ { 14, 14, 14, 14 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
+ string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant_vect = s32[5] constant({10,11,12,13,14})
+ constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%add");
+}
} // namespace
} // namespace xla