void NNInterpreter::visit(ops::ReshapeOp &op)
{
auto inputs = getInputTensors(op);
- auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
+ auto outputs = Reshape(inputs[0], op.getOutputShape(0));
setOutputTensors(op, std::move(outputs));
}
{
auto inputs = getInputTensors(op);
// Squeeze is just a special case of reshape.
- auto outputs = Reshape<float>(inputs[0], op.getOutputShape(0))();
+ auto outputs = Reshape(inputs[0], op.getOutputShape(0));
setOutputTensors(op, std::move(outputs));
}
#ifndef _NNC_CORE_BACKEND_INTERPRETER_RESHAPE_IMPL_
#define _NNC_CORE_BACKEND_INTERPRETER_RESHAPE_IMPL_
-#include "mir/ops/ReshapeOp.h"
+#include "mir/ShapeRange.h"
+#include "mir/TensorVariant.h"
-#include "OperationImpl.h"
-#include "Fill.h"
+#include <cstring>
namespace nnc
{
-template <typename T> class Reshape : public OperationImpl<T>
+std::vector<mir::TensorVariant> Reshape(const mir::TensorVariant &input,
+ const mir::Shape &output_shape)
{
-public:
- Reshape(const mir::TensorVariant &input, const mir::Shape &output_shape)
- : _input(input), _output_shape(output_shape)
- {
-
- assert(input.getShape().numElements() == _output_shape.numElements());
- }
-
- std::vector<mir::TensorVariant> operator()() override
- {
- mir::ShapeRange inRange(_input.getShape());
- auto inIter = inRange.begin();
-
- auto out = OperationImpl<T>::allocate_tensor(_output_shape);
-
- // Shapes element count compared in Reshape ctor
- return Fill<T>(_output_shape,
- [this, &inIter](const mir::Index &) -> T { return _input.at(*inIter++); })();
- }
-
-private:
- mir::Tensor<T> _input;
- const mir::Shape &_output_shape;
-};
+ assert(input.getShape().numElements() == output_shape.numElements());
+ mir::TensorType type(input.getElementType(), output_shape);
+ if (input.getType().isQuantized())
+ type.setQuantization(input.getType().getQuantization());
+
+ mir::TensorVariant result(type);
+ mir::ShapeRange input_range(input.getShape());
+ auto in_iter = input_range.begin();
+ const size_t elem_size = input.getElementSize();
+
+ for (const auto &out_index : mir::ShapeRange(output_shape))
+ std::memcpy(result.at(out_index), input.at(*in_iter++), elem_size);
+
+ return {result};
+}
} // namespace nnc