From 9319b13a23203f9f840ea3d65967a04a9034eaa6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 29 Aug 2019 22:02:44 +0300 Subject: [PATCH] [nnc] Implemented Data Format Switcher (#6932) * Implemented PoolOp visitor in switcher Signed-off-by: Pavel Iliutchenko --- compiler/nnc/driver/Driver.cpp | 15 +- compiler/nnc/driver/Driver.h | 2 +- .../passes/transformations/DataFormatSwitcher.h | 59 ++++++ compiler/nnc/passes/CMakeLists.txt | 1 + compiler/nnc/passes/transformations/CMakeLists.txt | 6 + .../passes/transformations/DataFormatSwitcher.cpp | 200 +++++++++++++++++++++ 6 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 compiler/nnc/include/passes/transformations/DataFormatSwitcher.h create mode 100644 compiler/nnc/passes/transformations/CMakeLists.txt create mode 100644 compiler/nnc/passes/transformations/DataFormatSwitcher.cpp diff --git a/compiler/nnc/driver/Driver.cpp b/compiler/nnc/driver/Driver.cpp index cd49be5..b3cb4b4 100644 --- a/compiler/nnc/driver/Driver.cpp +++ b/compiler/nnc/driver/Driver.cpp @@ -18,6 +18,8 @@ #include "passes/common_frontend/NNImporter.h" +#include "passes/transformations/DataFormatSwitcher.h" + #include "passes/interpreter/InterpreterPass.h" #include "passes/soft_backend/CPPGenerator.h" #include "passes/dot_dumper/DumperPass.h" @@ -113,21 +115,25 @@ void Driver::registerFrontendPass() * @brief Register backend pass * @throw DriverException if errors occurred */ -void Driver::registerBackendPass() +void Driver::registerBackendPassWithSwitcher() { - + std::unique_ptr data_format_pass; std::unique_ptr pass; if (cli::target == NNC_TARGET_ARM_CPP || cli::target == NNC_TARGET_X86_CPP) { + data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC)); pass = std::unique_ptr(new CPPCodeGenerator()); } else if (cli::target == NNC_TARGET_ARM_GPU_CPP) { + // TODO Change to DataFormat::NCHW when fix it in ACL + data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC)); pass = std::unique_ptr(new AclCppCodeGenerator()); } else if (cli::target == NNC_TARGET_INTERPRETER) { + data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC)); pass = std::unique_ptr(new InterpreterPass()); } else @@ -135,9 +141,10 @@ void Driver::registerBackendPass() assert(false && "invalid option value"); } + _passManager.registerPass(std::move(data_format_pass)); _passManager.registerPass(std::move(pass)); -} // registerBackendPass +} // registerBackendPassWithSwitcher void Driver::registerOptimizationPass() { @@ -165,7 +172,7 @@ void Driver::runDriver() // register passes registerFrontendPass(); registerOptimizationPass(); - registerBackendPass(); + registerBackendPassWithSwitcher(); // run registered passes runPasses(); diff --git a/compiler/nnc/driver/Driver.h b/compiler/nnc/driver/Driver.h index 9de06e7..dc96f5d 100644 --- a/compiler/nnc/driver/Driver.h +++ b/compiler/nnc/driver/Driver.h @@ -56,7 +56,7 @@ public: private: void registerFrontendPass(); - void registerBackendPass(); + void registerBackendPassWithSwitcher(); void registerOptimizationPass(); void runPasses(); diff --git a/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h b/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h new file mode 100644 index 0000000..c385d97 --- /dev/null +++ b/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATA_FORMAT_SWITCHER_PASS_H +#define DATA_FORMAT_SWITCHER_PASS_H + +#include "mir/Graph.h" +#include "mir/DataFormat.h" + +#include "pass/Pass.h" + +namespace nnc +{ +class DataFormatSwitcher : public Pass +{ +public: + explicit DataFormatSwitcher(mir::DataFormat target_format); + + PassData run(PassData data) override; + + void cleanup() override; + + ~DataFormatSwitcher() override; + + std::string getName() { return "DataFormatSwitcher"; } + +private: + // operations with DataFormat dependency + void switchConv2D(mir::ops::Conv2DOp *op); + void switchDeConv2D(mir::ops::DeConv2DOp *op); + void switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op); + void switchPool(mir::ops::PoolOp *op); + + // helper functions + mir::Operation::Output *insertTransposeBefore(mir::Operation::Output *out); + mir::Operation::Output *insertTransposeAfter(mir::Operation::Output *out); + +private: + mir::Graph *_graph; + mir::DataFormat _target_format; + std::vector _candidates_for_switch; +}; + +} // namespace nnc + +#endif // DATA_FORMAT_SWITCHER_PASS_H diff --git a/compiler/nnc/passes/CMakeLists.txt b/compiler/nnc/passes/CMakeLists.txt index 5a413bc..a00c4e9 100644 --- a/compiler/nnc/passes/CMakeLists.txt +++ b/compiler/nnc/passes/CMakeLists.txt @@ -30,6 +30,7 @@ add_subdirectory(optimizations) # # BACKENDs # +add_subdirectory(transformations) # transformations used before backends add_subdirectory(interpreter) add_subdirectory(soft_backend) add_subdirectory(acl_soft_backend) diff --git a/compiler/nnc/passes/transformations/CMakeLists.txt b/compiler/nnc/passes/transformations/CMakeLists.txt new file mode 100644 index 0000000..3a12b15 --- /dev/null +++ b/compiler/nnc/passes/transformations/CMakeLists.txt @@ -0,0 +1,6 @@ +set(TRANSFORMATIONS_SRC + DataFormatSwitcher.cpp) + +nnc_add_library(nnc_transformations STATIC ${TRANSFORMATIONS_SRC}) +set_target_properties(nnc_transformations PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_link_libraries(nnc_transformations PRIVATE mir nnc_support) diff --git a/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp new file mode 100644 index 0000000..74950ad --- /dev/null +++ b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/transformations/DataFormatSwitcher.h" + +#include "mir/ops/Conv2DOp.h" +#include "mir/ops/Deconv2DOp.h" +#include "mir/ops/DepthwiseConv2DOp.h" +#include "mir/ops/PoolOp.h" +#include "mir/ops/TransposeOp.h" + +namespace nnc +{ +DataFormatSwitcher::DataFormatSwitcher(const mir::DataFormat target_format) + : _target_format(target_format) +{ +} + +DataFormatSwitcher::~DataFormatSwitcher() = default; + +PassData DataFormatSwitcher::run(PassData data) +{ + _graph = static_cast(data); + assert(_graph); + + // Collect nodes which use DataFormat + for (auto *node : _graph->getNodes()) + { + switch (node->getType()) + { // nodes using DataFormat + case mir::Operation::Type::conv2D: + case mir::Operation::Type::deConv2D: + case mir::Operation::Type::depthwiseConv: + case mir::Operation::Type::pool: + _candidates_for_switch.push_back(node); + break; + default: + break; // not use DataFormat + } + } + // Switch collected ops + for (auto *op : _candidates_for_switch) + { + switch (op->getType()) + { + case mir::Operation::Type::conv2D: + switchConv2D(dynamic_cast(op)); + break; + case mir::Operation::Type::deConv2D: + switchDeConv2D(dynamic_cast(op)); + break; + case mir::Operation::Type::depthwiseConv: + switchDepthwiseConv2D(dynamic_cast(op)); + break; + case mir::Operation::Type::pool: + switchPool(dynamic_cast(op)); + break; + default: + assert(false && "Can't switch DataFormat for this operation!"); + } + } + + return _graph; +} + +void DataFormatSwitcher::cleanup() { _candidates_for_switch.clear(); } + +mir::Operation::Output *DataFormatSwitcher::insertTransposeBefore(mir::Operation::Output *out) +{ + if (_target_format == mir::DataFormat::NHWC) + return _graph->create(out, std::vector{0, 2, 3, 1}) + ->getOutput(0); // NCHW -> NHWC + else + return _graph->create(out, std::vector{0, 3, 1, 2}) + ->getOutput(0); // NHWC -> NCHW +} + +mir::Operation::Output *DataFormatSwitcher::insertTransposeAfter(mir::Operation::Output *out) +{ + if (_target_format == mir::DataFormat::NHWC) + return _graph->create(out, std::vector{0, 3, 1, 2}) + ->getOutput(0); // NHWC -> NCHW + else + return _graph->create(out, std::vector{0, 2, 3, 1}) + ->getOutput(0); // NCHW -> NHWC +} + +void DataFormatSwitcher::switchConv2D(mir::ops::Conv2DOp *op) +{ + if (op->getDataFormat() == _target_format) + return; + + assert(op->getNumInputs() == 2); + auto *input = op->getInput(0)->getProducer(); + auto *kernel = op->getInput(1)->getProducer(); + + const auto &strides = op->getStrides(); + const auto &padding_before = op->getPaddingBefore(); + const auto &padding_after = op->getPaddingAfter(); + + auto *trans_in = insertTransposeBefore(input); + + auto new_dw_conv = _graph->create(trans_in, kernel, strides, padding_before, + padding_after, _target_format); + + auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0)); + + _graph->replaceNode(op, trans_out->getNode()); +} + +void DataFormatSwitcher::switchDeConv2D(mir::ops::DeConv2DOp *op) +{ + if (op->getDataFormat() == _target_format) + return; + + assert(op->getNumInputs() == 2); + auto *input = op->getInput(0)->getProducer(); + auto *kernel = op->getInput(1)->getProducer(); + + const auto &strides = op->getStrides(); + const auto &padding_after = op->getPaddingAfter(); + const auto padding_type = op->getPaddingType(); + + auto *trans_in = insertTransposeBefore(input); + + mir::Operation *new_deconv; + if (padding_type == mir::ops::PaddingType::Custom) + new_deconv = _graph->create(trans_in, kernel, strides, padding_after, + _target_format); + else + new_deconv = _graph->create(trans_in, kernel, strides, padding_type, + _target_format); + + auto *trans_out = insertTransposeAfter(new_deconv->getOutput(0)); + + _graph->replaceNode(op, trans_out->getNode()); +} + +void DataFormatSwitcher::switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op) +{ + if (op->getDataFormat() == _target_format) + return; + + assert(op->getNumInputs() == 2); + auto *input = op->getInput(0)->getProducer(); + auto *kernel = op->getInput(1)->getProducer(); + + const auto &strides = op->getStrides(); + const auto &padding_before = op->getPaddingBefore(); + const auto &padding_after = op->getPaddingAfter(); + + auto *trans_in = insertTransposeBefore(input); + + auto new_dw_conv = _graph->create( + trans_in, kernel, strides, padding_before, padding_after, _target_format); + + auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0)); + + _graph->replaceNode(op, trans_out->getNode()); +} + +void DataFormatSwitcher::switchPool(mir::ops::PoolOp *op) +{ + if (op->getDataFormat() == _target_format) + return; + + auto *input = op->getInput(0)->getProducer(); + + const auto &window_shape = op->getWindowShape(); + const auto &strides = op->getStrides(); + const auto &padding_before = op->getPaddingBefore(); + const auto &padding_after = op->getPaddingAfter(); + const auto pooling_type = op->getPoolingType(); + const auto border_type = op->getBorderType(); + + auto *trans_in = insertTransposeBefore(input); + + auto new_pool = + _graph->create(trans_in, pooling_type, window_shape, strides, + padding_before, padding_after, border_type, _target_format); + + auto *trans_out = insertTransposeAfter(new_pool->getOutput(0)); + + _graph->replaceNode(op, trans_out->getNode()); +} + +} // namespace nnc -- 2.7.4