[nnc core] Add Tensor class template (#395)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Mon, 2 Jul 2018 15:03:52 +0000 (19:03 +0400)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Mon, 2 Jul 2018 15:03:52 +0000 (00:03 +0900)
[nnc core] Add Tensor class template

Used as a TensorVariant data accessor

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/include/nnc/core/linalg/Tensor.h [new file with mode: 0644]
contrib/nnc/libs/core/src/core/linalg/Tensor.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/linalg/Tensor.h b/contrib/nnc/libs/core/include/nnc/core/linalg/Tensor.h
new file mode 100644 (file)
index 0000000..a211428
--- /dev/null
@@ -0,0 +1,69 @@
+#pragma once
+
+#include "nncc/core/ADT/tensor/Shape.h"
+#include "nncc/core/ADT/tensor/Accessor.h"
+#include "nncc/core/ADT/tensor/Reader.h"
+
+#include "nncc/foundation/ExternalRegion.h"
+
+#include "nnc/core/linalg/TensorVariant.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace data
+{
+
+using nncc::core::ADT::tensor::Shape;
+using nncc::core::ADT::tensor::Index;
+using nncc::foundation::ExternalRegion;
+using nncc::core::ADT::tensor::Accessor;
+using nncc::core::ADT::tensor::Reader;
+
+template<typename T>
+class Tensor final : public Accessor<T>, public Reader<T> {
+ public:
+  Tensor() = delete;
+
+  explicit Tensor(const ADT::TensorVariant &t) : _proxy(t), _shape(t.getShape()) {
+  }
+
+  T at(const Index &id) const override {
+    return *reinterpret_cast<T *>(this->_proxy.at(id));
+  }
+
+  T &at(const Index &id) override {
+    return *reinterpret_cast<T *>(this->_proxy.at(id));
+  }
+
+  ExternalRegion<T> getRegion(const Index& idx) {
+    //Only last dimension is safe to process continiously
+    auto lastDim = _shape.rank()- 1;
+    auto base = reinterpret_cast<T *>(_proxy.at(idx));
+    auto length = _shape.dim(lastDim) - idx.at(lastDim);
+    return ExternalRegion<T>(base, length);
+  }
+
+  virtual const Shape &getShape() const { return _proxy.getShape(); };
+
+ private:
+  const ADT::TensorVariant& _proxy;
+  const Shape &_shape;
+};
+
+extern template
+class Tensor<float>;
+
+extern template
+class Tensor<double>;
+
+extern template
+class Tensor<int>;
+
+} // namespace data
+} // namespace core
+} // namespace contrib
+} // namespace nncc
diff --git a/contrib/nnc/libs/core/src/core/linalg/Tensor.cpp b/contrib/nnc/libs/core/src/core/linalg/Tensor.cpp
new file mode 100644 (file)
index 0000000..c2dcebe
--- /dev/null
@@ -0,0 +1,19 @@
+#include "nnc/core/linalg/Tensor.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace data
+{
+
+template class Tensor<float>;
+template class Tensor<double>;
+template class Tensor<int>;
+
+} // namespace data
+} // namespace core
+} // namespace contrib
+} // namespace nncc