From d6568772be726acd7dc2fc3e592f101c77d690a6 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 22 Jul 2020 12:46:51 +0100 Subject: [PATCH] IVGCVSW-5010 Add GetCapabilities to ITensorHandleFactory Signed-off-by: Narumol Prangnawarat Change-Id: Ie8acb9c729af4f95488aecf795f45ff12364f9ca --- include/armnn/backends/ITensorHandleFactory.hpp | 33 ++++++++++++++++++++++ .../reference/test/RefTensorHandleTests.cpp | 17 +++++++++++ 2 files changed, 50 insertions(+) diff --git a/include/armnn/backends/ITensorHandleFactory.hpp b/include/armnn/backends/ITensorHandleFactory.hpp index cd094d2..9d8f0cd 100644 --- a/include/armnn/backends/ITensorHandleFactory.hpp +++ b/include/armnn/backends/ITensorHandleFactory.hpp @@ -15,6 +15,29 @@ namespace armnn { +/// Capability class to calculate in the GetCapabilities function +/// so that only the capability in the scope can be choose to calculate +enum class CapabilityClass +{ + PaddingRequired = 1, + + // add new enum values here + + CapabilityClassMax = 254 +}; + +/// Capability of the TensorHandleFactory +struct Capability +{ + Capability(CapabilityClass capabilityClass, bool value) + : m_CapabilityClass(capabilityClass) + , m_Value(value) + {} + + CapabilityClass m_CapabilityClass; + bool m_Value; +}; + class ITensorHandleFactory { public: @@ -59,6 +82,16 @@ public: virtual MemorySourceFlags GetExportFlags() const { return 0; } virtual MemorySourceFlags GetImportFlags() const { return 0; } + + virtual std::vector GetCapabilities(const IConnectableLayer* layer, + const IConnectableLayer* connectedLayer, + CapabilityClass capabilityClass) + { + IgnoreUnused(layer); + IgnoreUnused(connectedLayer); + IgnoreUnused(capabilityClass); + return std::vector(); + } }; enum class EdgeStrategy diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp index 42f5664..3635a32 100644 --- a/src/backends/reference/test/RefTensorHandleTests.cpp +++ b/src/backends/reference/test/RefTensorHandleTests.cpp @@ -141,6 +141,23 @@ BOOST_AUTO_TEST_CASE(RefTensorHandleImport) BOOST_CHECK(buffer[1] == 10.0f); } +BOOST_AUTO_TEST_CASE(RefTensorHandleGetCapabilities) +{ + std::shared_ptr memoryManager = std::make_shared(); + RefTensorHandleFactory handleFactory(memoryManager); + + // Builds up the structure of the network. + INetworkPtr net(INetwork::Create()); + IConnectableLayer* input = net->AddInputLayer(0); + IConnectableLayer* output = net->AddOutputLayer(0); + input->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + std::vector capabilities = handleFactory.GetCapabilities(input, + output, + CapabilityClass::PaddingRequired); + BOOST_CHECK(capabilities.empty()); +} + #if !defined(__ANDROID__) // Only run these tests on non Android platforms BOOST_AUTO_TEST_CASE(CheckSourceType) -- 2.7.4