From 9eec9db9c405ef1e8b461ed367670e14ba5a8a57 Mon Sep 17 00:00:00 2001
From: Pavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics
Date: Thu, 17 Oct 2019 21:43:36 +0300
Subject: [PATCH] [nnc] Fix constant folding optimization for quantization
(#8241)
* Fixed transpose for different DataType
* Added qunatization setting on optimization
Signed-off-by: Pavel Iliutchenko
---
.../passes/optimizations/ConstantFoldTranspose.cpp | 26 +++++++++++++---------
1 file changed, 15 insertions(+), 11 deletions(-)
diff --git a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp
index f23d5b2..47a3147 100644
--- a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp
+++ b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp
@@ -18,35 +18,34 @@
#include "passes/optimizations/OptimizationUtils.h"
#include "mir/GraphPatternMatcher.h"
#include "mir/ShapeRange.h"
-#include "mir/Tensor.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/TransposeOp.h"
+#include
+
using namespace nnc;
using namespace mir;
// Copy & paste from interpreter backend.
// TODO Extract this to a common place and use in both interpreter and optimizations.
-static void transpose(const TensorVariant &arg, TensorVariant &res,
+static void transpose(const TensorVariant &input, TensorVariant &res,
const std::vector &axis_order)
{
- Tensor arg_accessor(arg);
- Tensor res_accessor(res);
-
- const auto &input_shape = arg.getShape();
+ const auto &input_shape = input.getShape();
const int num_axes = static_cast(axis_order.size());
assert(num_axes == input_shape.rank());
ShapeRange in_range(input_shape);
Index out_index(input_shape.rank());
+ const size_t elem_size = input.getElementSize();
+
for (const auto &in_index : in_range)
{
for (int i = 0; i < num_axes; ++i)
- {
out_index.at(i) = in_index.at(axis_order[i]);
- }
- res_accessor.at(out_index) = arg_accessor.at(in_index);
+
+ std::memcpy(res.at(out_index), input.at(in_index), elem_size);
}
}
@@ -68,8 +67,13 @@ PassData ConstantFoldTranspose::run(PassData data)
auto constant_op = dynamic_cast(match.first);
auto transpose_op = dynamic_cast(match.second);
- // FIXME Revise this when we've got type information in operations.
- TensorVariant res(DataType::FLOAT32, transpose_op->getOutputShape(0));
+ const auto elem_type = constant_op->getValue().getElementType();
+ const auto &out_shape = transpose_op->getOutputShape(0);
+ TensorType res_type(elem_type, out_shape);
+ if (constant_op->getOutput(0)->getType().isQuantized())
+ res_type.setQuantization(constant_op->getOutput(0)->getType().getQuantization());
+
+ TensorVariant res(res_type);
transpose(constant_op->getValue(), res, transpose_op->getAxisOrder());
auto new_op = graph->create(res);
--
2.7.4