From c6e5a6e9f146ecb95704d6fa80fae8465241f09e Mon Sep 17 00:00:00 2001 From: David Monahan Date: Wed, 2 Oct 2019 09:33:57 +0100 Subject: [PATCH] IVGCVSW-3925 Add Backward compatibility for ITensorHandle CreateTensorHandle functions Change-Id: I940b7ca706c9a8bc38743176eb7959aa629a6876 Signed-off-by: David Monahan --- src/armnn/test/TensorHandleStrategyTest.cpp | 12 ++++------- .../backendsCommon/ITensorHandleFactory.hpp | 23 ++++++++++++++++++++-- src/backends/cl/ClTensorHandleFactory.cpp | 11 +++++++++++ src/backends/cl/ClTensorHandleFactory.hpp | 5 +++++ src/backends/neon/NeonTensorHandleFactory.cpp | 11 +++++++++++ src/backends/neon/NeonTensorHandleFactory.hpp | 5 +++++ src/backends/reference/RefTensorHandleFactory.cpp | 9 +++------ src/backends/reference/RefTensorHandleFactory.hpp | 6 ++---- 8 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index ceb6e4d..3c53b13 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -45,15 +45,13 @@ public: return nullptr; } - std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const override + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override { return nullptr; } std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const override + DataLayout dataLayout) const override { return nullptr; } @@ -85,15 +83,13 @@ public: return nullptr; } - std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const override + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override { return nullptr; } std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const override + DataLayout dataLayout) const override { return nullptr; } diff --git a/src/backends/backendsCommon/ITensorHandleFactory.hpp b/src/backends/backendsCommon/ITensorHandleFactory.hpp index c6deaef..2e47423 100644 --- a/src/backends/backendsCommon/ITensorHandleFactory.hpp +++ b/src/backends/backendsCommon/ITensorHandleFactory.hpp @@ -8,6 +8,9 @@ #include #include #include +#include "ITensorHandle.hpp" + +#include namespace armnn { @@ -25,12 +28,28 @@ public: TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const = 0; + virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const = 0; + + virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const = 0; + + // Utility Functions for backends which require TensorHandles to have unmanaged memory. + // These should be overloaded if required to facilitate direct import of input tensors + // and direct export of output tensors. virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged = true) const = 0; + const bool IsMemoryManaged) const + { + boost::ignore_unused(IsMemoryManaged); + return CreateTensorHandle(tensorInfo); + } virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, - const bool IsMemoryManaged = true) const = 0; + const bool IsMemoryManaged) const + { + boost::ignore_unused(IsMemoryManaged); + return CreateTensorHandle(tensorInfo, dataLayout); + } virtual const FactoryId& GetId() const = 0; diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp index 3d9908a..9df3f1a 100644 --- a/src/backends/cl/ClTensorHandleFactory.cpp +++ b/src/backends/cl/ClTensorHandleFactory.cpp @@ -45,6 +45,17 @@ std::unique_ptr ClTensorHandleFactory::CreateSubTensorHandle(ITen boost::polymorphic_downcast(&parent), shape, coords); } +std::unique_ptr ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const +{ + return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, true); +} + +std::unique_ptr ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const +{ + return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true); +} + std::unique_ptr ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { diff --git a/src/backends/cl/ClTensorHandleFactory.hpp b/src/backends/cl/ClTensorHandleFactory.hpp index ea3728f..f0d427a 100644 --- a/src/backends/cl/ClTensorHandleFactory.hpp +++ b/src/backends/cl/ClTensorHandleFactory.hpp @@ -28,6 +28,11 @@ public: const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override; + + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const override; + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged = true) const override; diff --git a/src/backends/neon/NeonTensorHandleFactory.cpp b/src/backends/neon/NeonTensorHandleFactory.cpp index 8296b83..4ccbb7b 100644 --- a/src/backends/neon/NeonTensorHandleFactory.cpp +++ b/src/backends/neon/NeonTensorHandleFactory.cpp @@ -39,6 +39,17 @@ std::unique_ptr NeonTensorHandleFactory::CreateSubTensorHandle(IT boost::polymorphic_downcast(&parent), shape, coords); } +std::unique_ptr NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const +{ + return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true); +} + +std::unique_ptr NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const +{ + return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true); +} + std::unique_ptr NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { diff --git a/src/backends/neon/NeonTensorHandleFactory.hpp b/src/backends/neon/NeonTensorHandleFactory.hpp index b034333..d9b6404 100644 --- a/src/backends/neon/NeonTensorHandleFactory.hpp +++ b/src/backends/neon/NeonTensorHandleFactory.hpp @@ -26,6 +26,11 @@ public: const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override; + + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const override; + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged = true) const override; diff --git a/src/backends/reference/RefTensorHandleFactory.cpp b/src/backends/reference/RefTensorHandleFactory.cpp index 089f5e3..c97a779 100644 --- a/src/backends/reference/RefTensorHandleFactory.cpp +++ b/src/backends/reference/RefTensorHandleFactory.cpp @@ -27,18 +27,15 @@ std::unique_ptr RefTensorHandleFactory::CreateSubTensorHandle(ITe return nullptr; } -std::unique_ptr RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const +std::unique_ptr RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const { - boost::ignore_unused(IsMemoryManaged); return std::make_unique(tensorInfo, m_MemoryManager, m_ImportFlags); } std::unique_ptr RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const + DataLayout dataLayout) const { - boost::ignore_unused(dataLayout, IsMemoryManaged); + boost::ignore_unused(dataLayout); return std::make_unique(tensorInfo, m_MemoryManager, m_ImportFlags); } diff --git a/src/backends/reference/RefTensorHandleFactory.hpp b/src/backends/reference/RefTensorHandleFactory.hpp index ca6af72..220e6fd 100644 --- a/src/backends/reference/RefTensorHandleFactory.hpp +++ b/src/backends/reference/RefTensorHandleFactory.hpp @@ -28,12 +28,10 @@ public: TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const override; - std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged = true) const override; + std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged = true) const override; + DataLayout dataLayout) const override; static const FactoryId& GetIdStatic(); -- 2.7.4