IVGCVSW-5010 Add GetCapabilities to ITensorHandleFactory
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Wed, 22 Jul 2020 11:46:51 +0000 (12:46 +0100)
committerTeresaARM <teresa.charlinreyes@arm.com>
Thu, 23 Jul 2020 21:17:37 +0000 (21:17 +0000)
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ie8acb9c729af4f95488aecf795f45ff12364f9ca

include/armnn/backends/ITensorHandleFactory.hpp
src/backends/reference/test/RefTensorHandleTests.cpp

index cd094d2..9d8f0cd 100644 (file)
 namespace armnn
 {
 
+/// Capability class to calculate in the GetCapabilities function
+/// so that only the capability in the scope can be choose to calculate
+enum class CapabilityClass
+{
+    PaddingRequired = 1,
+
+    // add new enum values here
+
+    CapabilityClassMax = 254
+};
+
+/// Capability of the TensorHandleFactory
+struct Capability
+{
+    Capability(CapabilityClass capabilityClass, bool value)
+        : m_CapabilityClass(capabilityClass)
+        , m_Value(value)
+    {}
+
+    CapabilityClass m_CapabilityClass;
+    bool            m_Value;
+};
+
 class ITensorHandleFactory
 {
 public:
@@ -59,6 +82,16 @@ public:
 
     virtual MemorySourceFlags GetExportFlags() const { return 0; }
     virtual MemorySourceFlags GetImportFlags() const { return 0; }
+
+    virtual std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
+                                                    const IConnectableLayer* connectedLayer,
+                                                    CapabilityClass capabilityClass)
+    {
+        IgnoreUnused(layer);
+        IgnoreUnused(connectedLayer);
+        IgnoreUnused(capabilityClass);
+        return std::vector<Capability>();
+    }
 };
 
 enum class EdgeStrategy
index 42f5664..3635a32 100644 (file)
@@ -141,6 +141,23 @@ BOOST_AUTO_TEST_CASE(RefTensorHandleImport)
     BOOST_CHECK(buffer[1] == 10.0f);
 }
 
+BOOST_AUTO_TEST_CASE(RefTensorHandleGetCapabilities)
+{
+    std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
+    RefTensorHandleFactory handleFactory(memoryManager);
+
+    // Builds up the structure of the network.
+    INetworkPtr net(INetwork::Create());
+    IConnectableLayer* input = net->AddInputLayer(0);
+    IConnectableLayer* output = net->AddOutputLayer(0);
+    input->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+    std::vector<Capability> capabilities = handleFactory.GetCapabilities(input,
+                                                                         output,
+                                                                         CapabilityClass::PaddingRequired);
+    BOOST_CHECK(capabilities.empty());
+}
+
 #if !defined(__ANDROID__)
 // Only run these tests on non Android platforms
 BOOST_AUTO_TEST_CASE(CheckSourceType)