From 9319b13a23203f9f840ea3d65967a04a9034eaa6 Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Thu, 29 Aug 2019 22:02:44 +0300
Subject: [PATCH] [nnc] Implemented Data Format Switcher (#6932)
* Implemented PoolOp visitor in switcher
Signed-off-by: Pavel Iliutchenko
---
compiler/nnc/driver/Driver.cpp | 15 +-
compiler/nnc/driver/Driver.h | 2 +-
.../passes/transformations/DataFormatSwitcher.h | 59 ++++++
compiler/nnc/passes/CMakeLists.txt | 1 +
compiler/nnc/passes/transformations/CMakeLists.txt | 6 +
.../passes/transformations/DataFormatSwitcher.cpp | 200 +++++++++++++++++++++
6 files changed, 278 insertions(+), 5 deletions(-)
create mode 100644 compiler/nnc/include/passes/transformations/DataFormatSwitcher.h
create mode 100644 compiler/nnc/passes/transformations/CMakeLists.txt
create mode 100644 compiler/nnc/passes/transformations/DataFormatSwitcher.cpp
diff --git a/compiler/nnc/driver/Driver.cpp b/compiler/nnc/driver/Driver.cpp
index cd49be5..b3cb4b4 100644
--- a/compiler/nnc/driver/Driver.cpp
+++ b/compiler/nnc/driver/Driver.cpp
@@ -18,6 +18,8 @@
#include "passes/common_frontend/NNImporter.h"
+#include "passes/transformations/DataFormatSwitcher.h"
+
#include "passes/interpreter/InterpreterPass.h"
#include "passes/soft_backend/CPPGenerator.h"
#include "passes/dot_dumper/DumperPass.h"
@@ -113,21 +115,25 @@ void Driver::registerFrontendPass()
* @brief Register backend pass
* @throw DriverException if errors occurred
*/
-void Driver::registerBackendPass()
+void Driver::registerBackendPassWithSwitcher()
{
-
+ std::unique_ptr data_format_pass;
std::unique_ptr pass;
if (cli::target == NNC_TARGET_ARM_CPP || cli::target == NNC_TARGET_X86_CPP)
{
+ data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC));
pass = std::unique_ptr(new CPPCodeGenerator());
}
else if (cli::target == NNC_TARGET_ARM_GPU_CPP)
{
+ // TODO Change to DataFormat::NCHW when fix it in ACL
+ data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC));
pass = std::unique_ptr(new AclCppCodeGenerator());
}
else if (cli::target == NNC_TARGET_INTERPRETER)
{
+ data_format_pass = std::unique_ptr(new DataFormatSwitcher(mir::DataFormat::NHWC));
pass = std::unique_ptr(new InterpreterPass());
}
else
@@ -135,9 +141,10 @@ void Driver::registerBackendPass()
assert(false && "invalid option value");
}
+ _passManager.registerPass(std::move(data_format_pass));
_passManager.registerPass(std::move(pass));
-} // registerBackendPass
+} // registerBackendPassWithSwitcher
void Driver::registerOptimizationPass()
{
@@ -165,7 +172,7 @@ void Driver::runDriver()
// register passes
registerFrontendPass();
registerOptimizationPass();
- registerBackendPass();
+ registerBackendPassWithSwitcher();
// run registered passes
runPasses();
diff --git a/compiler/nnc/driver/Driver.h b/compiler/nnc/driver/Driver.h
index 9de06e7..dc96f5d 100644
--- a/compiler/nnc/driver/Driver.h
+++ b/compiler/nnc/driver/Driver.h
@@ -56,7 +56,7 @@ public:
private:
void registerFrontendPass();
- void registerBackendPass();
+ void registerBackendPassWithSwitcher();
void registerOptimizationPass();
void runPasses();
diff --git a/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h b/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h
new file mode 100644
index 0000000..c385d97
--- /dev/null
+++ b/compiler/nnc/include/passes/transformations/DataFormatSwitcher.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef DATA_FORMAT_SWITCHER_PASS_H
+#define DATA_FORMAT_SWITCHER_PASS_H
+
+#include "mir/Graph.h"
+#include "mir/DataFormat.h"
+
+#include "pass/Pass.h"
+
+namespace nnc
+{
+class DataFormatSwitcher : public Pass
+{
+public:
+ explicit DataFormatSwitcher(mir::DataFormat target_format);
+
+ PassData run(PassData data) override;
+
+ void cleanup() override;
+
+ ~DataFormatSwitcher() override;
+
+ std::string getName() { return "DataFormatSwitcher"; }
+
+private:
+ // operations with DataFormat dependency
+ void switchConv2D(mir::ops::Conv2DOp *op);
+ void switchDeConv2D(mir::ops::DeConv2DOp *op);
+ void switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op);
+ void switchPool(mir::ops::PoolOp *op);
+
+ // helper functions
+ mir::Operation::Output *insertTransposeBefore(mir::Operation::Output *out);
+ mir::Operation::Output *insertTransposeAfter(mir::Operation::Output *out);
+
+private:
+ mir::Graph *_graph;
+ mir::DataFormat _target_format;
+ std::vector _candidates_for_switch;
+};
+
+} // namespace nnc
+
+#endif // DATA_FORMAT_SWITCHER_PASS_H
diff --git a/compiler/nnc/passes/CMakeLists.txt b/compiler/nnc/passes/CMakeLists.txt
index 5a413bc..a00c4e9 100644
--- a/compiler/nnc/passes/CMakeLists.txt
+++ b/compiler/nnc/passes/CMakeLists.txt
@@ -30,6 +30,7 @@ add_subdirectory(optimizations)
#
# BACKENDs
#
+add_subdirectory(transformations) # transformations used before backends
add_subdirectory(interpreter)
add_subdirectory(soft_backend)
add_subdirectory(acl_soft_backend)
diff --git a/compiler/nnc/passes/transformations/CMakeLists.txt b/compiler/nnc/passes/transformations/CMakeLists.txt
new file mode 100644
index 0000000..3a12b15
--- /dev/null
+++ b/compiler/nnc/passes/transformations/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(TRANSFORMATIONS_SRC
+ DataFormatSwitcher.cpp)
+
+nnc_add_library(nnc_transformations STATIC ${TRANSFORMATIONS_SRC})
+set_target_properties(nnc_transformations PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_link_libraries(nnc_transformations PRIVATE mir nnc_support)
diff --git a/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp
new file mode 100644
index 0000000..74950ad
--- /dev/null
+++ b/compiler/nnc/passes/transformations/DataFormatSwitcher.cpp
@@ -0,0 +1,200 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/transformations/DataFormatSwitcher.h"
+
+#include "mir/ops/Conv2DOp.h"
+#include "mir/ops/Deconv2DOp.h"
+#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/ops/TransposeOp.h"
+
+namespace nnc
+{
+DataFormatSwitcher::DataFormatSwitcher(const mir::DataFormat target_format)
+ : _target_format(target_format)
+{
+}
+
+DataFormatSwitcher::~DataFormatSwitcher() = default;
+
+PassData DataFormatSwitcher::run(PassData data)
+{
+ _graph = static_cast(data);
+ assert(_graph);
+
+ // Collect nodes which use DataFormat
+ for (auto *node : _graph->getNodes())
+ {
+ switch (node->getType())
+ { // nodes using DataFormat
+ case mir::Operation::Type::conv2D:
+ case mir::Operation::Type::deConv2D:
+ case mir::Operation::Type::depthwiseConv:
+ case mir::Operation::Type::pool:
+ _candidates_for_switch.push_back(node);
+ break;
+ default:
+ break; // not use DataFormat
+ }
+ }
+ // Switch collected ops
+ for (auto *op : _candidates_for_switch)
+ {
+ switch (op->getType())
+ {
+ case mir::Operation::Type::conv2D:
+ switchConv2D(dynamic_cast(op));
+ break;
+ case mir::Operation::Type::deConv2D:
+ switchDeConv2D(dynamic_cast(op));
+ break;
+ case mir::Operation::Type::depthwiseConv:
+ switchDepthwiseConv2D(dynamic_cast(op));
+ break;
+ case mir::Operation::Type::pool:
+ switchPool(dynamic_cast(op));
+ break;
+ default:
+ assert(false && "Can't switch DataFormat for this operation!");
+ }
+ }
+
+ return _graph;
+}
+
+void DataFormatSwitcher::cleanup() { _candidates_for_switch.clear(); }
+
+mir::Operation::Output *DataFormatSwitcher::insertTransposeBefore(mir::Operation::Output *out)
+{
+ if (_target_format == mir::DataFormat::NHWC)
+ return _graph->create(out, std::vector{0, 2, 3, 1})
+ ->getOutput(0); // NCHW -> NHWC
+ else
+ return _graph->create(out, std::vector{0, 3, 1, 2})
+ ->getOutput(0); // NHWC -> NCHW
+}
+
+mir::Operation::Output *DataFormatSwitcher::insertTransposeAfter(mir::Operation::Output *out)
+{
+ if (_target_format == mir::DataFormat::NHWC)
+ return _graph->create(out, std::vector{0, 3, 1, 2})
+ ->getOutput(0); // NHWC -> NCHW
+ else
+ return _graph->create(out, std::vector{0, 2, 3, 1})
+ ->getOutput(0); // NCHW -> NHWC
+}
+
+void DataFormatSwitcher::switchConv2D(mir::ops::Conv2DOp *op)
+{
+ if (op->getDataFormat() == _target_format)
+ return;
+
+ assert(op->getNumInputs() == 2);
+ auto *input = op->getInput(0)->getProducer();
+ auto *kernel = op->getInput(1)->getProducer();
+
+ const auto &strides = op->getStrides();
+ const auto &padding_before = op->getPaddingBefore();
+ const auto &padding_after = op->getPaddingAfter();
+
+ auto *trans_in = insertTransposeBefore(input);
+
+ auto new_dw_conv = _graph->create(trans_in, kernel, strides, padding_before,
+ padding_after, _target_format);
+
+ auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));
+
+ _graph->replaceNode(op, trans_out->getNode());
+}
+
+void DataFormatSwitcher::switchDeConv2D(mir::ops::DeConv2DOp *op)
+{
+ if (op->getDataFormat() == _target_format)
+ return;
+
+ assert(op->getNumInputs() == 2);
+ auto *input = op->getInput(0)->getProducer();
+ auto *kernel = op->getInput(1)->getProducer();
+
+ const auto &strides = op->getStrides();
+ const auto &padding_after = op->getPaddingAfter();
+ const auto padding_type = op->getPaddingType();
+
+ auto *trans_in = insertTransposeBefore(input);
+
+ mir::Operation *new_deconv;
+ if (padding_type == mir::ops::PaddingType::Custom)
+ new_deconv = _graph->create(trans_in, kernel, strides, padding_after,
+ _target_format);
+ else
+ new_deconv = _graph->create(trans_in, kernel, strides, padding_type,
+ _target_format);
+
+ auto *trans_out = insertTransposeAfter(new_deconv->getOutput(0));
+
+ _graph->replaceNode(op, trans_out->getNode());
+}
+
+void DataFormatSwitcher::switchDepthwiseConv2D(mir::ops::DepthwiseConv2DOp *op)
+{
+ if (op->getDataFormat() == _target_format)
+ return;
+
+ assert(op->getNumInputs() == 2);
+ auto *input = op->getInput(0)->getProducer();
+ auto *kernel = op->getInput(1)->getProducer();
+
+ const auto &strides = op->getStrides();
+ const auto &padding_before = op->getPaddingBefore();
+ const auto &padding_after = op->getPaddingAfter();
+
+ auto *trans_in = insertTransposeBefore(input);
+
+ auto new_dw_conv = _graph->create(
+ trans_in, kernel, strides, padding_before, padding_after, _target_format);
+
+ auto *trans_out = insertTransposeAfter(new_dw_conv->getOutput(0));
+
+ _graph->replaceNode(op, trans_out->getNode());
+}
+
+void DataFormatSwitcher::switchPool(mir::ops::PoolOp *op)
+{
+ if (op->getDataFormat() == _target_format)
+ return;
+
+ auto *input = op->getInput(0)->getProducer();
+
+ const auto &window_shape = op->getWindowShape();
+ const auto &strides = op->getStrides();
+ const auto &padding_before = op->getPaddingBefore();
+ const auto &padding_after = op->getPaddingAfter();
+ const auto pooling_type = op->getPoolingType();
+ const auto border_type = op->getBorderType();
+
+ auto *trans_in = insertTransposeBefore(input);
+
+ auto new_pool =
+ _graph->create(trans_in, pooling_type, window_shape, strides,
+ padding_before, padding_after, border_type, _target_format);
+
+ auto *trans_out = insertTransposeAfter(new_pool->getOutput(0));
+
+ _graph->replaceNode(op, trans_out->getNode());
+}
+
+} // namespace nnc
--
2.7.4