const HloInstruction* instr) {
Array* computed_array;
if (instr->IsElementwise() && instr->operand_count() == 1) {
- TF_ASSIGN_OR_RETURN(computed_array,
- ComputeArrayForElementwiseUnaryOp(
- instr, FindOrDie(cache_, instr->operand(0))));
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForElementwiseUnaryOp(
+ instr->opcode(), FindOrDie(cache_, instr->operand(0))));
} else if (instr->IsElementwise() && instr->operand_count() == 2) {
- TF_ASSIGN_OR_RETURN(computed_array,
- ComputeArrayForElementwiseBinaryOp(
- instr, FindOrDie(cache_, instr->operand(0)),
- FindOrDie(cache_, instr->operand(1))));
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForElementwiseBinaryOp(
+ instr->opcode(), FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(computed_array,
ComputeArrayForConstant(instr->literal()));
}
StatusOr<Analysis::Array*>
-IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
- const HloInstruction* instr, Array* lhs, Array* rhs) {
+IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+ Array* lhs,
+ Array* rhs) {
// Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
// => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
//
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- instr->opcode(), scalar_indexed_const->literal(),
- *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- instr->opcode(), *inner_broadcast_result,
- scalar_indexed_const->literal())));
+ opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
}
StatusOr<Analysis::Array*>
-IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
- const HloInstruction* instr, Array* operand) {
+IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+ Array* operand) {
auto* scalar_indexed_const =
dynamic_cast<ScalarIndexedConstantArray*>(operand);
if (operand == nullptr) {
TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
- instr->opcode(), scalar_indexed_const->literal())));
+ opcode, scalar_indexed_const->literal())));
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
return Construct<ScalarIndexedConstantArray>(
new_source, scalar_indexed_const->indices(),
StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
- StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(
- const HloInstruction* instr, Array* lhs, Array* rhs);
- StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(
- const HloInstruction* instr, Array* operand);
+ StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+ Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+ Array* operand);
template <typename T, typename... Args>
T* Construct(Args&&... args) {