Add broadcasting support for fused add or sub.
authorMingxing Tan <tanmingxing@google.com>
Tue, 20 Mar 2018 19:54:01 +0000 (12:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 19:57:00 +0000 (12:57 -0700)
PiperOrigin-RevId: 189792542

tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc

index 5b57178..76c6be0 100644 (file)
@@ -50,7 +50,17 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
 
   // TODO(b/62904716): Bias array should become 1-D when padding removed.
   const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
-  CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
+  int operand_channel_increment = 0;
+  if (operand_shape.dimensions_count() >= 1 &&
+      operand_shape.dims(operand_shape.dimensions_count() - 1) ==
+          bias_shape.dims(bias_shape.dimensions_count() - 1)) {
+    operand_channel_increment = 1;
+  } else if (operand_shape.dimensions_count() == 0 ||
+             operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
+    operand_channel_increment = 0;
+  } else {
+    LOG(FATAL) << "Operand shape mismatch.";
+  }
 
   enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
 
@@ -60,9 +70,10 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
                                   ? OpType::BiasMinusOperand
                                   : OpType::OperandMinusBias;
 
+  int operand_channel = 0;
   for (int i = 0; i < depth; i++) {
     float& bias_val = bias_data[i];
-    const float operand_val = operand_data[i];
+    const float operand_val = operand_data[operand_channel];
     if (optype == OpType::BiasPlusOperand) {
       bias_val += operand_val;
     } else if (optype == OpType::BiasMinusOperand) {
@@ -72,6 +83,7 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
     } else {
       LOG(FATAL) << "Should not get here.";
     }
+    operand_channel += operand_channel_increment;
   }
 }