IVGCVSW-5036 Do not allocate memory when import is enabled
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Fri, 26 Jun 2020 10:00:21 +0000 (11:00 +0100)
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Tue, 30 Jun 2020 08:39:55 +0000 (08:39 +0000)
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ideaae5280702aae6c73f3b4e4cee9f71a8386fda

include/armnn/Exceptions.hpp
src/backends/backendsCommon/test/EndToEndTestImpl.hpp
src/backends/reference/RefTensorHandle.cpp
src/backends/reference/RefTensorHandle.hpp
src/backends/reference/RefTensorHandleFactory.cpp
src/backends/reference/RefTensorHandleFactory.hpp
src/backends/reference/test/RefTensorHandleTests.cpp

index 2f7b099..e3f086e 100644 (file)
@@ -149,6 +149,11 @@ public:
     using Exception::Exception;
 };
 
+class NullPointerException : public Exception
+{
+public:
+    using Exception::Exception;
+};
 
 template <typename ExceptionType>
 void ConditionalThrow(bool condition, const std::string& message)
index a4d0d50..0d16bcd 100644 (file)
@@ -719,6 +719,11 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend
     std::vector<float> outputData0(4);
     std::vector<float> outputData1(4);
 
+    std::vector<float> expectedOutput
+    {
+         1.0f, 4.0f, 9.0f, 16.0f
+    };
+
     InputTensors inputTensors
     {
         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
@@ -764,6 +769,12 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend
     // Contains CopyMemGeneric
     found = dump.find("CopyMemGeneric");
     BOOST_TEST(found != std::string::npos);
+
+    // Check that the outputs are correct
+    BOOST_CHECK_EQUAL_COLLECTIONS(outputData0.begin(), outputData0.end(),
+                                  expectedOutput.begin(), expectedOutput.end());
+    BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(),
+                                  expectedOutput.begin(), expectedOutput.end());
 }
 
 inline void StridedSliceInvalidSliceEndToEndTest(std::vector<BackendId> backends)
index 7d86b11..b9e566e 100644 (file)
@@ -13,19 +13,20 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R
     m_Pool(nullptr),
     m_UnmanagedMemory(nullptr),
     m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
-    m_Imported(false)
+    m_Imported(false),
+    m_IsImportEnabled(false)
 {
 
 }
 
-RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager,
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
                                  MemorySourceFlags importFlags)
                                  : m_TensorInfo(tensorInfo),
-                                   m_MemoryManager(memoryManager),
                                    m_Pool(nullptr),
                                    m_UnmanagedMemory(nullptr),
                                    m_ImportFlags(importFlags),
-                                   m_Imported(false)
+                                   m_Imported(false),
+                                   m_IsImportEnabled(true)
 {
 
 }
@@ -44,31 +45,39 @@ RefTensorHandle::~RefTensorHandle()
 
 void RefTensorHandle::Manage()
 {
-    ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
-    ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
+    if (!m_IsImportEnabled)
+    {
+        ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
+        ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
 
-    m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
+        m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
+    }
 }
 
 void RefTensorHandle::Allocate()
 {
-    if (!m_UnmanagedMemory)
+    // If import is enabled, do not allocate the tensor
+    if (!m_IsImportEnabled)
     {
-        if (!m_Pool)
+
+        if (!m_UnmanagedMemory)
         {
-            // unmanaged
-            m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
+            if (!m_Pool)
+            {
+                // unmanaged
+                m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
+            }
+            else
+            {
+                m_MemoryManager->Allocate(m_Pool);
+            }
         }
         else
         {
-            m_MemoryManager->Allocate(m_Pool);
+            throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
+                                           "that already has allocated memory.");
         }
     }
-    else
-    {
-        throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
-                                       "that already has allocated memory.");
-    }
 }
 
 const void* RefTensorHandle::Map(bool /*unused*/) const
@@ -82,11 +91,14 @@ void* RefTensorHandle::GetPointer() const
     {
         return m_UnmanagedMemory;
     }
-    else
+    else if (m_Pool)
     {
-        ARMNN_ASSERT_MSG(m_Pool, "RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
         return m_MemoryManager->GetPointer(m_Pool);
     }
+    else
+    {
+        throw NullPointerException("RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
+    }
 }
 
 void RefTensorHandle::CopyOutTo(void* dest) const
@@ -105,10 +117,9 @@ void RefTensorHandle::CopyInFrom(const void* src)
 
 bool RefTensorHandle::Import(void* memory, MemorySource source)
 {
-
     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
     {
-        if (source == MemorySource::Malloc)
+        if (m_IsImportEnabled && source == MemorySource::Malloc)
         {
             // Check memory alignment
             constexpr uintptr_t alignment = sizeof(size_t);
index 6cde326..8c64dfb 100644 (file)
@@ -17,8 +17,7 @@ class RefTensorHandle : public ITensorHandle
 public:
     RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
 
-    RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager,
-                    MemorySourceFlags importFlags);
+    RefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
 
     ~RefTensorHandle();
 
@@ -73,9 +72,10 @@ private:
 
     std::shared_ptr<RefMemoryManager> m_MemoryManager;
     RefMemoryManager::Pool* m_Pool;
-    mutable void *m_UnmanagedMemory;
+    mutable voidm_UnmanagedMemory;
     MemorySourceFlags m_ImportFlags;
     bool m_Imported;
+    bool m_IsImportEnabled;
 };
 
 }
index d687c78..ade27dd 100644 (file)
@@ -29,14 +29,42 @@ std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateSubTensorHandle(ITe
 
 std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
 {
-    return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager, m_ImportFlags);
+    return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
 }
 
 std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
                                                                           DataLayout dataLayout) const
 {
     IgnoreUnused(dataLayout);
-    return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager, m_ImportFlags);
+    return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
+}
+
+std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                                          const bool IsMemoryManaged) const
+{
+    if (IsMemoryManaged)
+    {
+        return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
+    }
+    else
+    {
+        return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+    }
+}
+
+std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                                          DataLayout dataLayout,
+                                                                          const bool IsMemoryManaged) const
+{
+    IgnoreUnused(dataLayout);
+    if (IsMemoryManaged)
+    {
+        return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
+    }
+    else
+    {
+        return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+    }
 }
 
 const FactoryId& RefTensorHandleFactory::GetId() const
index 8ea02f5..7e572d7 100644 (file)
@@ -33,6 +33,13 @@ public:
     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
                                                       DataLayout dataLayout) const override;
 
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                      const bool IsMemoryManaged) const override;
+
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                      DataLayout dataLayout,
+                                                      const bool IsMemoryManaged) const override;
+
     static const FactoryId& GetIdStatic();
 
     const FactoryId& GetId() const override;
index be229bf..42f5664 100644 (file)
@@ -3,6 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 #include <reference/RefTensorHandle.hpp>
+#include <reference/RefTensorHandleFactory.hpp>
 
 #include <boost/test/unit_test.hpp>
 
@@ -13,7 +14,7 @@ BOOST_AUTO_TEST_CASE(AcquireAndRelease)
 {
     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
 
-    TensorInfo info({1,1,1,1}, DataType::Float32);
+    TensorInfo info({ 1, 1, 1, 1 }, DataType::Float32);
     RefTensorHandle handle(info, memoryManager);
 
     handle.Manage();
@@ -21,7 +22,7 @@ BOOST_AUTO_TEST_CASE(AcquireAndRelease)
 
     memoryManager->Acquire();
     {
-        float *buffer = reinterpret_cast<float *>(handle.Map());
+        float* buffer = reinterpret_cast<float*>(handle.Map());
 
         BOOST_CHECK(buffer != nullptr); // Yields a valid pointer
 
@@ -34,7 +35,7 @@ BOOST_AUTO_TEST_CASE(AcquireAndRelease)
 
     memoryManager->Acquire();
     {
-        float *buffer = reinterpret_cast<float *>(handle.Map());
+        float* buffer = reinterpret_cast<float*>(handle.Map());
 
         BOOST_CHECK(buffer != nullptr); // Yields a valid pointer
 
@@ -45,14 +46,107 @@ BOOST_AUTO_TEST_CASE(AcquireAndRelease)
     memoryManager->Release();
 }
 
+BOOST_AUTO_TEST_CASE(RefTensorHandleFactoryMemoryManaged)
+{
+    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
+    RefTensorHandleFactory handleFactory(memoryManager);
+    TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
+
+    // create TensorHandle with memory managed
+    auto handle = handleFactory.CreateTensorHandle(info, true);
+    handle->Manage();
+    handle->Allocate();
+
+    memoryManager->Acquire();
+    {
+        float* buffer = reinterpret_cast<float*>(handle->Map());
+        BOOST_CHECK(buffer != nullptr); // Yields a valid pointer
+        buffer[0] = 1.5f;
+        buffer[1] = 2.5f;
+        BOOST_CHECK(buffer[0] == 1.5f); // Memory is writable and readable
+        BOOST_CHECK(buffer[1] == 2.5f); // Memory is writable and readable
+    }
+    memoryManager->Release();
+
+    memoryManager->Acquire();
+    {
+        float* buffer = reinterpret_cast<float*>(handle->Map());
+        BOOST_CHECK(buffer != nullptr); // Yields a valid pointer
+        buffer[0] = 3.5f;
+        buffer[1] = 4.5f;
+        BOOST_CHECK(buffer[0] == 3.5f); // Memory is writable and readable
+        BOOST_CHECK(buffer[1] == 4.5f); // Memory is writable and readable
+    }
+    memoryManager->Release();
+
+    float testPtr[2] = { 2.5f, 5.5f };
+    // Cannot import as import is disabled
+    BOOST_CHECK(!handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+}
+
+BOOST_AUTO_TEST_CASE(RefTensorHandleFactoryImport)
+{
+    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
+    RefTensorHandleFactory handleFactory(memoryManager);
+    TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
+
+    // create TensorHandle without memory managed
+    auto handle = handleFactory.CreateTensorHandle(info, false);
+    handle->Manage();
+    handle->Allocate();
+    memoryManager->Acquire();
+
+    // No buffer allocated when import is enabled
+    BOOST_CHECK_THROW(handle->Map(), armnn::NullPointerException);
+
+    float testPtr[2] = { 2.5f, 5.5f };
+    // Correctly import
+    BOOST_CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+    float* buffer = reinterpret_cast<float*>(handle->Map());
+    BOOST_CHECK(buffer != nullptr); // Yields a valid pointer after import
+    BOOST_CHECK(buffer == testPtr); // buffer is pointing to testPtr
+    // Memory is writable and readable with correct value
+    BOOST_CHECK(buffer[0] == 2.5f);
+    BOOST_CHECK(buffer[1] == 5.5f);
+    buffer[0] = 3.5f;
+    buffer[1] = 10.0f;
+    BOOST_CHECK(buffer[0] == 3.5f);
+    BOOST_CHECK(buffer[1] == 10.0f);
+    memoryManager->Release();
+}
+
+BOOST_AUTO_TEST_CASE(RefTensorHandleImport)
+{
+    TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+
+    handle.Manage();
+    handle.Allocate();
+
+    // No buffer allocated when import is enabled
+    BOOST_CHECK_THROW(handle.Map(), armnn::NullPointerException);
+
+    float testPtr[2] = { 2.5f, 5.5f };
+    // Correctly import
+    BOOST_CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+    float* buffer = reinterpret_cast<float*>(handle.Map());
+    BOOST_CHECK(buffer != nullptr); // Yields a valid pointer after import
+    BOOST_CHECK(buffer == testPtr); // buffer is pointing to testPtr
+    // Memory is writable and readable with correct value
+    BOOST_CHECK(buffer[0] == 2.5f);
+    BOOST_CHECK(buffer[1] == 5.5f);
+    buffer[0] = 3.5f;
+    buffer[1] = 10.0f;
+    BOOST_CHECK(buffer[0] == 3.5f);
+    BOOST_CHECK(buffer[1] == 10.0f);
+}
+
 #if !defined(__ANDROID__)
 // Only run these tests on non Android platforms
 BOOST_AUTO_TEST_CASE(CheckSourceType)
 {
-    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
-
     TensorInfo info({1}, DataType::Float32);
-    RefTensorHandle handle(info, memoryManager, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
 
     int* testPtr = new int(4);
 
@@ -70,10 +164,8 @@ BOOST_AUTO_TEST_CASE(CheckSourceType)
 
 BOOST_AUTO_TEST_CASE(ReusePointer)
 {
-    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
-
     TensorInfo info({1}, DataType::Float32);
-    RefTensorHandle handle(info, memoryManager, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
 
     int* testPtr = new int(4);
 
@@ -87,10 +179,8 @@ BOOST_AUTO_TEST_CASE(ReusePointer)
 
 BOOST_AUTO_TEST_CASE(MisalignedPointer)
 {
-    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
-
     TensorInfo info({2}, DataType::Float32);
-    RefTensorHandle handle(info, memoryManager, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
 
     // Allocate a 2 int array
     int* testPtr = new int[2];