Code for generic tensor shape in acl/src/shape.cpp (#1404)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 31 May 2018 05:26:22 +0000 (14:26 +0900)
committer서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Thu, 31 May 2018 05:26:22 +0000 (14:26 +0900)
Parent issue: #1402

This code follows the explanation in https://arm-software.github.io/ComputeLibrary/latest/architecture.xhtml
```
Tensors are defined by a DataType plus a number of channels (Always expected to be 1 for now)
their dimensions are expressed as [width, height, feature_maps, batch].
In other words, the lower three dimensions of a tensor specify a single input in
[width, height, feature_maps], while any other specified dimension represents a batch
in the appropriate dimension space. For example, a tensor with dimensions [128, 128, 64, 16]
represents a 1D batch space with 16 batches of 128 elements in width and height
and 64 feature maps each.
```

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
libs/kernel/acl/src/Mul.h
libs/kernel/acl/src/Mul.test.data.h
libs/kernel/acl/src/Mul.test.h
libs/kernel/acl/src/shape.cpp

index 7eda5b3..376bac7 100644 (file)
@@ -52,9 +52,9 @@ bool mulFloat32(const float *inputData1, const nnfw::rt::Shape &inputShape1,
                 const float *inputData2, const nnfw::rt::Shape &inputShape2, int32_t activation,
                 float *outputData, const nnfw::rt::Shape &outputShape, sync_scheduler_f sync_func)
 {
-  auto input_shape1 = util::fromNNShape(inputShape1);
-  auto input_shape2 = util::fromNNShape(inputShape2);
-  auto output_shape = util::fromNNShape(outputShape);
+  auto input_shape1 = util::fromNNShape(inputShape1, false);
+  auto input_shape2 = util::fromNNShape(inputShape2, false);
+  auto output_shape = util::fromNNShape(outputShape, false);
 
   TensorT input1(arm_compute::TensorInfo(input_shape1, arm_compute::Format::F32));
   TensorT input2(arm_compute::TensorInfo(input_shape2, arm_compute::Format::F32));
@@ -82,23 +82,23 @@ bool mulFloat32(const float *inputData1, const nnfw::rt::Shape &inputShape1,
   {
     TensorAccess<VectorInputAccessor>(input1.ref(), inputData1, inputShape1);
   }
-  else if ((inputShape1.dimensions.size() <= 3))
+  else if ((inputShape1.dimensions.size() <= 4))
   {
     TensorAccess<MatrixInputAccessor>(input1.ref(), inputData1, inputShape1);
   }
   else
-    assert(inputShape1.dimensions.size() <= 3);
+    assert(inputShape1.dimensions.size() <= 4);
 
   if (inputShape2.dimensions.size() == 1)
   {
     TensorAccess<VectorInputAccessor>(input2.ref(), inputData2, inputShape2);
   }
-  else if ((inputShape2.dimensions.size() <= 3))
+  else if ((inputShape2.dimensions.size() <= 4))
   {
     TensorAccess<MatrixInputAccessor>(input2.ref(), inputData2, inputShape2);
   }
   else
-    assert(inputShape2.dimensions.size() <= 3);
+    assert(inputShape2.dimensions.size() <= 4);
 
   for (const auto &fn : fns)
   {
@@ -111,12 +111,12 @@ bool mulFloat32(const float *inputData1, const nnfw::rt::Shape &inputShape1,
   {
     TensorAccess<VectorOutputAccessor>(output.ref(), outputData, outputShape);
   }
-  else if ((outputShape.dimensions.size() <= 3))
+  else if ((outputShape.dimensions.size() <= 4))
   {
     TensorAccess<MatrixOutputAccessor>(output.ref(), outputData, outputShape);
   }
   else
-    assert(outputShape.dimensions.size() <= 3);
+    assert(outputShape.dimensions.size() <= 4);
 
   return true;
 }
index 9110dd9..617aafc 100644 (file)
@@ -189,6 +189,197 @@ static float expected2[2][4][6] = {
     },
 };
 
+// 3. elementwise-multiplying simple 4d x 1d
+static float x3[3][2][4][6] = {
+    {
+        {
+            {
+                3.4511616, 6.8213983, -3.9032097, 2.7949853, -2.4810624, -5.193684,
+            },
+            {
+                0.08306229, 1.8435066, 0.71155137, 0.57163835, 3.6964777, 0.8721923,
+            },
+            {
+                -1.3706003, -2.6251526, 6.111269, 3.9835145, 4.6476684, 1.7110837,
+            },
+            {
+                0.20046586, -9.296765, -0.38201705, -6.524978, -3.4010968, 0.8397062,
+            },
+        },
+        {
+            {
+                -1.1077878, -1.912447, 3.370302, -10.548304, -13.630229, 5.6805444,
+            },
+            {
+                0.1377167, 6.2926893, -1.3399599, -2.8300138, 4.136174, 5.701481,
+            },
+            {
+                1.8934447, -4.3057623, 5.4859633, 6.9906974, -2.743602, 0.0060951854,
+            },
+            {
+                5.2084804, 0.7307493, 0.041380864, 3.9017618, -2.9675317, 0.893882,
+            },
+        },
+    },
+    {
+        {
+            {
+                2.7736564, -4.883692, 2.724194, 3.2103822, -9.412777, -0.9386832,
+            },
+            {
+                -3.0339835, -8.912085, -8.830975, -2.013668, -2.9297779, 2.408302,
+            },
+            {
+                -2.0810814, -1.8345542, -2.1508193, -4.6043878, 3.0493782, 2.9568095,
+            },
+            {
+                -4.0822353, -4.3395967, 4.084664, 5.4317946, 4.325478, 3.6764784,
+            },
+        },
+        {
+            {
+                -2.5416138, 7.879944, -2.0247207, -1.1500132, -4.063577, 0.99201775,
+            },
+            {
+                -1.0726405, -6.3343916, 8.285111, 0.8598841, 1.5183163, 7.9626045,
+            },
+            {
+                -6.54306, 5.261826, 6.2108326, 0.6986546, 7.9932504, -8.734413,
+            },
+            {
+                -2.4487484, 4.8250856, -6.518466, -2.252397, 3.8628614, -1.763003,
+            },
+        },
+    },
+    {
+        {
+            {
+                -4.7506614, -6.3858204, 1.1295259, -9.164337, -4.920489, -4.547884,
+            },
+            {
+                12.074501, -1.0219653, 2.562501, 4.4603024, 11.272025, 1.2186266,
+            },
+            {
+                0.8928604, 5.2597437, -3.8921394, -1.4161685, -1.9687729, -3.1143188,
+            },
+            {
+                -1.3726944, -2.17756, 3.4230003, 2.4563243, -6.8160734, -1.6609626,
+            },
+        },
+        {
+            {
+                4.12099, 6.351284, -3.1314368, 2.039052, -5.265438, 0.085810244,
+            },
+            {
+                -2.081704, -2.0682046, -8.809668, -1.9327109, 11.162933, -5.473809,
+            },
+            {
+                6.6081295, -4.427154, -3.326314, 4.311129, -0.778096, -5.855744,
+            },
+            {
+                7.0940695, -5.3400326, -0.1266769, 0.20553468, 1.2000599, 0.098438516,
+            },
+        },
+    },
+};
+
+static float y3[6] = {
+    2.4239943, -2.4815967, -2.246438, 0.35926288, -0.39192855, 9.577583,
+};
+
+static float expected3[3][2][4][6] = {
+    {
+        {
+            {
+                8.365596, -16.92796, 8.768319, 1.0041345, 0.9723992, -49.742943,
+            },
+            {
+                0.20134252, -4.57484, -1.598456, 0.20536844, -1.4487551, 8.353495,
+            },
+            {
+                -3.3223274, 6.51457, -13.728587, 1.431129, -1.821554, 16.388046,
+            },
+            {
+                0.4859281, 23.070822, 0.8581776, -2.3441825, 1.332987, 8.042356,
+            },
+        },
+        {
+            {
+                -2.6852715, 4.745922, -7.5711746, -3.789614, 5.342076, 54.405888,
+            },
+            {
+                0.3338245, -15.615917, 3.0101368, -1.0167189, -1.6210848, 54.606407,
+            },
+            {
+                4.5896993, 10.685165, -12.323876, 2.5114982, 1.0752959, 0.058377147,
+            },
+            {
+                12.625327, -1.8134251, -0.092959546, 1.4017582, 1.1630604, 8.561229,
+            },
+        },
+    },
+    {
+        {
+            {
+                6.723327, 12.119353, -6.1197333, 1.1533712, 3.689136, -8.990316,
+            },
+            {
+                -7.3543587, 22.1162, 19.838238, -0.7234362, 1.1482636, 23.065714,
+            },
+            {
+                -5.0445294, 4.5526237, 4.831682, -1.6541857, -1.1951383, 28.31909,
+            },
+            {
+                -9.895315, 10.769129, -9.175944, 1.9514422, -1.6952784, 35.211777,
+            },
+        },
+        {
+            {
+                -6.160857, -19.554842, 4.5484095, -0.41315708, 1.5926319, 9.501133,
+            },
+            {
+                -2.6000745, 15.719405, -18.61199, 0.30892444, -0.5950715, 76.262505,
+            },
+            {
+                -15.86034, -13.05773, -13.9522505, 0.25100067, -3.1327832, -83.65457,
+            },
+            {
+                -5.935752, -11.973917, 14.64333, -0.8092027, -1.5139657, -16.885307,
+            },
+        },
+    },
+    {
+        {
+            {
+                -11.515576, 15.847031, -2.53741, -3.292406, 1.92848, -43.55774,
+            },
+            {
+                29.268522, 2.5361056, -5.7565, 1.602421, -4.4178286, 11.671498,
+            },
+            {
+                2.1642885, -13.052563, 8.74345, -0.5087768, 0.7716183, -29.827648,
+            },
+            {
+                -3.3274033, 5.4038258, -7.689558, 0.88246614, 2.671414, -15.908008,
+            },
+        },
+        {
+            {
+                9.989256, -15.761326, 7.034579, 0.7325557, 2.0636756, 0.82185477,
+            },
+            {
+                -5.0460386, 5.1324496, 19.790373, -0.6943513, -4.3750725, -52.42586,
+            },
+            {
+                16.018068, 10.986411, 7.472358, 1.5488287, 0.30495805, -56.083874,
+            },
+            {
+                17.195984, 13.251807, 0.2845718, 0.07384098, -0.47033775, 0.9428031,
+            },
+        },
+    },
+};
+
 } // end of data
 } // end of namespace elementwise_mul_test
 } // end of namespace nnfw
index 1bf0b37..303ad7d 100644 (file)
@@ -81,7 +81,7 @@ ACL_TEST(KernelACL_TC, mulFloat32_2x4x6_2x4x6)
   EXPECT_EQ(bret, true);
 }
 
-// when the sape of a, b of Mul(a, b) are different
+// when the shape of a, b of Mul(a, b) are different
 
 // Note: neon/Mul.test.h fails with this test.
 // Unlike cl/Mul.test.h, arm_compute::NEPixelWiseMultiplication.config() in neon/Mul.test.h fails
@@ -133,4 +133,25 @@ ACL_TEST(KernelACL_TC, mulFloat32_6x1x4_4)
   EXPECT_EQ(bret, true);
 }
 
+ACL_TEST(KernelACL_TC, mulFloat32_3x2x4x6_6)
+{
+  const nnfw::rt::Shape x3Shape = {OperandType::FLOAT32, {3, 2, 4, 6}, 1.0, 0};
+  const nnfw::rt::Shape y3Shape = {OperandType::FLOAT32, {6}, 1.0, 0};
+
+  float actual[3][2][4][6];
+  const nnfw::rt::Shape actualShape = {OperandType::FLOAT32, {3, 2, 4, 6}, 1.0, 0};
+  bool bret;
+
+  util::initData((float *)actual, sizeof(actual) / sizeof(actual[0]), 0.0);
+
+  bret = ACL_CORE_FUNC_NAME((float *)td::x3, x3Shape, (float *)td::y3, y3Shape,
+                            static_cast<int32_t>(FusedActivationFunc::NONE), (float *)actual,
+                            actualShape);
+
+  EXPECT_EQ(bret, true);
+
+  bret = util::compareData((float *)actual, (float *)td::expected3, actualShape);
+  EXPECT_EQ(bret, true);
+}
+
 #endif // GTEST_EXCLUDE_TEST
index a566a02..1b45c3a 100644 (file)
@@ -61,24 +61,38 @@ arm_compute::TensorShape fromMatrixNNShape(const nnfw::rt::Shape &shape)
 {
   assert(shape.dimensions.size() <= 4);
 
+  // in https://arm-software.github.io/ComputeLibrary/latest/architecture.xhtml
+  // sample code was written like "const TensorShape shape(width, height, batch);"
+  // also the above site mensions,
+  // "Tensors are defined by a DataType plus a number of channels (Always expected to be 1 for now)
+  // and their dimensions are expressed as [width, height, feature_maps, batch].
+  // In other words, the lower three dimensions of a tensor specify a single input in
+  // [width, height, feature_maps], while any other specified dimension represents a batch
+  // in the appropriate dimension space. For example, a tensor with dimensions [128, 128, 64, 16]
+  // represents a 1D batch space with 16 batches of 128 elements in width and height
+  // and 64 feature maps each.
+
   if (shape.dimensions.size() == 2)
   {
-    const uint32_t n = nnfw::rt::getSizeOfDimension(shape, 0);
-    const uint32_t c = nnfw::rt::getSizeOfDimension(shape, 1);
+    const uint32_t h = nnfw::rt::getSizeOfDimension(shape, 0);
+    const uint32_t w = nnfw::rt::getSizeOfDimension(shape, 1);
 
-    return arm_compute::TensorShape(c, n);
+    return arm_compute::TensorShape(w, h);
   }
   else if (shape.dimensions.size() == 3)
   {
-    return arm_compute::TensorShape(nnfw::rt::getSizeOfDimension(shape, 2),
-                                    nnfw::rt::getSizeOfDimension(shape, 1),
-                                    nnfw::rt::getSizeOfDimension(shape, 0));
+    const uint32_t w = nnfw::rt::getSizeOfDimension(shape, 2);
+    const uint32_t h = nnfw::rt::getSizeOfDimension(shape, 1);
+    const uint32_t feature_maps = nnfw::rt::getSizeOfDimension(shape, 0);
+    return arm_compute::TensorShape(w, h, feature_maps);
   }
   else if (shape.dimensions.size() == 4)
   {
-    return arm_compute::TensorShape(
-        nnfw::rt::getSizeOfDimension(shape, 3), nnfw::rt::getSizeOfDimension(shape, 2),
-        nnfw::rt::getSizeOfDimension(shape, 1), nnfw::rt::getSizeOfDimension(shape, 0));
+    const uint32_t batch = nnfw::rt::getSizeOfDimension(shape, 0);
+    const uint32_t feature_maps = nnfw::rt::getSizeOfDimension(shape, 1);
+    const uint32_t h = nnfw::rt::getSizeOfDimension(shape, 2);
+    const uint32_t w = nnfw::rt::getSizeOfDimension(shape, 3);
+    return arm_compute::TensorShape(w, h, feature_maps, batch);
   }
 }