COMPMID-3241: Add Layer Normalization to NEQLSTMLayer
authorSang-Hoon Park <sang-hoon.park@arm.com>
Fri, 17 Apr 2020 23:46:34 +0000 (00:46 +0100)
committerSang-Hoon Park <sang-hoon.park@arm.com>
Wed, 22 Apr 2020 12:29:06 +0000 (12:29 +0000)
- Add output quantization calculation to Layer Normalization
- Add members for Layer Normalization to NEQLSTMLayer
- Add configure/validate/run of Layer Normalization to NEQLSTMLayer

Change-Id: I278c8e0edbb21212f3afa4d4a336df0f1a4c1bfb
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3059
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h
arm_compute/runtime/NEON/functions/NEQLSTMLayer.h
src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
src/runtime/NEON/functions/NEQLSTMLayer.cpp

index 631de66cc2f7bd5dcd871faa0345129b6572c002..f5e8da7febcaf0e52abb159091ce2c82bc7f7ba0 100644 (file)
@@ -130,6 +130,8 @@ private:
                             const int16_t *weight_ptr,
                             const int32_t *bias_ptr,
                             int32_t mean, int32_t inv_std_mul, int32_t inv_std_shift);
+    /** Function to compute output quantization information */
+    QuantizationInfo compute_output_qinfo();
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_NEQLSTMLAYERNORMALIZATIONKERNEL_H */
index a37909b775a4bc62bf13ebf8eb6a34435b39b001..312a8984b5d4ba85d9f36a7018f431f0ccfc4d85 100644 (file)
@@ -28,6 +28,7 @@
 #include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
 #include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
 #include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
@@ -169,6 +170,16 @@ public:
     void prepare() override;
 
 private:
+    enum class LayerNormGate : uint8_t
+    {
+        Forget,
+        Cell,
+        Input,
+        Output,
+        Count
+    };
+    static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+
     /** Internal method to configure matrix multiplication plus output stage of each gate.
      *
      * @param[in] mm             Matrix multiplication function to use.
@@ -254,12 +265,10 @@ private:
     NEGEMMLowpOutputStage            _projection_outstage{};
     NEArithmeticAdditionKernel       _accumulate_projection{};
     NEActivationLayer                _projection_clip{};
+    std::array<NEQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{};
 
     // Tensor pointers
-    const ITensor *_input_to_input_weights
-    {
-        nullptr
-    };
+    const ITensor *_input_to_input_weights{ nullptr };
     const ITensor *_recurrent_to_input_weights{ nullptr };
     const ITensor *_projection_bias{ nullptr };
     const ITensor *_input_to_forget_weights{ nullptr };
@@ -269,6 +278,58 @@ private:
     const ITensor *_recurrent_to_cell_weights{ nullptr };
     const ITensor *_recurrent_to_output_weights{ nullptr };
     const ITensor *_projection_weights{ nullptr };
+    std::array<const ITensor *, _layer_norm_count> _layer_norm_weights{};
+    std::array<const ITensor *, _layer_norm_count> _layer_norm_bias{};
+
+    using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
+    inline LayerNormIndexType getGateIndex(LayerNormGate g)
+    {
+        return static_cast<LayerNormIndexType>(g);
+    }
+
+    inline void set_layer_norm_weight(const ITensor *t, LayerNormGate g)
+    {
+        _layer_norm_weights[getGateIndex(g)] = t;
+    }
+
+    inline void set_layer_norm_bias(const ITensor *t, LayerNormGate g)
+    {
+        _layer_norm_bias[getGateIndex(g)] = t;
+    }
+
+    inline const ITensor *get_layer_norm_weight(LayerNormGate g)
+    {
+        return _layer_norm_weights[getGateIndex(g)];
+    }
+
+    inline const ITensor *get_layer_norm_bias(LayerNormGate g)
+    {
+        return _layer_norm_bias[getGateIndex(g)];
+    }
+
+    inline NEQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g)
+    {
+        return _layer_norms[getGateIndex(g)];
+    }
+
+    inline void configure_layer_norm(LayerNormGate g, const ITensor *in)
+    {
+        ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
+
+        Tensor &out = get_layer_norm_output(g);
+        _memory_group.manage(&out);
+        out.allocator()->init(*(in->info()));
+
+        get_layer_norm(g).configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
+    }
+
+    inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
+    {
+        // Output quantization scale will be different, but ignored here
+        // since it will be configured at configure() stage.
+        const TensorInfo out{ in };
+        return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
+    }
 
     // Temporary tensors
     Tensor _input_to_forget_weights_transposed{ nullptr };
@@ -320,6 +381,12 @@ private:
     Tensor _mm_projection_res{ nullptr };
     Tensor _projection_outstage_res{ nullptr };
     Tensor _ones{ nullptr };
+    std::array<Tensor, _layer_norm_count> _layer_norm_output{};
+
+    inline Tensor &get_layer_norm_output(LayerNormGate g)
+    {
+        return _layer_norm_output[getGateIndex(g)];
+    }
 
     bool _is_prepared{ false };
     bool _has_cifg{ false };
@@ -327,6 +394,7 @@ private:
     bool _has_projection{ false };
     bool _has_projection_clipping{ false };
     bool _has_peephole{ false };
+    bool _has_layer_norm{ false };
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_NEQLSTMLAYER_H */
index db2ff85db90d675e1896ff14307a4c62042392c5..e966c6bdbab5d0bbdf3a4c4f67f3a3d811e4dcd3 100644 (file)
@@ -80,11 +80,9 @@ inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4
 
 void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight);
-    ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(),
-                                        output ? output->info() : nullptr,
-                                        weight->info(),
-                                        bias ? bias->info() : nullptr));
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output);
+    ARM_COMPUTE_ERROR_ON(input == output);
+    ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), weight->info(), bias->info()));
 
     static const std::map<DataType, ComputeFuncType> fn_map =
     {
@@ -98,6 +96,7 @@ void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *o
     _fn     = fn_map.at(_input->info()->data_type());
 
     auto_init_if_empty(*_output->info(), *_input->info());
+    _output->info()->set_quantization_info(compute_output_qinfo());
 
     const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform();
     const Status                  s       = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift);
@@ -171,6 +170,14 @@ void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo
     _fn(*this);
 }
 
+inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo()
+{
+    const UniformQuantizationInfo iq_info      = _input->info()->quantization_info().uniform();
+    const UniformQuantizationInfo wq_info      = _weight->info()->quantization_info().uniform();
+    const float                   output_scale = (wq_info.scale * iq_info.scale) * 1024;
+    return QuantizationInfo(output_scale);
+}
+
 inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr)
 {
     ARM_COMPUTE_ERROR_ON(!input_ptr);
index b02fab227bfedeb81e6e640480dc6cbd8082d174..a279bba2ab4c8a93d1fb4f64c6cdab3417219138 100644 (file)
@@ -79,9 +79,6 @@ void NEQLSTMLayer::configure(const ITensor *input,
                              ITensor *cell_state_out, ITensor *output_state_out,
                              const LSTMParams<ITensor> &lstm_params)
 {
-    ARM_COMPUTE_UNUSED(forget_gate_bias);
-    ARM_COMPUTE_UNUSED(cell_bias);
-    ARM_COMPUTE_UNUSED(output_gate_bias);
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
                                  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
@@ -112,6 +109,21 @@ void NEQLSTMLayer::configure(const ITensor *input,
     _recurrent_to_output_weights = recurrent_to_output_weights;
     _projection_weights          = lstm_params.projection_weights();
 
+    // Layer normalization
+    _has_layer_norm = lstm_params.use_layer_norm();
+    if(_has_layer_norm)
+    {
+        set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
+        set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
+        set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
+        set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
+
+        set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
+        set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
+        set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
+        set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
+    }
+
     _has_cifg       = lstm_params.has_cifg_opt();
     _has_projection = lstm_params.has_projection();
     _has_peephole   = lstm_params.has_peephole_opt();
@@ -203,14 +215,23 @@ void NEQLSTMLayer::configure(const ITensor *input,
         _cell_to_forget_outstage_res.allocator()->allocate();
     }
 
+    Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
+        forget_activation_input->allocator()->allocate();
+        forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
+    }
+
     // Output quantization info of Sigmoid and Tanh activations
     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
+    const TensorInfo       forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
 
-    const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     _memory_group.manage(&_forget_gate);
     _forget_gate.allocator()->init(forget_gate_info);
-    _forget_gate_sigmoid.configure(&_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
-    _recurrent_to_forget_outstage_res.allocator()->allocate();
+    _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+    forget_activation_input->allocator()->allocate();
 
     // Modulation gate.
     const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
@@ -229,11 +250,21 @@ void NEQLSTMLayer::configure(const ITensor *input,
     _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
     _input_to_cell_outstage_res.allocator()->allocate();
 
+    Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
+        cell_activation_input->allocator()->allocate();
+        cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
+    }
+
     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
     _memory_group.manage(&_cell_gate);
     _cell_gate.allocator()->init(cell_gate_info);
-    _cell_gate_tanh.configure(&_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
-    _recurrent_to_cell_outstage_res.allocator()->allocate();
+    _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+    cell_activation_input->allocator()->allocate();
 
     // Input gate.
     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
@@ -276,8 +307,17 @@ void NEQLSTMLayer::configure(const ITensor *input,
             _cell_to_input_outstage_res.allocator()->allocate();
         }
 
-        _input_gate_tanh.configure(&_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
-        _recurrent_to_input_outstage_res.allocator()->allocate();
+        Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
+
+        if(_has_layer_norm)
+        {
+            configure_layer_norm(LayerNormGate::Input, input_activation_input);
+            input_activation_input->allocator()->allocate();
+            input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
+        }
+
+        _input_gate_tanh.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+        input_activation_input->allocator()->allocate();
     }
     // Cell.
     // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
@@ -325,11 +365,20 @@ void NEQLSTMLayer::configure(const ITensor *input,
         _mul_cell_to_output_res.allocator()->allocate();
     }
 
+    Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Output, output_activation_input);
+        output_activation_input->allocator()->allocate();
+        output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
+    }
     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
     _memory_group.manage(&_output_gate);
     _output_gate.allocator()->init(output_gate_info);
-    _output_gate_sigmoid.configure(&_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
-    _recurrent_to_output_outstage_res.allocator()->allocate();
+    _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+    output_activation_input->allocator()->allocate();
 
     // Hidden.
     _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
@@ -505,6 +554,8 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
     gemmlowp_info.output_data_type   = DataType::QSYMM16;
 
+    const bool has_layer_norm = lstm_params.use_layer_norm();
+
     // Forget gate.
     const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
@@ -527,10 +578,17 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
     }
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
+        const ITensorInfo *b_info = forget_gate_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
+    }
+
     // Output quantization info of Sigmoid and Tanh activations
     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
+    const TensorInfo       forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
 
-    const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
 
     // Modulation gate.
@@ -543,7 +601,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
 
     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
+        const ITensorInfo *b_info = cell_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+    }
     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
 
     // Input gate.
@@ -582,6 +647,13 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
             ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
         }
 
+        if(has_layer_norm)
+        {
+            const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
+            const ITensorInfo *b_info = lstm_params.input_gate_bias();
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
+        }
+
         ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
     }
     // Cell.
@@ -614,6 +686,13 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
     }
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
+        const ITensorInfo *b_info = output_gate_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
+    }
+
     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
 
@@ -695,6 +774,11 @@ void NEQLSTMLayer::run()
         NEScheduler::get().schedule(&_accumulate_cell_forget, Window::DimY);
     }
 
+    if(_has_layer_norm)
+    {
+        NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Forget), Window::DimY);
+    }
+
     _forget_gate_sigmoid.run();
 
     // Modulation gate.
@@ -705,6 +789,11 @@ void NEQLSTMLayer::run()
     _recurrent_to_cell_outstage.run();
     NEScheduler::get().schedule(&_accumulate_input_recurrent_modulation, Window::DimY);
 
+    if(_has_layer_norm)
+    {
+        NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Cell), Window::DimY);
+    }
+
     _cell_gate_tanh.run();
 
     // Input gate
@@ -727,6 +816,11 @@ void NEQLSTMLayer::run()
             NEScheduler::get().schedule(&_accumulate_cell_input, Window::DimY);
         }
 
+        if(_has_layer_norm)
+        {
+            NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY);
+        }
+
         _input_gate_tanh.run();
     }
 
@@ -751,6 +845,11 @@ void NEQLSTMLayer::run()
         NEScheduler::get().schedule(&_accumulate_cell_to_output, Window::DimY);
     }
 
+    if(_has_layer_norm)
+    {
+        NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Output), Window::DimY);
+    }
+
     _output_gate_sigmoid.run();
 
     // Hidden.