From 131c96cab8a51347bfe0a1582728f38d2d01f9d3 Mon Sep 17 00:00:00 2001 From: seongwoo chae Date: Tue, 25 Aug 2020 15:28:12 +0900 Subject: [PATCH] [luci] Introduce FuseBatchNormWithTConv pass (#3971) This commit introduces FuseBatchNormWithTConv pass to luci. ONE-DCO-1.0-Signed-off-by: seongwoo --- compiler/luci/pass/include/luci/CircleOptimizer.h | 1 + .../include/luci/Pass/FuseBatchNormWithTConv.h | 37 +++++ compiler/luci/pass/src/CircleOptimizer.cpp | 5 + compiler/luci/pass/src/FuseBatchNormWithTConv.cpp | 159 +++++++++++++++++++++ 4 files changed, 202 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h create mode 100644 compiler/luci/pass/src/FuseBatchNormWithTConv.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 883943c..a832844 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -32,6 +32,7 @@ public: { enum Algorithm { + FuseBatchNormWithTConv, FuseBCQ, FuseInstanceNorm, ResolveCustomOpAdd, diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h new file mode 100644 index 0000000..d3e930a --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_BATCH_NORM_WITH_TCONV_PASS_H__ +#define __LUCI_FUSE_BATCH_NORM_WITH_TCONV_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Batch Normalization into CircleTransposeConv + */ +struct FuseBatchNormWithTConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseBatchNormWithTConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_BATCH_NORM_WITH_TCONV_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 4701fbf..5788294 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -16,6 +16,7 @@ #include "luci/CircleOptimizer.h" +#include "luci/Pass/FuseBatchNormWithTConv.h" #include "luci/Pass/FuseBCQPass.h" #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/ResolveCustomOpAddPass.h" @@ -126,6 +127,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseBatchNormWithTConv)) + { + phase.emplace_back(std::make_unique()); + } // Shape inference is needed for added nodes doing above transformations phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp new file mode 100644 index 0000000..e39455b --- /dev/null +++ b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseBatchNormWithTConv.h" + +#include + +namespace +{ +/** + * NOTE TF's fusedBatchNorm is converted to mul and add of Circle. + * + * BEFORE + * + * [CircleTransposeConv] + * | + * [mul] + * | + * [add] + * AFTER + * + * [CircleTransposeConv] + */ +bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv) +{ + // check whether it has bias or not. This optimization works only if it doesn't. + auto bias = dynamic_cast(tconv->bias()); + if (not bias) + return false; + + // get weight of tconv + auto filter = dynamic_cast(tconv->filter()); + if (not filter) + return false; + if (filter->dtype() != loco::DataType::FLOAT32) + return false; + + // get mul node + auto tconv_output = loco::succs(tconv); + assert(tconv_output.size() == 1); + auto mul = dynamic_cast(*tconv_output.begin()); + if (not mul) + return false; + if (mul->dtype() != loco::DataType::FLOAT32) + return false; + + // get add node + auto mul_output = loco::succs(mul); + assert(mul_output.size() == 1); + auto add = dynamic_cast(*mul_output.begin()); + if (not add) + return false; + if (add->dtype() != loco::DataType::FLOAT32) + return false; + if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && + add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + return false; + + // get scale of batchnorm + auto scale = dynamic_cast(mul->y()); + if (not scale) + return false; + + // scale dim(0) == tconv filter channel dim + if (filter->rank() != 4) + return false; + auto filter_channel_dim = filter->dim(3).value(); + if (scale->rank() != 1) + return false; + auto scale_dim = scale->dim(0).value(); + if (filter_channel_dim != scale_dim) + return false; + + // get shift of batchnorm + auto shift = dynamic_cast(add->y()); + if (not shift) + return false; + + // shift dim(0) == tconv filter channel dim + if (shift->rank() != 1) + return false; + auto shift_dim = shift->dim(0).value(); + if (filter_channel_dim != shift_dim) + return false; + + // filter weight = filter weight * mul(scale) + add(shift) + uint32_t filter_batch_dim = filter->dim(0).value(); + uint32_t filter_height_dim = filter->dim(1).value(); + uint32_t filter_width_dim = filter->dim(2).value(); + for (uint32_t c = 0; c < filter_channel_dim; c++) + { + for (uint32_t n = 0; n < filter_batch_dim; n++) + { + for (uint32_t h = 0; h < filter_height_dim; h++) + { + for (uint32_t w = 0; w < filter_width_dim; w++) + { + uint32_t offset = n * filter_height_dim * filter_width_dim * filter_channel_dim + + h * filter_width_dim * filter_channel_dim + w * filter_channel_dim + c; + filter->at(offset) *= scale->at(c); + } + } + } + } + + // fuse shift with transposed conv + tconv->bias(shift); + + if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) + { + // separate relu op from add op + auto relu = add->graph()->nodes()->create(); + relu->features(tconv); + + // remove mul node + replace(add).with(relu); + } + else + { + replace(add).with(tconv); + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseBatchNormWithTConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto tconv = dynamic_cast(node); + if (not tconv) + continue; + + changed |= fused_batch_norm_with_tconv(tconv); + } + + return changed; +} + +} // namespace luci -- 2.7.4