return std::move(literal);
}
-// Matches binary addition computations.
-bool LooksLikeSum(const HloComputation& computation) {
+enum class ConstantType { kUnknown, kZero, kOne };
+
+// Return the constant type required by this computation, if known.
+ConstantType GetInitValue(const HloComputation& computation) {
const HloInstruction* const root = computation.root_instruction();
- return root->opcode() == HloOpcode::kAdd &&
- computation.num_parameters() == 2 &&
- root->operand(0)->opcode() == HloOpcode::kParameter &&
- root->operand(1)->opcode() == HloOpcode::kParameter &&
- root->operand(0) != root->operand(1);
+ if (computation.num_parameters() != 2 ||
+ root->operand(0)->opcode() != HloOpcode::kParameter ||
+ root->operand(1)->opcode() != HloOpcode::kParameter ||
+ root->operand(0) == root->operand(1)) {
+ return ConstantType::kUnknown;
+ }
+
+ switch (root->opcode()) {
+ case HloOpcode::kAdd:
+ return ConstantType::kZero;
+ case HloOpcode::kMultiply:
+ return ConstantType::kOne;
+ default:
+ return ConstantType::kUnknown;
+ }
}
-// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition,
-// which requires an init_value of 0 rather than a random value.
-bool NeedsZeroInitValue(const HloUse& use) {
+// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
+// initialization value.
+bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
return (
((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1 && LooksLikeSum(*instruction->to_apply())) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2 &&
- LooksLikeSum(*instruction->scatter())));
+ op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
}
// Generate random values that are constrained to the input_shape minus the
auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
fused_uses.end());
- } else if (NeedsZeroInitValue(use)) {
+ } else if (NeedsInitValue(use)) {
constrained_uses.push_back(instruction);
} else if (opcode == HloOpcode::kConvert ||
opcode == HloOpcode::kReducePrecision) {
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
- HloInstruction* needs_zero = nullptr;
+ HloInstruction* needs_constant = nullptr;
+ ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ needs_constant = use;
+ constant_type = GetInitValue(*use->to_apply());
+ break;
+
case HloOpcode::kSelectAndScatter:
- needs_zero = use;
+ needs_constant = use;
+ constant_type = GetInitValue(*use->scatter());
break;
default:
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_zero != nullptr) {
+ if (needs_index != nullptr && needs_constant != nullptr) {
return Unimplemented(
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "zero: %s\n",
- needs_index->ToString().c_str(), needs_zero->ToString().c_str());
+ "constant: %s\n",
+ needs_index->ToString().c_str(), needs_constant->ToString().c_str());
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
needs_index->shape(), engine);
- } else if (needs_zero != nullptr) {
- return Literal::CreateFromShape(param.shape());
+ } else if (needs_constant != nullptr) {
+ switch (constant_type) {
+ case ConstantType::kZero:
+ return Literal::Zero(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kOne:
+ return Literal::One(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kUnknown:
+ // We want the identity element for the computation, but we don't really
+ // know what it is - so any value we generate will be just as wrong.
+ return MakeFakeLiteralInternal(param.shape(), engine);
+ }
} else {
return MakeFakeLiteralInternal(param.shape(), engine);
}