--- /dev/null
+#ifndef _NNC_CORE_LINALG_TENSOR_VARIANT_H_
+#define _NNC_CORE_LINALG_TENSOR_VARIANT_H_
+
+#include <utility>
+#include <memory>
+#include <cassert>
+
+#include "nncc/core/ADT/tensor/Index.h"
+#include "nncc/core/ADT/tensor/Shape.h"
+
+namespace nncc {
+namespace contrib {
+namespace core {
+namespace ADT {
+
+using nncc::core::ADT::tensor::Shape;
+using nncc::core::ADT::tensor::Index;
+
+constexpr int MAX_DIMENSIONS = 32;
+
+class TensorVariant {
+public:
+ enum class DTYPE {
+ UNKNOWN,
+ FLOAT,
+ INT
+ };
+
+ explicit TensorVariant(const Shape& shape, const std::shared_ptr<char>& data, DTYPE dtype, size_t element_size);
+
+ template<typename T>
+ explicit TensorVariant(const Shape& shape, const std::shared_ptr<T>& data, DTYPE dtype) :
+ TensorVariant(
+ shape,
+ std::shared_ptr<char>(data, (char*)data.get()),
+ dtype,
+ sizeof(typename std::remove_extent<T>::type))
+ {
+ }
+ ~TensorVariant() = default;
+
+ char *at(const Index &idx) const;
+ size_t getOffset(const Index &idx) const;
+
+ virtual const Shape &getShape() const { return _shape; }
+ const DTYPE getDataType() const { return _dtype; }
+ size_t getElementSize() const { return _element_size; }
+
+ private:
+ const DTYPE _dtype;
+ const std::shared_ptr<char> _data;
+ uint_fast32_t _strides[MAX_DIMENSIONS];
+ size_t _rank;
+ Shape _shape;
+
+ const size_t _element_size;
+};
+
+} // namespace ADT
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_LINALG_TENSOR_VARIANT_H_
--- /dev/null
+#include "nnc/core/linalg/TensorVariant.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace ADT
+{
+
+TensorVariant::TensorVariant(const Shape& shape, const std::shared_ptr<char>& data, TensorVariant::DTYPE dtype, size_t element_size)
+ : _shape(shape), _data(data), _dtype(dtype), _strides{0}, _element_size(element_size)
+{
+ int stride = 1;
+ _rank = _shape.rank();
+ for (int d = _rank - 1; d >= 0; --d)
+ {
+ _strides[d] = stride;
+ stride *= _shape.dim(d);
+ }
+}
+
+char *TensorVariant::at(const Index &idx) const
+{
+ return _data.get() + getOffset(idx) * _element_size;
+}
+
+size_t TensorVariant::getOffset(const Index &idx) const {
+ assert(idx.rank() == getShape().rank());
+ std::size_t offset = 0;
+ for (size_t i = 0; i < _rank; ++i)
+ {
+ offset += idx.at(i) * _strides[i];
+ }
+ return offset;
+}
+
+} // namespace ADT
+} // namespace core
+} // namespace contrib
+} // namespace nncc
--- /dev/null
+#include "nnc/core/linalg/TensorVariant.h"
+#include "nncc/core/ADT/feature/Shape.h"
+
+#include <gtest/gtest.h>
+
+using namespace nncc::contrib::core::ADT;
+using namespace nncc::core::ADT::tensor;
+
+TEST(TensorVariant, BasicTest) {
+ Shape shape{2,2};
+ char* ptr = (char*)(new float[4]);
+ std::shared_ptr<char> mem(ptr, [](char* d){ delete[] (float*)d; } );
+
+ TensorVariant t(shape, mem, TensorVariant::DTYPE::FLOAT, sizeof(float));
+
+ ASSERT_EQ(t.getShape(), shape);
+ ASSERT_EQ(t.getOffset({0,0}), 0);
+}
+
+TEST(TensorVariant, ElementSizeDeductionTest) {
+ Shape shape{2, 2, 2};
+
+ std::shared_ptr<float> mem(new float[8], [](float* f){ delete[] f; });
+
+ TensorVariant t(shape, mem, TensorVariant::DTYPE::FLOAT);
+
+ ASSERT_EQ(t.getElementSize(), sizeof(float));
+ ASSERT_EQ((float*)t.at({1,1,1}), mem.get() + 7);
+}
+
+TEST(TensorVariant, DeletionTest) {
+ struct Indicator {
+ Indicator() : val(true) {
+ }
+
+ ~Indicator() {
+ val = false;
+ }
+
+ bool val;
+ };
+
+ TensorVariant* t;
+ auto raw_indicator = new Indicator[1];
+ {
+ Shape shape{1,1};
+ auto mem = std::shared_ptr<Indicator>(raw_indicator, [](Indicator*& p){ delete[] p; });
+ t = new TensorVariant(shape, mem, TensorVariant::DTYPE::UNKNOWN);
+ //mem gets destroyed here
+ }
+
+ ASSERT_EQ(raw_indicator->val, true);
+ delete t;
+ ASSERT_EQ(raw_indicator->val, false);
+}