IVGCVSW-3962 Return 0 for Neon GetExportFlags()
authorJames Conroy <james.conroy@arm.com>
Fri, 25 Oct 2019 08:44:14 +0000 (09:44 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Fri, 25 Oct 2019 12:03:51 +0000 (12:03 +0000)
* Fixes issue where MemImport workload was being
  inserted into a graph when changing from a NEON
  to Ref workload. A MemCopy will now be performed
  instead.
* Improves existing ImportAlignedPointerTest by
  adding check for expected output.

Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I606dbbe0166731c62fbe4cc1966c558ade66d6bb

src/backends/backendsCommon/test/EndToEndTestImpl.hpp
src/backends/neon/NeonTensorHandleFactory.cpp
src/backends/neon/NeonTensorHandleFactory.hpp

index efaffb9b67f490fafea9b93377fbca52a2054a43..ee9d2bc026b778178f74d20f9e3cf47efc64f2cf 100644 (file)
@@ -369,6 +369,11 @@ inline void ImportAlignedPointerTest(std::vector<BackendId> backends)
 
     std::vector<float> outputData(4);
 
+    std::vector<float> expectedOutput
+    {
+        1.0f, 4.0f, 9.0f, 16.0f
+    };
+
     InputTensors inputTensors
     {
         {0,armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())},
@@ -378,8 +383,6 @@ inline void ImportAlignedPointerTest(std::vector<BackendId> backends)
         {0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
     };
 
-    // The result of the inference is not important, just the fact that there
-    // should not be CopyMemGeneric workloads.
     runtime->GetProfiler(netId)->EnableProfiling(true);
 
     // Do the inference
@@ -394,12 +397,17 @@ inline void ImportAlignedPointerTest(std::vector<BackendId> backends)
     // Contains ActivationWorkload
     std::size_t found = dump.find("ActivationWorkload");
     BOOST_TEST(found != std::string::npos);
+
     // Contains SyncMemGeneric
     found = dump.find("SyncMemGeneric");
     BOOST_TEST(found != std::string::npos);
+
     // Does not contain CopyMemGeneric
     found = dump.find("CopyMemGeneric");
     BOOST_TEST(found == std::string::npos);
+
+    // Check output is as expected
+    BOOST_TEST(outputData == expectedOutput);
 }
 
 inline void ImportOnlyWorkload(std::vector<BackendId> backends)
index 4ccbb7b64f06f1545d025e90f2c53459aa53b3a9..80f46d2237437b17bbbac2e22625949f920f74b3 100644 (file)
@@ -60,7 +60,7 @@ std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const
     }
     // If we are not Managing the Memory then we must be importing
     tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
-    tensorHandle->SetImportFlags(m_ImportFlags);
+    tensorHandle->SetImportFlags(GetImportFlags());
 
     return tensorHandle;
 }
@@ -76,7 +76,7 @@ std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const
     }
     // If we are not Managing the Memory then we must be importing
     tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
-    tensorHandle->SetImportFlags(m_ImportFlags);
+    tensorHandle->SetImportFlags(GetImportFlags());
 
     return tensorHandle;
 }
@@ -99,12 +99,12 @@ bool NeonTensorHandleFactory::SupportsSubTensors() const
 
 MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
 {
-    return m_ExportFlags;
+    return 0;
 }
 
 MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
 {
-    return m_ImportFlags;
+    return static_cast<MemorySourceFlags>(MemorySource::Malloc);
 }
 
 } // namespace armnn
index d9b64045e6e865d9c6e4c15b177ea6e1dba49df6..8a8ac5cdcba9c18394b6c5862c74b8a9fc7f87cc 100644 (file)
@@ -17,9 +17,7 @@ class NeonTensorHandleFactory : public ITensorHandleFactory
 {
 public:
     NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
-                            : m_MemoryManager(mgr),
-                              m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
-                              m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
+                            : m_MemoryManager(mgr)
     {}
 
     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
@@ -50,8 +48,6 @@ public:
 
 private:
     mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
-    MemorySourceFlags m_ImportFlags;
-    MemorySourceFlags m_ExportFlags;
 };
 
 } // namespace armnn