IVGCVSW-5085 Updates to CL and NEON TensorHandleFactory
authorDavid Monahan <david.monahan@arm.com>
Tue, 21 Jul 2020 10:16:51 +0000 (11:16 +0100)
committerDavid Monahan <david.monahan@arm.com>
Wed, 29 Jul 2020 16:49:20 +0000 (16:49 +0000)
 * Update the CL and Neon TensorHandleFactories to not use SubTensors if
   Axis is on x or y

Signed-off-by: David Monahan <david.monahan@arm.com>
Change-Id: I782b89f50a92b21fdcbe68dab0281ad265fb3b63

src/backends/cl/ClTensorHandleFactory.cpp
src/backends/neon/NeonTensorHandleFactory.cpp

index 8af97f4..e92913f 100644 (file)
@@ -36,6 +36,18 @@ std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITen
 
     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(
             parent.GetShape());
+
+    // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
+    // must match the parent shapes
+    if (coords.x() != 0 || coords.y() != 0)
+    {
+        return nullptr;
+    }
+    if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
+    {
+        return nullptr;
+    }
+
     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
     {
         return nullptr;
index ec9e063..4e013a3 100644 (file)
@@ -33,6 +33,18 @@ std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(IT
     }
 
     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
+
+    // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
+    // must match the parent shapes
+    if (coords.x() != 0 || coords.y() != 0)
+    {
+        return nullptr;
+    }
+    if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
+    {
+        return nullptr;
+    }
+
     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
     {
         return nullptr;