Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / CpuTensorHandle.hpp
index 4bf4439..3376650 100644 (file)
@@ -9,10 +9,12 @@
 
 #include "OutputHandler.hpp"
 
+#include <algorithm>
+
 namespace armnn
 {
 
-// Abstract tensor handle wrapping a CPU-readable region of memory, interpreting it as tensor data.
+// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
 class ConstCpuTensorHandle : public ITensorHandle
 {
 public:
@@ -33,6 +35,30 @@ public:
         return ITensorHandle::Cpu;
     }
 
+    virtual void Manage() override {}
+
+    virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+    virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
+    virtual void Unmap() const override {}
+
+    TensorShape GetStrides() const override
+    {
+        TensorShape shape(m_TensorInfo.GetShape());
+        auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
+        auto runningSize = size;
+        std::vector<unsigned int> strides(shape.GetNumDimensions());
+        auto lastIdx = shape.GetNumDimensions()-1;
+        for (unsigned int i=0; i < lastIdx ; i++)
+        {
+            strides[lastIdx-i] = runningSize;
+            runningSize *= shape[lastIdx-i];
+        }
+        strides[0] = runningSize;
+        return TensorShape(shape.GetNumDimensions(), strides.data());
+    }
+    TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
+
 protected:
     ConstCpuTensorHandle(const TensorInfo& tensorInfo);
 
@@ -46,7 +72,7 @@ private:
     const void* m_Memory;
 };
 
-// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data
+// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
 class CpuTensorHandle : public ConstCpuTensorHandle
 {
 public:
@@ -79,9 +105,12 @@ class ScopedCpuTensorHandle : public CpuTensorHandle
 public:
     explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
 
-    // Copies contents from Tensor
+    // Copies contents from Tensor.
     explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
 
+    // Copies contents from ConstCpuTensorHandle
+    explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
+
     ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
     ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
     ~ScopedCpuTensorHandle();
@@ -98,7 +127,7 @@ private:
 // Clients must make sure the passed in memory region stays alive for the lifetime of
 // the PassthroughCpuTensorHandle instance.
 //
-// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle
+// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
 class PassthroughCpuTensorHandle : public CpuTensorHandle
 {
 public:
@@ -117,7 +146,7 @@ public:
 // Clients must make sure the passed in memory region stays alive for the lifetime of
 // the PassthroughCpuTensorHandle instance.
 //
-// Note there is no polymorphism to/from PassthroughCpuTensorHandle
+// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
 {
 public:
@@ -131,7 +160,7 @@ public:
 };
 
 
-// template specializations
+// Template specializations.
 
 template <>
 const void* ConstCpuTensorHandle::GetConstTensor() const;