From 13b993095f155bd4dd7fc3b057a7b5043ef0a06c Mon Sep 17 00:00:00 2001 From: Mingxing Tan Date: Tue, 20 Mar 2018 12:54:01 -0700 Subject: [PATCH] Add broadcasting support for fused add or sub. PiperOrigin-RevId: 189792542 --- .../fuse_binary_into_preceding_affine.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 5b57178..76c6be0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -50,7 +50,17 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, // TODO(b/62904716): Bias array should become 1-D when padding removed. const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); - CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); + int operand_channel_increment = 0; + if (operand_shape.dimensions_count() >= 1 && + operand_shape.dims(operand_shape.dimensions_count() - 1) == + bias_shape.dims(bias_shape.dimensions_count() - 1)) { + operand_channel_increment = 1; + } else if (operand_shape.dimensions_count() == 0 || + operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { + operand_channel_increment = 0; + } else { + LOG(FATAL) << "Operand shape mismatch."; + } enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; @@ -60,9 +70,10 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, ? OpType::BiasMinusOperand : OpType::OperandMinusBias; + int operand_channel = 0; for (int i = 0; i < depth; i++) { float& bias_val = bias_data[i]; - const float operand_val = operand_data[i]; + const float operand_val = operand_data[operand_channel]; if (optype == OpType::BiasPlusOperand) { bias_val += operand_val; } else if (optype == OpType::BiasMinusOperand) { @@ -72,6 +83,7 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } else { LOG(FATAL) << "Should not get here."; } + operand_channel += operand_channel_increment; } } -- 2.7.4