#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
void prepare() override;
private:
+ enum class LayerNormGate : uint8_t
+ {
+ Forget,
+ Cell,
+ Input,
+ Output,
+ Count
+ };
+ static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+
/** Internal method to configure matrix multiplication plus output stage of each gate.
*
* @param[in] mm Matrix multiplication function to use.
NEGEMMLowpOutputStage _projection_outstage{};
NEArithmeticAdditionKernel _accumulate_projection{};
NEActivationLayer _projection_clip{};
+ std::array<NEQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{};
// Tensor pointers
- const ITensor *_input_to_input_weights
- {
- nullptr
- };
+ const ITensor *_input_to_input_weights{ nullptr };
const ITensor *_recurrent_to_input_weights{ nullptr };
const ITensor *_projection_bias{ nullptr };
const ITensor *_input_to_forget_weights{ nullptr };
const ITensor *_recurrent_to_cell_weights{ nullptr };
const ITensor *_recurrent_to_output_weights{ nullptr };
const ITensor *_projection_weights{ nullptr };
+ std::array<const ITensor *, _layer_norm_count> _layer_norm_weights{};
+ std::array<const ITensor *, _layer_norm_count> _layer_norm_bias{};
+
+ using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
+ inline LayerNormIndexType getGateIndex(LayerNormGate g)
+ {
+ return static_cast<LayerNormIndexType>(g);
+ }
+
+ inline void set_layer_norm_weight(const ITensor *t, LayerNormGate g)
+ {
+ _layer_norm_weights[getGateIndex(g)] = t;
+ }
+
+ inline void set_layer_norm_bias(const ITensor *t, LayerNormGate g)
+ {
+ _layer_norm_bias[getGateIndex(g)] = t;
+ }
+
+ inline const ITensor *get_layer_norm_weight(LayerNormGate g)
+ {
+ return _layer_norm_weights[getGateIndex(g)];
+ }
+
+ inline const ITensor *get_layer_norm_bias(LayerNormGate g)
+ {
+ return _layer_norm_bias[getGateIndex(g)];
+ }
+
+ inline NEQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g)
+ {
+ return _layer_norms[getGateIndex(g)];
+ }
+
+ inline void configure_layer_norm(LayerNormGate g, const ITensor *in)
+ {
+ ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
+
+ Tensor &out = get_layer_norm_output(g);
+ _memory_group.manage(&out);
+ out.allocator()->init(*(in->info()));
+
+ get_layer_norm(g).configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
+ }
+
+ inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
+ {
+ // Output quantization scale will be different, but ignored here
+ // since it will be configured at configure() stage.
+ const TensorInfo out{ in };
+ return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
+ }
// Temporary tensors
Tensor _input_to_forget_weights_transposed{ nullptr };
Tensor _mm_projection_res{ nullptr };
Tensor _projection_outstage_res{ nullptr };
Tensor _ones{ nullptr };
+ std::array<Tensor, _layer_norm_count> _layer_norm_output{};
+
+ inline Tensor &get_layer_norm_output(LayerNormGate g)
+ {
+ return _layer_norm_output[getGateIndex(g)];
+ }
bool _is_prepared{ false };
bool _has_cifg{ false };
bool _has_projection{ false };
bool _has_projection_clipping{ false };
bool _has_peephole{ false };
+ bool _has_layer_norm{ false };
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_NEQLSTMLAYER_H */
ITensor *cell_state_out, ITensor *output_state_out,
const LSTMParams<ITensor> &lstm_params)
{
- ARM_COMPUTE_UNUSED(forget_gate_bias);
- ARM_COMPUTE_UNUSED(cell_bias);
- ARM_COMPUTE_UNUSED(output_gate_bias);
ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
_recurrent_to_output_weights = recurrent_to_output_weights;
_projection_weights = lstm_params.projection_weights();
+ // Layer normalization
+ _has_layer_norm = lstm_params.use_layer_norm();
+ if(_has_layer_norm)
+ {
+ set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
+ set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
+ set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
+ set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
+
+ set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
+ set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
+ set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
+ set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
+ }
+
_has_cifg = lstm_params.has_cifg_opt();
_has_projection = lstm_params.has_projection();
_has_peephole = lstm_params.has_peephole_opt();
_cell_to_forget_outstage_res.allocator()->allocate();
}
+ Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
+ forget_activation_input->allocator()->allocate();
+ forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
+ }
+
// Output quantization info of Sigmoid and Tanh activations
const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
+ const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
- const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
_memory_group.manage(&_forget_gate);
_forget_gate.allocator()->init(forget_gate_info);
- _forget_gate_sigmoid.configure(&_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- _recurrent_to_forget_outstage_res.allocator()->allocate();
+ _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ forget_activation_input->allocator()->allocate();
// Modulation gate.
const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
_accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
_input_to_cell_outstage_res.allocator()->allocate();
+ Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
+ cell_activation_input->allocator()->allocate();
+ cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
+ }
+
const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
_memory_group.manage(&_cell_gate);
_cell_gate.allocator()->init(cell_gate_info);
- _cell_gate_tanh.configure(&_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
- _recurrent_to_cell_outstage_res.allocator()->allocate();
+ _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ cell_activation_input->allocator()->allocate();
// Input gate.
const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
_cell_to_input_outstage_res.allocator()->allocate();
}
- _input_gate_tanh.configure(&_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
- _recurrent_to_input_outstage_res.allocator()->allocate();
+ Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Input, input_activation_input);
+ input_activation_input->allocator()->allocate();
+ input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
+ }
+
+ _input_gate_tanh.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ input_activation_input->allocator()->allocate();
}
// Cell.
// TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
_mul_cell_to_output_res.allocator()->allocate();
}
+ Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Output, output_activation_input);
+ output_activation_input->allocator()->allocate();
+ output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
+ }
const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
_memory_group.manage(&_output_gate);
_output_gate.allocator()->init(output_gate_info);
- _output_gate_sigmoid.configure(&_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- _recurrent_to_output_outstage_res.allocator()->allocate();
+ _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ output_activation_input->allocator()->allocate();
// Hidden.
_hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
gemmlowp_info.output_data_type = DataType::QSYMM16;
+ const bool has_layer_norm = lstm_params.use_layer_norm();
+
// Forget gate.
const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
+ const ITensorInfo *b_info = forget_gate_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
+ }
+
// Output quantization info of Sigmoid and Tanh activations
const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
+ const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
- const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Modulation gate.
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
+ const ITensorInfo *b_info = cell_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+ }
const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
+
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
// Input gate.
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
+ const ITensorInfo *b_info = lstm_params.input_gate_bias();
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
+ }
+
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
}
// Cell.
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
+ const ITensorInfo *b_info = output_gate_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
+ }
+
const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
NEScheduler::get().schedule(&_accumulate_cell_forget, Window::DimY);
}
+ if(_has_layer_norm)
+ {
+ NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Forget), Window::DimY);
+ }
+
_forget_gate_sigmoid.run();
// Modulation gate.
_recurrent_to_cell_outstage.run();
NEScheduler::get().schedule(&_accumulate_input_recurrent_modulation, Window::DimY);
+ if(_has_layer_norm)
+ {
+ NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Cell), Window::DimY);
+ }
+
_cell_gate_tanh.run();
// Input gate
NEScheduler::get().schedule(&_accumulate_cell_input, Window::DimY);
}
+ if(_has_layer_norm)
+ {
+ NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY);
+ }
+
_input_gate_tanh.run();
}
NEScheduler::get().schedule(&_accumulate_cell_to_output, Window::DimY);
}
+ if(_has_layer_norm)
+ {
+ NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Output), Window::DimY);
+ }
+
_output_gate_sigmoid.run();
// Hidden.