Correction in SpaceToBatch CPU version for quant8 datatype (#3686)
authorShubham Gupta/SNAP /SRI-Bangalore/Engineer/삼성전자 <shub98.gupta@samsung.com>
Fri, 30 Nov 2018 02:07:32 +0000 (07:37 +0530)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 30 Nov 2018 02:07:32 +0000 (11:07 +0900)
This patch will set correct zero_value for quant8 datatype in spaceToBatch.

Signed-off-by: shubham <shub98.gupta@samsung.com>
runtimes/pure_arm_compute/src/internal/layers/SimpleSpaceToBatchND.cc

index 2b221c8..4b040fa 100644 (file)
@@ -46,7 +46,7 @@ inline void
 SpaceToBatchND(const ::arm_compute::ITensor *input, const ::arm_compute::TensorShape &input_shape,
                const ::arm_compute::ITensor *block_size, const ::arm_compute::ITensor *padding_size,
                const ::arm_compute::ITensor *output, const ::arm_compute::TensorShape &output_shape,
-               const ::arm_compute::Coordinates &axises)
+               const ::arm_compute::Coordinates &axises, T zero_value)
 {
   const int input_batch = input_shape[axises[0]];
   const int input_height = input_shape[axises[1]];
@@ -97,7 +97,7 @@ SpaceToBatchND(const ::arm_compute::ITensor *input, const ::arm_compute::TensorS
           if (in_h < padding_height_left || in_h >= (input_height + padding_height_left) ||
               in_w < padding_width_left || in_w >= (input_width + padding_width_left))
           {
-            *reinterpret_cast<T *>(output->ptr_to_element(output_id)) = 0;
+            *reinterpret_cast<T *>(output->ptr_to_element(output_id)) = zero_value;
           }
           else
           {
@@ -130,11 +130,12 @@ void SimpleSpaceToBatchND::run()
     case ::arm_compute::DataType::U8:
     case ::arm_compute::DataType::QASYMM8:
       SpaceToBatchND<uint8_t>(_input, _input->info()->tensor_shape(), _block_size, _padding_size,
-                              _output, _output->info()->tensor_shape(), _axises);
+                              _output, _output->info()->tensor_shape(), _axises,
+                              _input->info()->quantization_info().offset);
       break;
     case ::arm_compute::DataType::F32:
       SpaceToBatchND<float>(_input, _input->info()->tensor_shape(), _block_size, _padding_size,
-                            _output, _output->info()->tensor_shape(), _axises);
+                            _output, _output->info()->tensor_shape(), _axises, 0.0f);
       break;
     default:
       ARM_COMPUTE_ERROR("DataType not supported");