[ncc/interpreter] Fix Reshape operation for quantization (#8246)
authorPavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics <p.iliutchenk@samsung.com>
Thu, 17 Oct 2019 18:45:38 +0000 (21:45 +0300)
committerAlexander Efimov/./AI Tools Lab/Samsung Electronics <a.efimov@samsung.com>
Thu, 17 Oct 2019 18:45:38 +0000 (21:45 +0300)
* Made Reshape independent from DataType
* Fixed setting output quantization

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/nnc/backends/interpreter/Interpreter.cpp
compiler/nnc/backends/interpreter/ops/Reshape.h

index 00a32ac..b7e159a 100644 (file)
@@ -117,7 +117,7 @@ void NNInterpreter::visit(ops::MaxPool2DOp &op)
 void NNInterpreter::visit(ops::ReshapeOp &op)
 {
   auto inputs = getInputTensors(op);
-  auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
+  auto outputs = Reshape(inputs[0], op.getOutputShape(0));
   setOutputTensors(op, std::move(outputs));
 }
 
@@ -215,7 +215,7 @@ void NNInterpreter::visit(ops::SqueezeOp &op)
 {
   auto inputs = getInputTensors(op);
   // Squeeze is just a special case of reshape.
-  auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
+  auto outputs = Reshape(inputs[0], op.getOutputShape(0));
   setOutputTensors(op, std::move(outputs));
 }
 
index 1979f5c..83f0140 100644 (file)
 #ifndef _NNC_CORE_BACKEND_INTERPRETER_RESHAPE_IMPL_
 #define _NNC_CORE_BACKEND_INTERPRETER_RESHAPE_IMPL_
 
-#include "mir/ops/ReshapeOp.h"
+#include "mir/ShapeRange.h"
+#include "mir/TensorVariant.h"
 
-#include "OperationImpl.h"
-#include "Fill.h"
+#include <cstring>
 
 namespace nnc
 {
 
-template <typename T> class Reshape : public OperationImpl<T>
+std::vector<mir::TensorVariant> Reshape(const mir::TensorVariant &input,
+                                        const mir::Shape &output_shape)
 {
-public:
-  Reshape(const mir::TensorVariant &input, const mir::Shape &output_shape)
-      : _input(input), _output_shape(output_shape)
-  {
-
-    assert(input.getShape().numElements() == _output_shape.numElements());
-  }
-
-  std::vector<mir::TensorVariant> operator()() override
-  {
-    mir::ShapeRange inRange(_input.getShape());
-    auto inIter = inRange.begin();
-
-    auto out = OperationImpl<T>::allocate_tensor(_output_shape);
-
-    // Shapes element count compared in Reshape ctor
-    return Fill<T>(_output_shape,
-                   [this, &inIter](const mir::Index &) -> T { return _input.at(*inIter++); })();
-  }
-
-private:
-  mir::Tensor<T> _input;
-  const mir::Shape &_output_shape;
-};
+  assert(input.getShape().numElements() == output_shape.numElements());
+  mir::TensorType type(input.getElementType(), output_shape);
+  if (input.getType().isQuantized())
+    type.setQuantization(input.getType().getQuantization());
+
+  mir::TensorVariant result(type);
+  mir::ShapeRange input_range(input.getShape());
+  auto in_iter = input_range.begin();
+  const size_t elem_size = input.getElementSize();
+
+  for (const auto &out_index : mir::ShapeRange(output_shape))
+    std::memcpy(result.at(out_index), input.at(*in_iter++), elem_size);
+
+  return {result};
+}
 
 } // namespace nnc