Implement CAST operation for TENSOR_QUANT8_ASYMM in CPU fallback (#1609)
author서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 11 Jun 2018 07:48:01 +0000 (16:48 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 11 Jun 2018 07:48:01 +0000 (16:48 +0900)
For issue #1431 and #1199

This patch implements CAST operation for TENSOR_QUANT8_ASYMM in CPU
fallback.  This implementation will be used for comparison with that of
the Pure CL runtime.

Signed-off-by: Sangmin Seo <sangmin7.seo@samsung.com>
runtimes/nn/common/CpuExecutor.cpp
runtimes/nn/common/OperationsUtils.cpp
runtimes/nn/common/include/OperationsUtils.h
runtimes/nn/common/include/Utils.h

index 8f9061d..6c97a63 100644 (file)
@@ -1416,7 +1416,11 @@ int CpuExecutor::executeOperation(const Operation &operation)
           success = castToOperand(reinterpret_cast<const int32_t *>(input.buffer), output.type,
                                   output.buffer, numberOfElements);
           break;
-        // TODO-NNRT implement other operands. (ex: TENSOR_QUANT8_ASYMM)
+        case OperandType::TENSOR_QUANT8_ASYMM:
+          success =
+              castQuant8ToOperand(reinterpret_cast<const uint8_t *>(input.buffer), input.scale,
+                                  input.zeroPoint, output.type, output.buffer, numberOfElements);
+          break;
         default:
           // Unsupported type.
           LOG(ERROR) << getOperandTypeName(input.type) << " is unsupported type.";
index 7b60f30..e08527c 100644 (file)
@@ -758,5 +758,39 @@ bool stridedSlicePrepare(const Shape &input, const int32_t *beginData, const Sha
   return true;
 }
 
+bool castQuant8ToOperand(const uint8_t *inputData, float scale, int32_t zeroPoint,
+                         const OperandType outType, uint8_t *outputData, int numElements)
+{
+  switch (outType)
+  {
+    case OperandType::FLOAT32:
+      copyQuant8Cast(inputData, inputData + 1, scale, zeroPoint,
+                     reinterpret_cast<float *>(outputData));
+      break;
+    case OperandType::INT32:
+      copyQuant8Cast(inputData, inputData + 1, scale, zeroPoint,
+                     reinterpret_cast<int32_t *>(outputData));
+      break;
+    case OperandType::UINT32:
+      copyQuant8Cast(inputData, inputData + 1, scale, zeroPoint,
+                     reinterpret_cast<uint32_t *>(outputData));
+      break;
+    case OperandType::TENSOR_FLOAT32:
+      copyQuant8Cast(inputData, inputData + numElements, scale, zeroPoint,
+                     reinterpret_cast<float *>(outputData));
+      break;
+    case OperandType::TENSOR_INT32:
+      copyQuant8Cast(inputData, inputData + numElements, scale, zeroPoint,
+                     reinterpret_cast<int32_t *>(outputData));
+      break;
+    // TODO-NNRT implement other operands. (ex: TENSOR_QUANT8_ASYMM)
+    default:
+      // Unsupported type.
+      LOG(ERROR) << getOperandTypeName(outType) << "is unsupported type.";
+      return false;
+  }
+  return true;
+}
+
 } // namespace rt
 } // namespace nnfw
index 682aeb9..74601e8 100644 (file)
@@ -223,6 +223,10 @@ bool castToOperand(const FromT *inputData, const OperandType outType, uint8_t *o
   return true;
 }
 
+// Cast operation from TENSOR_QUANT8_ASYMM to outType
+bool castQuant8ToOperand(const uint8_t *inputData, float scale, int32_t zeroPoint,
+                         const OperandType outType, uint8_t *outputData, int numElements);
+
 // TODO: add more documentation from upstream.
 // Reverse order of bits in the mask to match the expected order in kernel
 inline int ReverseMaskBits(int mask, int num_dimensions)
index caa4735..8d09bcd 100644 (file)
@@ -155,6 +155,15 @@ void copyCast(const FromT *inFirst, const FromT *inLast, ToT *out)
   std::transform(inFirst, inLast, out, [](FromT a) { return static_cast<ToT>(a); });
 }
 
+template <typename ToT>
+void copyQuant8Cast(const uint8_t *inFirst, const uint8_t *inLast, float scale, int32_t zeroPoint,
+                    ToT *out)
+{
+  std::transform(inFirst, inLast, out, [scale, zeroPoint](uint8_t a) {
+    return static_cast<ToT>((float)(a - zeroPoint) * scale);
+  });
+}
+
 } // namespace rt
 } // namespace nnfw