From 5e3c65305435ff942fad78f0e7c21df6018137ed Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 1 Nov 2019 14:36:01 +0900 Subject: [PATCH] [exo] Fuse Squared Difference (#8654) This commit introduces a pass to fuse Squared Difference pattern Signed-off-by: Cheongyo Bahk --- .../exo/src/Pass/FuseSquaredDifferencePass.cpp | 88 ++++++++++++++++++++++ compiler/exo/src/Pass/FuseSquaredDifferencePass.h | 49 ++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp create mode 100644 compiler/exo/src/Pass/FuseSquaredDifferencePass.h diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp new file mode 100644 index 0000000..497c3ee --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseSquaredDifferencePass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include + +namespace +{ + +/** + * @return Casted TFLMul for fusable candidate, nullptr otherwise + * + * This helper checkes fusability with following conditions: + * - TFLMul has no activation + * - TFLMul's first and second arguments are equal and TFLSub + */ +locoex::TFLMul *as_candidate(loco::Node *node) +{ + auto mul = dynamic_cast(node); + if (not mul) + return nullptr; + + // Cannot fuse mul with activation function + if (mul->fusedActivationFunction() != locoex::FusedActFunc::NONE) + return nullptr; + + if (mul->x() != mul->y()) + return nullptr; + + if (not dynamic_cast(mul->x())) + return nullptr; + + return mul; +} + +void fuse_squared_difference(locoex::TFLMul *mul) +{ + auto sub = dynamic_cast(mul->x()); + EXO_ASSERT(sub, "sub should be valid at this point"); + + // TFLSquaredDifference to replace + auto sq_diff = mul->graph()->nodes()->create(); + sq_diff->x(sub->x()); + sq_diff->y(sub->y()); + + // replace + loco::replace(mul).with(sq_diff); +} + +} // namespace + +namespace exo +{ + +bool FuseSquaredDifferencePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto mul = as_candidate(node)) + { + fuse_squared_difference(mul); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.h b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h new file mode 100644 index 0000000..dbc1514 --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_SQUARED_DIFFERENCE_PASS_H__ +#define __FUSE_SQUARED_DIFFERENCE_PASS_H__ + +#include + +namespace exo +{ + +/** + * @brief Class to fuse SquaredDifference pattern + * + * + * + * A --- TFLSub --- TFLMul --- C + * / \ / + * B ---- ----- + * + * + * + * A --- TFLSquaredDifference --- C + * / + * B ---- + */ +struct FuseSquaredDifferencePass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseSquaredDifferencePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_SQUARED_DIFFERENCE_PASS_H__ -- 2.7.4