From 9eec9db9c405ef1e8b461ed367670e14ba5a8a57 Mon Sep 17 00:00:00 2001 From: Pavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics Date: Thu, 17 Oct 2019 21:43:36 +0300 Subject: [PATCH] [nnc] Fix constant folding optimization for quantization (#8241) * Fixed transpose for different DataType * Added qunatization setting on optimization Signed-off-by: Pavel Iliutchenko --- .../passes/optimizations/ConstantFoldTranspose.cpp | 26 +++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp index f23d5b2..47a3147 100644 --- a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp +++ b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp @@ -18,35 +18,34 @@ #include "passes/optimizations/OptimizationUtils.h" #include "mir/GraphPatternMatcher.h" #include "mir/ShapeRange.h" -#include "mir/Tensor.h" #include "mir/ops/ConstantOp.h" #include "mir/ops/TransposeOp.h" +#include + using namespace nnc; using namespace mir; // Copy & paste from interpreter backend. // TODO Extract this to a common place and use in both interpreter and optimizations. -static void transpose(const TensorVariant &arg, TensorVariant &res, +static void transpose(const TensorVariant &input, TensorVariant &res, const std::vector &axis_order) { - Tensor arg_accessor(arg); - Tensor res_accessor(res); - - const auto &input_shape = arg.getShape(); + const auto &input_shape = input.getShape(); const int num_axes = static_cast(axis_order.size()); assert(num_axes == input_shape.rank()); ShapeRange in_range(input_shape); Index out_index(input_shape.rank()); + const size_t elem_size = input.getElementSize(); + for (const auto &in_index : in_range) { for (int i = 0; i < num_axes; ++i) - { out_index.at(i) = in_index.at(axis_order[i]); - } - res_accessor.at(out_index) = arg_accessor.at(in_index); + + std::memcpy(res.at(out_index), input.at(in_index), elem_size); } } @@ -68,8 +67,13 @@ PassData ConstantFoldTranspose::run(PassData data) auto constant_op = dynamic_cast(match.first); auto transpose_op = dynamic_cast(match.second); - // FIXME Revise this when we've got type information in operations. - TensorVariant res(DataType::FLOAT32, transpose_op->getOutputShape(0)); + const auto elem_type = constant_op->getValue().getElementType(); + const auto &out_shape = transpose_op->getOutputShape(0); + TensorType res_type(elem_type, out_shape); + if (constant_op->getOutput(0)->getType().isQuantized()) + res_type.setQuantization(constant_op->getOutput(0)->getType().getQuantization()); + + TensorVariant res(res_type); transpose(constant_op->getValue(), res, transpose_op->getAxisOrder()); auto new_op = graph->create(res); -- 2.7.4