From a08f3ba80ace5fe9b4673e6a9b0d4542b93ac94b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Aug 2019 19:48:24 +0900 Subject: [PATCH] [nnc] Add transpose folding optimization (#6724) Constant fold sequence of operations Constant -> Transpose. Signed-off-by: Sergei Barannikov --- compiler/nnc/driver/Driver.cpp | 5 ++ .../passes/optimizations/ConstantFoldTranspose.h | 39 ++++++++++ compiler/nnc/passes/optimizations/CMakeLists.txt | 11 +-- .../passes/optimizations/ConstantFoldTranspose.cpp | 83 ++++++++++++++++++++++ 4 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h create mode 100644 compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp diff --git a/compiler/nnc/driver/Driver.cpp b/compiler/nnc/driver/Driver.cpp index 117bb06..cd49be5 100644 --- a/compiler/nnc/driver/Driver.cpp +++ b/compiler/nnc/driver/Driver.cpp @@ -24,6 +24,7 @@ #include "passes/acl_soft_backend/AclCppGenerator.h" #include "passes/optimizations/CombineTransposes.h" +#include "passes/optimizations/ConstantFoldTranspose.h" #include "passes/optimizations/RemoveDeadEnds.h" #include "passes/optimizations/FuseArithmeticOps.h" #include "passes/optimizations/SinkRelu.h" @@ -140,6 +141,10 @@ void Driver::registerBackendPass() void Driver::registerOptimizationPass() { + // TODO For now this optimization is mandatory. Do it optional when ACL backend is able to handle + // all transposes that come from importers. + _passManager.registerPass(std::unique_ptr(new ConstantFoldTranspose())); + if (cli::doOptimizationPass) { // TODO: maybe we should start managing the optimizations more intelligently? diff --git a/compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h b/compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h new file mode 100644 index 0000000..96e2070 --- /dev/null +++ b/compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_CONSTANT_FOLD_TRANSPOSE_H +#define NNCC_CONSTANT_FOLD_TRANSPOSE_H + +#include "pass/Pass.h" + +namespace nnc +{ + +class ConstantFoldTranspose : public Pass +{ +public: + PassData run(PassData data) override; + + std::string getName() override + { + static const std::string name("opt_constant_fold_transpose"); + return name; + }; +}; + +} // namespace nnc + +#endif // NNCC_CONSTANT_FOLD_TRANSPOSE_H diff --git a/compiler/nnc/passes/optimizations/CMakeLists.txt b/compiler/nnc/passes/optimizations/CMakeLists.txt index 94ba075..00799ac 100644 --- a/compiler/nnc/passes/optimizations/CMakeLists.txt +++ b/compiler/nnc/passes/optimizations/CMakeLists.txt @@ -1,9 +1,10 @@ set(OPTIMIZATIONS_SRC CombineTransposes.cpp - FuseArithmeticOps.cpp - RemoveDeadEnds.cpp - SinkRelu.cpp - SinkTranspose.cpp - OptimizationUtils.cpp) + ConstantFoldTranspose.cpp + FuseArithmeticOps.cpp + RemoveDeadEnds.cpp + SinkRelu.cpp + SinkTranspose.cpp + OptimizationUtils.cpp) nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC}) target_link_libraries(nnc_optimizations PRIVATE mir nnc_support) diff --git a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp new file mode 100644 index 0000000..69f2b21 --- /dev/null +++ b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/optimizations/ConstantFoldTranspose.h" +#include "passes/optimizations/OptimizationUtils.h" +#include "mir/GraphPatternMatcher.h" +#include "mir/ShapeRange.h" +#include "mir/Tensor.h" +#include "mir/ops/ConstantOp.h" +#include "mir/ops/TransposeOp.h" + +using namespace nnc; +using namespace mir; + +// Copy & paste from interpreter backend. +// TODO Extract this to a common place and use in both interpreter and optimizations. +static void transpose(const TensorVariant &arg, TensorVariant &res, + const std::vector &axis_order) +{ + Tensor arg_accessor(arg); + Tensor res_accessor(res); + + const auto &input_shape = arg.getShape(); + const int num_axes = static_cast(axis_order.size()); + assert(num_axes == input_shape.rank()); + + ShapeRange in_range(input_shape); + Index out_index(input_shape.rank()); + + for (const auto &in_index : in_range) + { + for (int i = 0; i < num_axes; ++i) + { + out_index.at(i) = in_index.at(axis_order[i]); + } + res_accessor.at(out_index) = arg_accessor.at(in_index); + } +} + +PassData ConstantFoldTranspose::run(PassData data) +{ + auto graph = static_cast(data); + + GraphPatternMatcher matcher(graph); + auto is_constant = [](const Operation *op) { return op->getType() == Operation::Type::constant; }; + auto is_transpose = [](const Operation *op) { + return op->getType() == Operation::Type::transpose; + }; + + auto matches = matcher.matchEdge(is_constant, is_transpose); + while (!matches.empty()) + { + for (const auto match : matches) + { + auto constant_op = dynamic_cast(match.first); + auto transpose_op = dynamic_cast(match.second); + + // FIXME Revise this when we've got type information in operations. + TensorVariant res(DataType::FLOAT32, transpose_op->getOutputShape(0)); + transpose(constant_op->getValue(), res, transpose_op->getAxisOrder()); + + auto new_op = graph->create("", res); + + graph->replaceNode(transpose_op, new_op); + opt_util::removeNodeIfUnused(graph, constant_op); + } + matches = matcher.matchEdge(is_constant, is_transpose); + } + return graph; +} -- 2.7.4