COMPMID-3100 Fuse bias addition with fully connected layer NEON
authorSiCong Li <sicong.li@arm.com>
Mon, 17 Feb 2020 16:39:27 +0000 (16:39 +0000)
committerGiorgio Arena <giorgio.arena@arm.com>
Tue, 3 Mar 2020 09:55:55 +0000 (09:55 +0000)
NEGEMM and NEGEMMLowpMatrixMultiplyCore are already fuse with bias
addition. Expose them to NEFullyConnectedLayer.

Change-Id: I42a909565bf49de1a019a07dc4dca11ae0981ada
Signed-off-by: SiCongLi <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2769
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
examples/graph_deepspeech_v0_4_1.cpp
src/runtime/NEON/functions/NEFullyConnectedLayer.cpp

index 78f12daf9cd354321eec759b4ec7eb1cc9d10420..db09da45eec479d50da312c3037ca0facd4e7d81 100644 (file)
 #include "arm_compute/runtime/IFunction.h"
 
 #include "arm_compute/core/NEON/kernels/NEFlattenLayerKernel.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h"
 #include "arm_compute/core/NEON/kernels/NETransposeKernel.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h"
 #include "arm_compute/runtime/NEON/functions/NEGEMM.h"
 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
-#include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
 #include "arm_compute/runtime/Tensor.h"
 
 namespace arm_compute
@@ -107,7 +105,7 @@ private:
  *  -# @ref NEIm2ColKernel (called when the input comes from a convolutional layer)
  *  -# @ref NEFullyConnectedLayerReshapeWeights (if @p are_weights_reshaped is set to false and transpose_weights is set to true ) (called once)
  *  -# @ref NEGEMMMatrixMultiplyKernel or @ref NEGEMMLowpMatrixMultiplyCore (if quantized asymmetric)
- *  -# @ref NEGEMMMatrixAccumulateBiasesKernel or @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if quantized asymmetric) (if @p biases is not equal to nullptr)
+ *  -# @ref NEGEMMMatrixAdditionKernel or @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if quantized asymmetric) (if @p biases is not equal to nullptr)
  *
  * @note  The fully connected layer accepts "weights" tensors only with 2 dimensions.
  */
@@ -164,9 +162,9 @@ public:
     void prepare() override;
 
 private:
-    void configure_fc_fc(const ITensor *input, const ITensor *weights, ITensor *output);
-    void configure_conv_fc(const ITensor *input, const ITensor *weights, ITensor *output);
-    void configure_mm(const ITensor *input, const ITensor *weights, ITensor *output);
+    void configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output);
+    void configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output);
+    void configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output);
 
     MemoryGroup                                                         _memory_group;
     IWeightsManager                                                    *_weights_manager;
@@ -177,17 +175,13 @@ private:
     weights_transformations::NEFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function;
     NEGEMM                                                              _mm_gemm;
     NEGEMMLowpMatrixMultiplyCore                                        _mm_gemmlowp;
-    NEGEMMLowpOutputStage                                               _gemmlowp_output_stage;
-    NEGEMMMatrixAccumulateBiasesKernel                                  _accumulate_biases_kernel;
     Tensor                                                              _flatten_output;
-    Tensor                                                              _gemmlowp_output;
     Tensor                                                              _converted_weights_output;
     Tensor                                                              _reshape_weights_output;
     const ITensor                                                      *_original_weights;
     bool                                                                _are_weights_converted;
     bool                                                                _are_weights_reshaped;
     bool                                                                _is_fc_after_conv;
-    bool                                                                _accumulate_biases;
     bool                                                                _is_quantized;
     bool                                                                _is_prepared;
 };
index d2a4832bd16eed0ccc7484e996011c76972e0b8b..ed44ffbee2fb650b79211e32efb2c68ea26a2bc2 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -57,9 +57,6 @@ public:
             return false;
         }
 
-        // Checks
-        ARM_COMPUTE_EXIT_ON_MSG(arm_compute::is_data_type_quantized_asymmetric(common_params.data_type), "QASYMM8 not supported for this graph");
-
         // Print parameter values
         std::cout << common_params << std::endl;
 
index 92ccd5d1cc03ea495c7a36a06d11129b35acf197..b5f406da8d0a4f49ad1fb634bf637fa99a78a060 100644 (file)
@@ -39,24 +39,46 @@ using namespace arm_compute::misc::shape_calculator;
 
 namespace
 {
-Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output)
+Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output)
 {
-    if(is_data_type_quantized_asymmetric(input.data_type()))
+    if(is_data_type_quantized_asymmetric(input->data_type()))
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
         // Extract and negate input and weights offset
-        const QuantizationInfo input_quantization_info(input.quantization_info().uniform().scale, -input.quantization_info().uniform().offset);
-        const QuantizationInfo weights_quantization_info(weights.quantization_info().uniform().scale, -weights.quantization_info().uniform().offset);
+        const QuantizationInfo input_quantization_info(input->quantization_info().uniform().scale, -input->quantization_info().uniform().offset);
+        const QuantizationInfo weights_quantization_info(weights->quantization_info().uniform().scale, -weights->quantization_info().uniform().offset);
+
+        const UniformQuantizationInfo iq_info = input->quantization_info().uniform();
+        const UniformQuantizationInfo wq_info = weights->quantization_info().uniform();
+        const UniformQuantizationInfo oq_info = output->quantization_info().uniform();
+
+        float   multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale;
+        int32_t output_multiplier;
+        int32_t output_shift;
+        ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
+        GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
+        gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
+        gemmlowp_output_stage_info.gemmlowp_shift      = output_shift;
+        gemmlowp_output_stage_info.gemmlowp_offset     = oq_info.offset;
+        gemmlowp_output_stage_info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+        const auto min_max_bound                       = get_min_max(input->data_type());
+        gemmlowp_output_stage_info.gemmlowp_min_bound  = (std::get<0>(min_max_bound)).get<int32_t>();
+        gemmlowp_output_stage_info.gemmlowp_max_bound  = (std::get<1>(min_max_bound)).get<int32_t>();
+
+        GEMMInfo gemm_info;
+        gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
 
         // Validate gemmlowp function
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(&input.clone()->set_quantization_info(input_quantization_info),
-                                                                           &weights.clone()->set_quantization_info(weights_quantization_info),
-                                                                           nullptr,
-                                                                           &output));
+        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(&input->clone()->set_quantization_info(input_quantization_info),
+                                                                           &weights->clone()->set_quantization_info(weights_quantization_info),
+                                                                           biases,
+                                                                           output,
+                                                                           gemm_info));
     }
     else
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)));
+        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMM::validate(input, weights, biases, output, 1.f, 1.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)));
     }
 
     return Status{};
@@ -77,13 +99,12 @@ Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c
 
 NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(std::move(memory_manager)), _weights_manager(weights_manager), _flatten_kernel(), _convert_weights(), _convert_weights_managed(), _reshape_weights_function(),
-      _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(),
-      _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _accumulate_biases(false),
-      _is_quantized(false), _is_prepared(false)
+      _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(), _flatten_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr),
+      _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _is_quantized(false), _is_prepared(false)
 {
 }
 
-void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, ITensor *output)
+void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output)
 {
     if(_is_quantized)
     {
@@ -95,8 +116,27 @@ void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *we
         input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.uniform().scale, -input_quantization_info.uniform().offset));
         weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
 
-        // Configure gemmlowp function
-        _mm_gemmlowp.configure(input, weights, nullptr, output);
+        // Configure gemmlowp function and output stage for asymmetric quantized types
+        const UniformQuantizationInfo iq_info = input->info()->quantization_info().uniform();
+        const UniformQuantizationInfo wq_info = weights->info()->quantization_info().uniform();
+        const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform();
+
+        float   multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale;
+        int32_t output_multiplier;
+        int32_t output_shift;
+        quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+
+        GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
+        gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
+        gemmlowp_output_stage_info.gemmlowp_shift      = output_shift;
+        gemmlowp_output_stage_info.gemmlowp_offset     = oq_info.offset;
+        gemmlowp_output_stage_info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+        const auto min_max_bound                       = get_min_max(input->info()->data_type());
+        gemmlowp_output_stage_info.gemmlowp_min_bound  = (std::get<0>(min_max_bound)).get<int32_t>();
+        gemmlowp_output_stage_info.gemmlowp_max_bound  = (std::get<1>(min_max_bound)).get<int32_t>();
+        GEMMInfo gemm_info;
+        gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
+        _mm_gemmlowp.configure(input, weights, biases, output, gemm_info);
 
         // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers
         input->info()->set_quantization_info(input_quantization_info);
@@ -105,11 +145,11 @@ void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *we
     else
     {
         // Configure matrix multiply kernel
-        _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
+        _mm_gemm.configure(input, weights, biases, output, 1.f, 1.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
     }
 }
 
-void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, ITensor *output)
+void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output)
 {
     ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
 
@@ -124,18 +164,18 @@ void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITenso
     _flatten_kernel.configure(input, &_flatten_output);
 
     // Configure matrix multiply kernel
-    configure_mm(&_flatten_output, weights, output);
+    configure_mm(&_flatten_output, weights, biases, output);
 
     // Allocate the output tensor for flatten once all the configure methods have been called
     _flatten_output.allocator()->allocate();
 }
 
-void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, ITensor *output)
+void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output)
 {
     ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
 
     // Configure matrix multiply kernel
-    configure_mm(input, weights, output);
+    configure_mm(input, weights, biases, output);
 }
 
 void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
@@ -152,7 +192,6 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
     _are_weights_converted = true;
     _are_weights_reshaped  = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     _is_fc_after_conv      = true;
-    _accumulate_biases     = false;
     _is_quantized          = is_data_type_quantized_asymmetric(input->info()->data_type());
     _original_weights      = weights;
 
@@ -161,21 +200,6 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
         _weights_manager->manage(weights);
     }
 
-    // Configure gemmlowp output
-    if(_is_quantized)
-    {
-        _gemmlowp_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
-    }
-
-    // Configure accumulate biases kernel for non quantized asymmetric types
-    if(biases != nullptr && !_is_quantized)
-    {
-        _accumulate_biases = true;
-
-        // Configure accumulate biases kernel
-        _accumulate_biases_kernel.configure(output, biases);
-    }
-
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
     //  2) Fully Connected layer -> Fully Connected layer without batches
@@ -236,37 +260,15 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
         _are_weights_converted = false;
     }
 
-    ITensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output;
     if(_is_fc_after_conv)
     {
         // Fully Connected layer after a Convolution Layer without batches
-        configure_conv_fc(input, weights_to_use, tmp_output);
+        configure_conv_fc(input, weights_to_use, biases, output);
     }
     else
     {
         // Fully Connected layer after a Fully Connected Layer without batches
-        configure_fc_fc(input, weights_to_use, tmp_output);
-    }
-
-    // Configure output stage for asymmetric quantized types
-    if(_is_quantized)
-    {
-        const UniformQuantizationInfo iq_info = input->info()->quantization_info().uniform();
-        const UniformQuantizationInfo wq_info = weights->info()->quantization_info().uniform();
-        const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform();
-
-        float   multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale;
-        int32_t output_multiplier;
-        int32_t output_shift;
-        quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
-
-        GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
-        gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
-        gemmlowp_output_stage_info.gemmlowp_shift      = output_shift;
-        gemmlowp_output_stage_info.gemmlowp_offset     = oq_info.offset;
-        gemmlowp_output_stage_info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
-        _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, gemmlowp_output_stage_info);
-        _gemmlowp_output.allocator()->allocate();
+        configure_fc_fc(input, weights_to_use, biases, output);
     }
 
     _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights;
@@ -283,19 +285,10 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;
-    bool is_quantized     = is_data_type_quantized_asymmetric(input->data_type());
 
     const ITensorInfo &flatten_input     = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)));
     const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
     const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
-    const ITensorInfo &gemmlowp_output   = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
-
-    // Configure accumulate biases kernel for non quantized asymmetric types
-    if(biases != nullptr && !is_quantized)
-    {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAccumulateBiasesKernel::validate(output, biases));
-    }
 
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
@@ -305,7 +298,6 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
 
     const ITensorInfo *input_to_use   = input;
     const ITensorInfo *weights_to_use = weights;
-    const ITensorInfo *tmp_output     = (is_quantized) ? &gemmlowp_output : output;
 
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = output->dimension(1) > 1;
@@ -353,27 +345,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
         ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
     }
     // Validate matrix multiply kernel
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output));
-
-    // Validate output stage for asymmetric quantized types
-    if(is_quantized)
-    {
-        const UniformQuantizationInfo iq_info = input->quantization_info().uniform();
-        const UniformQuantizationInfo wq_info = weights->quantization_info().uniform();
-        const UniformQuantizationInfo oq_info = output->quantization_info().uniform();
-
-        float   multiplier = (iq_info.scale * wq_info.scale) / oq_info.scale;
-        int32_t output_multiplier;
-        int32_t output_shift;
-        ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
-
-        GEMMLowpOutputStageInfo gemmlowp_output_stage_info;
-        gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
-        gemmlowp_output_stage_info.gemmlowp_shift      = output_shift;
-        gemmlowp_output_stage_info.gemmlowp_offset     = oq_info.offset;
-        gemmlowp_output_stage_info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&gemmlowp_output, biases, output, gemmlowp_output_stage_info));
-    }
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(input_to_use, weights_to_use, biases, output));
 
     return Status{};
 }
@@ -399,19 +371,6 @@ void NEFullyConnectedLayer::run()
     {
         _mm_gemm.run();
     }
-
-    // Accumulate biases if provided
-    if(_is_quantized)
-    {
-        _gemmlowp_output_stage.run();
-    }
-    else
-    {
-        if(_accumulate_biases)
-        {
-            NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY);
-        }
-    }
 }
 
 void NEFullyConnectedLayer::prepare()