// TODO(b/62904716): Bias array should become 1-D when padding removed.
const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
- CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
+ int operand_channel_increment = 0;
+ if (operand_shape.dimensions_count() >= 1 &&
+ operand_shape.dims(operand_shape.dimensions_count() - 1) ==
+ bias_shape.dims(bias_shape.dimensions_count() - 1)) {
+ operand_channel_increment = 1;
+ } else if (operand_shape.dimensions_count() == 0 ||
+ operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
+ operand_channel_increment = 0;
+ } else {
+ LOG(FATAL) << "Operand shape mismatch.";
+ }
enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
? OpType::BiasMinusOperand
: OpType::OperandMinusBias;
+ int operand_channel = 0;
for (int i = 0; i < depth; i++) {
float& bias_val = bias_data[i];
- const float operand_val = operand_data[i];
+ const float operand_val = operand_data[operand_channel];
if (optype == OpType::BiasPlusOperand) {
bias_val += operand_val;
} else if (optype == OpType::BiasMinusOperand) {
} else {
LOG(FATAL) << "Should not get here.";
}
+ operand_channel += operand_channel_increment;
}
}