#include "passes/acl_soft_backend/AclCppGenerator.h"
#include "passes/optimizations/CombineTransposes.h"
+#include "passes/optimizations/ConstantFoldTranspose.h"
#include "passes/optimizations/RemoveDeadEnds.h"
#include "passes/optimizations/FuseArithmeticOps.h"
#include "passes/optimizations/SinkRelu.h"
void Driver::registerOptimizationPass()
{
+ // TODO For now this optimization is mandatory. Do it optional when ACL backend is able to handle
+ // all transposes that come from importers.
+ _passManager.registerPass(std::unique_ptr<Pass>(new ConstantFoldTranspose()));
+
if (cli::doOptimizationPass)
{
// TODO: maybe we should start managing the optimizations more intelligently?
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_CONSTANT_FOLD_TRANSPOSE_H
+#define NNCC_CONSTANT_FOLD_TRANSPOSE_H
+
+#include "pass/Pass.h"
+
+namespace nnc
+{
+
+class ConstantFoldTranspose : public Pass
+{
+public:
+ PassData run(PassData data) override;
+
+ std::string getName() override
+ {
+ static const std::string name("opt_constant_fold_transpose");
+ return name;
+ };
+};
+
+} // namespace nnc
+
+#endif // NNCC_CONSTANT_FOLD_TRANSPOSE_H
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/ConstantFoldTranspose.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/GraphPatternMatcher.h"
+#include "mir/ShapeRange.h"
+#include "mir/Tensor.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/TransposeOp.h"
+
+using namespace nnc;
+using namespace mir;
+
+// Copy & paste from interpreter backend.
+// TODO Extract this to a common place and use in both interpreter and optimizations.
+static void transpose(const TensorVariant &arg, TensorVariant &res,
+ const std::vector<std::size_t> &axis_order)
+{
+ Tensor<float> arg_accessor(arg);
+ Tensor<float> res_accessor(res);
+
+ const auto &input_shape = arg.getShape();
+ const int num_axes = static_cast<int>(axis_order.size());
+ assert(num_axes == input_shape.rank());
+
+ ShapeRange in_range(input_shape);
+ Index out_index(input_shape.rank());
+
+ for (const auto &in_index : in_range)
+ {
+ for (int i = 0; i < num_axes; ++i)
+ {
+ out_index.at(i) = in_index.at(axis_order[i]);
+ }
+ res_accessor.at(out_index) = arg_accessor.at(in_index);
+ }
+}
+
+PassData ConstantFoldTranspose::run(PassData data)
+{
+ auto graph = static_cast<Graph *>(data);
+
+ GraphPatternMatcher matcher(graph);
+ auto is_constant = [](const Operation *op) { return op->getType() == Operation::Type::constant; };
+ auto is_transpose = [](const Operation *op) {
+ return op->getType() == Operation::Type::transpose;
+ };
+
+ auto matches = matcher.matchEdge(is_constant, is_transpose);
+ while (!matches.empty())
+ {
+ for (const auto match : matches)
+ {
+ auto constant_op = dynamic_cast<ops::ConstantOp *>(match.first);
+ auto transpose_op = dynamic_cast<ops::TransposeOp *>(match.second);
+
+ // FIXME Revise this when we've got type information in operations.
+ TensorVariant res(DataType::FLOAT32, transpose_op->getOutputShape(0));
+ transpose(constant_op->getValue(), res, transpose_op->getAxisOrder());
+
+ auto new_op = graph->create<ops::ConstantOp>("", res);
+
+ graph->replaceNode(transpose_op, new_op);
+ opt_util::removeNodeIfUnused(graph, constant_op);
+ }
+ matches = matcher.matchEdge(is_constant, is_transpose);
+ }
+ return graph;
+}