[ layer ] Enable mha gtest and match version
authorskykongkong8 <ss.kong@samsung.com>
Mon, 3 Jun 2024 09:31:10 +0000 (18:31 +0900)
committerjijoong.moon <jijoong.moon@samsung.com>
Mon, 10 Jun 2024 22:54:32 +0000 (07:54 +0900)
- Current mha layer at nntrainer/layer is not for general use, but implemented solely for LLaMA support.
- In order to run unittest for mha layer, return to previous version of mha layer, and move current implementation under Application/LLaMA

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
Applications/LLaMA/jni/Android.mk
Applications/LLaMA/jni/custom_multi_head_attention_layer.cpp [new file with mode: 0644]
Applications/LLaMA/jni/custom_multi_head_attention_layer.h [new file with mode: 0644]
Applications/LLaMA/jni/main.cpp
Applications/LLaMA/jni/meson.build
nntrainer/layers/multi_head_attention_layer.cpp
nntrainer/layers/multi_head_attention_layer.h
test/unittest/layers/meson.build

index e095e7049bafffb485e134611e6bfac33e5527a9..f1a9c2f117ef2ee68857bd3efc8b1783a500a066 100644 (file)
@@ -79,6 +79,26 @@ LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
 
 include $(BUILD_SHARED_LIBRARY)
 
+include $(CLEAR_VARS)
+
+LOCAL_ARM_NEON := true
+LOCAL_CFLAGS += -std=c++17 -Ofast -mcpu=cortex-a53 -Ilz4-nougat/lib -DENABLE_FP16=1 -DUSE__FP16=1
+LOCAL_LDFLAGS += -Llz4-nougat/lib/obj/local/$(TARGET_ARCH_ABI)/
+LOCAL_CXXFLAGS += -std=c++17 -frtti -DENABLE_FP16=1 -DUSE__FP16=1
+LOCAL_CFLAGS += -pthread -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1
+LOCAL_LDFLAGS += -fexceptions -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1
+LOCAL_MODULE_TAGS := optional
+LOCAL_ARM_MODE := arm
+LOCAL_MODULE := custom_multi_head_attention_layer
+LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__FP16=1
+
+LOCAL_SRC_FILES := custom_multi_head_attention_layer.cpp
+
+LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
+
+LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
+
+include $(BUILD_SHARED_LIBRARY)
 
 
 include $(CLEAR_VARS)
@@ -96,7 +116,7 @@ LOCAL_LDLIBS := -llog -landroid -fopenmp -static-openmp -DENABLE_FP16=1 -DUSE__F
 
 LOCAL_SRC_FILES := main.cpp 
 
-LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer rms_norm_layer swiglu_layer
+LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer rms_norm_layer swiglu_layer custom_multi_head_attention_layer
 
 LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
 
diff --git a/Applications/LLaMA/jni/custom_multi_head_attention_layer.cpp b/Applications/LLaMA/jni/custom_multi_head_attention_layer.cpp
new file mode 100644 (file)
index 0000000..2a7bcba
--- /dev/null
@@ -0,0 +1,1537 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
+ *
+ * @file   multi_head_attention_layer.cpp
+ * @date   08 July 2022
+ * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/abs/1706.03762
+ * @author hyeonseok Lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is MultiHeadAttention Layer Class for Neural Network
+ *
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <custom_multi_head_attention_layer.h>
+#include <layer_context.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <thread>
+#include <vector>
+
+namespace nntrainer {
+
+MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
+  multi_head_attention_props(
+    props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
+    props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
+    props::AverageAttentionWeight(), props::MaxTimestep()),
+  sm(ActivationType::ACT_SOFTMAX),
+  epsilon(1e-3),
+  cache_index(0) {
+  weight_idx.fill(std::numeric_limits<unsigned>::max());
+  layer_progress = 0;
+}
+
+MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
+
+enum INOUT_INDEX {
+  /** input index */
+  QUERY = 0,
+  KEY = 1,
+  VALUE = 2,
+  MASK = 3,
+  /** output index */
+  OUTPUT = 0,
+  RETURN_ATTENTION_WEIGHT = 1,
+};
+
+enum AttentionParams {
+  query_fc_weight,
+  query_fc_bias,
+  key_fc_weight,
+  key_fc_bias,
+  value_fc_weight,
+  value_fc_bias,
+  fc_weight,
+  fc_bias,
+  projected_query,
+  projected_key,
+  projected_value,
+  cache_key,
+  cache_value,
+  /** intended comment for later use of attention_mask */
+  // attention_mask,
+  attention_weight,
+  dropout_mask,
+  attention_output,
+};
+
+void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
+  NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4,
+                std::invalid_argument)
+    << "Multi head Attention layer needs 3 or 4 inputs. (query, key, value and "
+       "mask is optional";
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+
+  TensorDim::TensorType weight_type = {context.getFormat(),
+                                       context.getWeightDataType()};
+
+  TensorDim::TensorType activation_type = {context.getFormat(),
+                                           context.getActivationDataType()};
+
+  TensorDim empty_dim(activation_type);
+
+  const std::vector<TensorDim> &input_dims = context.getInputDimensions();
+  const TensorDim &query_dim = input_dims[INOUT_INDEX::QUERY];
+  const TensorDim &key_dim = input_dims[INOUT_INDEX::KEY];
+  const TensorDim &value_dim = input_dims[INOUT_INDEX::VALUE];
+  const TensorDim &mask_dim =
+    provide_attention_mask ? input_dims[INOUT_INDEX::MASK] : empty_dim;
+
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const unsigned int query_width = query_dim.width();
+  // const unsigned int key_height = key_dim.height();
+  const unsigned int key_width = key_dim.width();
+  // const unsigned int value_height = value_dim.height();
+  const unsigned int value_width = value_dim.width();
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+  auto &weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props).get();
+  auto &weight_regularizer =
+    std::get<props::WeightRegularizer>(*layer_impl_props);
+  auto &weight_regularizer_constant =
+    std::get<props::WeightRegularizerConstant>(*layer_impl_props);
+  const float &weight_decay =
+    std::get<props::WeightDecay>(*layer_impl_props).get();
+
+  NNTR_THROW_IF(std::get<props::NumHeads>(multi_head_attention_props).empty(),
+                std::invalid_argument)
+    << "num_heads property is not provided for layer " << context.getName();
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+
+  if (std::get<props::ProjectedKeyDim>(multi_head_attention_props).empty()) {
+    NNTR_THROW_IF(query_width % num_heads, std::invalid_argument)
+      << "query_width: " << query_width
+      << " is not divisible by num_heads: " << num_heads << " for layer "
+      << context.getName();
+
+    ml_logw("[multi head attention] ProjectedKeyDim property is not given. "
+            "Default value(query_width / num_heads) is set");
+
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props)
+      .set(query_width / num_heads);
+  }
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+
+  if (std::get<props::ProjectedValueDim>(multi_head_attention_props).empty()) {
+    std::get<props::ProjectedValueDim>(multi_head_attention_props)
+      .set(projected_key_dim_prop);
+  }
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+
+  if (std::get<props::OutputShape>(multi_head_attention_props).empty()) {
+    std::get<props::OutputShape>(multi_head_attention_props).set(query_width);
+  }
+  const unsigned int output_shape =
+    std::get<props::OutputShape>(multi_head_attention_props).get();
+
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+
+  if (std::get<props::AverageAttentionWeight>(multi_head_attention_props)
+        .empty()) {
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props)
+      .set(true);
+  }
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+
+  const unsigned int max_timestep =
+    std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+  // @todo: fix me
+  const unsigned int key_height = max_timestep;
+  const unsigned int value_height = max_timestep;
+
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+
+  if (activation_type.data_type == TensorDim::DataType::FP32) {
+    sm.setActiFunc(ActivationType::ACT_SOFTMAX);
+  } else if (activation_type.data_type == TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    sm.setActiFunc<_FP16>(ActivationType::ACT_SOFTMAX);
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+  }
+
+  // sm.setActiFunc(ActivationType::ACT_SOFTMAX);
+
+  NNTR_THROW_IF(query_dim.channel() != 1, std::invalid_argument)
+    << "Dimension of input query channel: " << query_dim.channel()
+    << " is not 1 for layer " << context.getName();
+  NNTR_THROW_IF(key_dim.channel() != 1, std::invalid_argument)
+    << "Dimension of input key channel: " << key_dim.channel()
+    << " is not 1 for layer " << context.getName();
+  NNTR_THROW_IF(value_dim.channel() != 1, std::invalid_argument)
+    << "Dimension of input value channel: " << value_dim.channel()
+    << " is not 1 for layer " << context.getName();
+  NNTR_THROW_IF(provide_attention_mask && mask_dim.channel() != num_heads,
+                std::invalid_argument)
+    << "Dimension of input mask channel: " << mask_dim.channel()
+    << " is not matched with num_heads property: " << num_heads << " for layer "
+    << context.getName();
+
+  NNTR_THROW_IF(key_height != value_height, std::invalid_argument)
+    << "Dimension of input key height: " << key_height
+    << " is not matched with Dimension of input value height: " << value_height
+    << " for layer " << context.getName();
+  NNTR_THROW_IF(provide_attention_mask && mask_dim.height() != query_height,
+                std::invalid_argument)
+    << "Dimension of input mask height: " << mask_dim.height()
+    << " is not matched with Dimension of input query height: " << query_height
+    << " for layer " << context.getName();
+
+  NNTR_THROW_IF(provide_attention_mask && mask_dim.width() != key_height,
+                std::invalid_argument)
+    << "Dimension of input mask width: " << mask_dim.width()
+    << " is not matched with Dimension of input key height: " << key_height
+    << " for layer " << context.getName();
+
+  /** weight/bias for query fc */
+  TensorDim query_fc_weight_dim(
+    {1, 1, query_width, num_heads * projected_query_dim_prop}, weight_type);
+
+  weight_idx[AttentionParams::query_fc_weight] = context.requestWeight(
+    query_fc_weight_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "query_fc_weight", true);
+  if (!disable_bias) {
+    TensorDim query_fc_bias_dim({1, 1, 1, num_heads * projected_query_dim_prop},
+                                weight_type);
+    weight_idx[AttentionParams::query_fc_bias] = context.requestWeight(
+      query_fc_bias_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, weight_decay, "query_fc_bias", true);
+  }
+
+  /** weight/bias for key fc */
+  TensorDim key_fc_weight_dim(
+    {1, 1, key_width, num_heads * projected_key_dim_prop}, weight_type);
+  weight_idx[AttentionParams::key_fc_weight] = context.requestWeight(
+    key_fc_weight_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "key_fc_weight", true);
+  if (!disable_bias) {
+    TensorDim key_fc_bias_dim({1, 1, 1, num_heads * projected_key_dim_prop},
+                              weight_type);
+    weight_idx[AttentionParams::key_fc_bias] = context.requestWeight(
+      key_fc_bias_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, weight_decay, "key_fc_bias", true);
+  }
+
+  /** weight/bias for value fc */
+  TensorDim value_fc_weight_dim(
+    {1, 1, value_width, num_heads * projected_value_dim_prop}, weight_type);
+  weight_idx[AttentionParams::value_fc_weight] = context.requestWeight(
+    value_fc_weight_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "value_fc_weight", true);
+  if (!disable_bias) {
+    TensorDim value_fc_bias_dim({1, 1, 1, num_heads * projected_value_dim_prop},
+                                weight_type);
+    weight_idx[AttentionParams::value_fc_bias] = context.requestWeight(
+      value_fc_bias_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, weight_decay, "value_fc_bias", true);
+  }
+
+  /** weight/bias for out fc */
+  TensorDim fc_weight_dim(
+    {1, 1, num_heads * projected_value_dim_prop, output_shape}, weight_type);
+  weight_idx[AttentionParams::fc_weight] = context.requestWeight(
+    fc_weight_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "fc_weight", true);
+  if (!disable_bias) {
+    TensorDim fc_bias_dim({1, 1, 1, output_shape}, weight_type);
+    weight_idx[AttentionParams::fc_bias] = context.requestWeight(
+      fc_bias_dim, weight_initializer, weight_regularizer,
+      weight_regularizer_constant, weight_decay, "fc_bias", true);
+  }
+
+  /** tensor for output of query fc */
+  TensorDim projected_query_dim(
+    {batch_size, 1, query_height, num_heads * projected_query_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::projected_query] = context.requestTensor(
+    projected_query_dim, "projected_query", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
+  /** tensor for output of key fc */
+  TensorDim projected_key_dim(
+    {batch_size, 1, key_height, num_heads * projected_key_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::projected_key] = context.requestTensor(
+    projected_key_dim, "projected_key", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
+  /** tensor for output of value fc */
+  TensorDim projected_value_dim(
+    {batch_size, 1, value_height, num_heads * projected_value_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::projected_value] = context.requestTensor(
+    projected_value_dim, "projected_value", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
+
+  TensorDim cache_key_dim(
+    {batch_size, 1, max_timestep, num_heads * projected_key_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::cache_key] =
+    context.requestTensor(cache_key_dim, "cache_key", Tensor::Initializer::NONE,
+                          true, TensorLifespan::MAX_LIFESPAN);
+
+  TensorDim cache_value_dim(
+    {batch_size, 1, max_timestep, num_heads * projected_value_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::cache_value] = context.requestTensor(
+    cache_value_dim, "cache_value", Tensor::Initializer::NONE, true,
+    TensorLifespan::MAX_LIFESPAN);
+
+  if (provide_attention_mask) {
+    /** Intended comment for bool type mask */
+    // TensorDim attention_mask_dim(
+    //   {batch_size, num_heads, query_height, key_height});
+    // weight_idx[AttentionParams::attention_mask] = context.requestTensor(
+    //   attention_mask_dim, "attention_mask", Tensor::Initializer::NONE, false,
+    //   TensorLifespan::FORWARD_FUNC_LIFESPAN);
+  }
+  /** tensor for attention weight */
+  TensorDim attention_weight_dim(
+    {batch_size, num_heads, query_height, key_height}, activation_type);
+  weight_idx[AttentionParams::attention_weight] = context.requestTensor(
+    attention_weight_dim, "attention_weight", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
+  if (dropout_rate > epsilon) {
+    /** tensor for dropout mask */
+    TensorDim dropout_mask_dim(
+      {batch_size, num_heads, query_height, key_height}, activation_type);
+    weight_idx[AttentionParams::dropout_mask] = context.requestTensor(
+      dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
+      TensorLifespan::ITERATION_LIFESPAN);
+  }
+
+  /** tensor for attention output */
+  TensorDim attention_output_dim(
+    {batch_size, 1, query_height, num_heads * projected_value_dim_prop},
+    activation_type);
+  weight_idx[AttentionParams::attention_output] = context.requestTensor(
+    attention_output_dim, "attention_output", Tensor::Initializer::NONE, true,
+    TensorLifespan::ITERATION_LIFESPAN);
+
+  TensorDim output_dim({batch_size, 1, query_height, output_shape},
+                       activation_type);
+  if (return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none) {
+    TensorDim return_attention_weight_dim(
+      {batch_size, average_attention_weight ? 1 : num_heads, query_height,
+       key_height},
+      activation_type);
+    context.setOutputDimensions({output_dim, return_attention_weight_dim});
+  } else {
+    context.setOutputDimensions({output_dim});
+  }
+
+  /**
+   * @todo
+   * check query width and key width
+   *
+   */
+  if (freqs_cos == nullptr)
+    precompute_freqs(projected_key_dim_prop, max_timestep);
+}
+
+#define _MASK_NUM(datatype) \
+  (((datatype) == ml::train::TensorDim::DataType::FP16) ? (-1e4) : (-1e10))
+
+void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
+                                         bool training) {
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+  const bool enable_dropout = dropout_rate > epsilon;
+
+  Tensor empty_tensor;
+
+  /** get inputs/outputs */
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+  Tensor &mask =
+    provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
+
+  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+  Tensor &ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  /** get weights */
+  Tensor &query_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
+  Tensor &query_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
+  Tensor &key_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
+  Tensor &key_fc_bias =
+    disable_bias ? empty_tensor
+                 : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
+  Tensor &value_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
+  Tensor &value_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
+  Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+  Tensor &fc_bias = disable_bias
+                      ? empty_tensor
+                      : context.getWeight(weight_idx[AttentionParams::fc_bias]);
+
+  /** get tensors */
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &attention_output =
+    context.getTensor(weight_idx[AttentionParams::attention_output]);
+
+  const TensorDim query_dim = query.getDim();
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const TensorDim key_dim = key.getDim();
+  const unsigned int key_height = key_dim.height();
+  const TensorDim value_dim = value.getDim();
+  const unsigned int value_height = value_dim.height();
+
+  query.dot(query_fc_weight, projected_query);
+  if (!disable_bias) {
+    projected_query.add_i(query_fc_bias);
+  }
+  key.dot(key_fc_weight, projected_key);
+  if (!disable_bias) {
+    projected_key.add_i(key_fc_bias);
+  }
+  value.dot(value_fc_weight, projected_value);
+  if (!disable_bias) {
+    projected_value.add_i(value_fc_bias);
+  }
+
+  apply_rotary_emb_tensor(projected_query, projected_query_dim_prop, 0);
+  apply_rotary_emb_tensor(projected_key, projected_key_dim_prop, 0);
+
+  projected_query.reshape(
+    TensorDim({batch_size, query_height, num_heads, projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size, key_height, num_heads, projected_key_dim_prop}));
+  projected_value.reshape(
+    TensorDim({batch_size, value_height, num_heads, projected_value_dim_prop}));
+
+  projected_query = projected_query.transpose("1:0:2");
+  projected_key = projected_key.transpose("1:0:2");
+  projected_value = projected_value.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove
+   * during transpose */
+  projected_query.setName("multi_head_attention:projected_query");
+  projected_key.setName("multi_head_attention:projected_key");
+  projected_value.setName("multi_head_attention:projected_value");
+
+  projected_query.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+
+  attention_weight.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  attention_output.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
+
+  /** scaled dot product attention */
+  projected_query.dotBatched(projected_key, attention_weight, false, true);
+  attention_weight.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+
+  unsigned int mask_size = attention_weight.getDim().width();
+  unsigned int mask_dim_height = mask_size;
+  unsigned int mask_dim_width = mask_size;
+
+  Tensor causal_mask(
+    TensorDim{1, 1, mask_size, mask_size, attention_weight.getTensorType()});
+
+  causal_mask.setZero();
+
+  for (unsigned int i = 0; i < mask_dim_height; ++i) {
+    for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
+      causal_mask.setValue(0, 0, i, j,
+                           _MASK_NUM(attention_weight.getDataType()));
+    }
+  }
+
+  attention_weight.add_i(causal_mask);
+
+  if (provide_attention_mask) {
+    // Tensor &attention_mask =
+    //   context.getTensor(weight_idx[AttentionParams::attention_mask]);
+    /** @todo: enable bool type tensor */
+    // if (torch_ref) {
+    //   attention_mask.setValue(-1e9);
+    // } else {
+    //   // flip
+    //   attention_mask.setValue(1);
+    //   attention_mask.subtract_i(mask);
+
+    //   attention_mask.multiply_i(-1e9);
+    // }
+    // attention_mask.multiply_i(mask);
+    // attention_weight.add_i(attention_mask);
+
+    attention_weight.reshape(
+      TensorDim({batch_size, num_heads, query_height, key_height}));
+    attention_weight.add_i(mask);
+    attention_weight.reshape(
+      TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  }
+
+  sm.run_fn(attention_weight, attention_weight);
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::before) {
+    ret_attention_weight.copyData(attention_weight);
+  }
+
+  if (enable_dropout) {
+    Tensor &dropout_mask =
+      context.getTensor(weight_idx[AttentionParams::dropout_mask]);
+    dropout_mask.dropout_mask(dropout_rate);
+    attention_weight.multiply_i(dropout_mask);
+  }
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::after) {
+    if (average_attention_weight) {
+      attention_weight.reshape(
+        TensorDim({batch_size, num_heads, query_height, key_height}));
+      attention_weight.sum(1, ret_attention_weight, 1, 0);
+      ret_attention_weight.divide_i(num_heads);
+      attention_weight.reshape(
+        TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+    } else {
+      ret_attention_weight.copyData(attention_weight);
+    }
+  }
+
+  attention_weight.dotBatched(projected_value, attention_output);
+
+  attention_output.reshape(
+    TensorDim({batch_size, num_heads, query_height, projected_value_dim_prop}));
+
+  attention_output = attention_output.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove during
+   * transpose */
+  attention_output.setName("multi_head_attention:attention_output");
+
+  attention_output.reshape(TensorDim(
+    {batch_size * query_height, 1, 1, num_heads * projected_value_dim_prop}));
+
+  attention_output.dot(fc_weight, output);
+  if (!disable_bias) {
+    output.add_i(fc_bias);
+  }
+
+  /** restore shape */
+  projected_query.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
+
+  attention_weight.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  attention_output.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
+}
+
+void MultiHeadAttentionLayer::initial_incremental_forwarding(
+  RunLayerContext &context, unsigned int _from, unsigned int _to,
+  bool training) {
+  unsigned int max_timestep =
+    std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+  bool cache_shift = false;
+  unsigned int from = _from;
+  unsigned int to = _to;
+  if (to > max_timestep) {
+    throw std::invalid_argument("to shouldn't greater than max_timestep");
+  }
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+  const bool enable_dropout = dropout_rate > epsilon;
+
+  /** get inputs/outputs */
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+
+  Tensor empty_tensor =
+    Tensor("empty_tensor", value.getFormat(), value.getDataType());
+
+  Tensor &mask =
+    provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
+
+  TensorDim query_dim = query.getDim();
+  TensorDim key_dim = key.getDim();
+  TensorDim value_dim = value.getDim();
+
+  TensorDim query_step_dim = query_dim;
+  TensorDim key_step_dim = key_dim;
+  TensorDim value_step_dim = value_dim;
+
+  query_step_dim.height(to);
+  key_step_dim.height(to);
+  value_step_dim.height(to);
+
+  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
+  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+
+  TensorDim output_dim = output.getDim();
+  TensorDim output_step_dim = output_dim;
+  output_step_dim.height(to);
+  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
+  Tensor &ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  /** get weights */
+
+  Tensor qWeight, kWeight, vWeight, fWeight, qbias, kbias, vbias, fcWeight;
+
+  Tensor &query_fc_weight = qWeight;
+  Tensor &key_fc_weight = kWeight;
+  Tensor &value_fc_weight = vWeight;
+  Tensor &fc_weight = fcWeight;
+  Tensor &query_fc_bias = qbias;
+  Tensor &key_fc_bias = kbias;
+  Tensor &value_fc_bias = vbias;
+
+  context.getWeight(query_fc_weight,
+                    weight_idx[AttentionParams::query_fc_weight]);
+  context.getWeight(key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
+  context.getWeight(value_fc_weight,
+                    weight_idx[AttentionParams::value_fc_weight]);
+
+  context.getWeight(fc_weight, weight_idx[AttentionParams::fc_weight]);
+
+  if (!disable_bias)
+    context.getWeight(query_fc_bias,
+                      weight_idx[AttentionParams::query_fc_bias]);
+  if (!disable_bias)
+    context.getWeight(key_fc_bias, weight_idx[AttentionParams::key_fc_bias]);
+
+  if (!disable_bias)
+    context.getWeight(value_fc_bias,
+                      weight_idx[AttentionParams::value_fc_bias]);
+
+  /** get tensors */
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+  Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
+  Tensor &cache_value =
+    context.getTensor(weight_idx[AttentionParams::cache_value]);
+
+  TensorDim projected_query_dim = projected_query.getDim();
+  TensorDim projected_key_dim = projected_key.getDim();
+  TensorDim projected_value_dim = projected_value.getDim();
+  TensorDim cache_key_dim = cache_key.getDim();
+  TensorDim cache_value_dim = cache_value.getDim();
+
+  TensorDim projected_query_step_dim = projected_query_dim;
+
+  TensorDim projected_key_step_dim = projected_key_dim;
+  TensorDim projected_value_step_dim = projected_value_dim;
+  TensorDim cache_key_step_dim = cache_key_dim;
+  TensorDim cache_value_step_dim = cache_value_dim;
+  projected_query_step_dim.height(to);
+
+  projected_key_step_dim.height(to);
+  projected_value_step_dim.height(to);
+  cache_key_step_dim.height(to);
+  cache_value_step_dim.height(to);
+
+  Tensor projected_query_step =
+    projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
+  Tensor projected_key_step =
+    projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
+  Tensor projected_value_step =
+    projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
+
+  Tensor cache_key_step =
+    cache_key.getSharedDataTensor(cache_key_step_dim, 0, true);
+  Tensor cache_value_step =
+    cache_value.getSharedDataTensor(cache_value_step_dim, 0, true);
+
+  TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
+                              to, cache_key_dim.width(),
+                              cache_key.getTensorType()};
+  TensorDim cached_value_dim = {
+    cache_value_dim.batch(), cache_value_dim.channel(), to,
+    cache_value_dim.width(), cache_value.getTensorType()};
+  Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
+  Tensor cached_value =
+    cache_value.getSharedDataTensor(cached_value_dim, 0, true);
+
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &attention_output =
+    context.getTensor(weight_idx[AttentionParams::attention_output]);
+  TensorDim attention_weight_dim = attention_weight.getDim();
+
+  TensorDim attention_weight_step_dim = attention_weight_dim;
+  attention_weight_step_dim.height(to);
+  attention_weight_step_dim.width(to);
+
+  Tensor attention_weight_step =
+    attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
+
+  TensorDim attention_output_dim = attention_output.getDim();
+  TensorDim attention_output_step_dim = attention_output_dim;
+  attention_output_step_dim.height(to);
+
+  Tensor attention_output_step =
+    attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
+
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const unsigned int key_height = key_dim.height();
+  const unsigned int value_height = value_dim.height();
+
+  query_step.dot(query_fc_weight, projected_query_step);
+  if (!disable_bias) {
+    projected_query_step.add_i(query_fc_bias);
+  }
+  key_step.dot(key_fc_weight, cache_key_step);
+  if (!disable_bias) {
+    cache_key_step.add_i(key_fc_bias);
+  }
+  value_step.dot(value_fc_weight, cache_value_step);
+  if (!disable_bias) {
+    cache_value_step.add_i(value_fc_bias);
+  }
+
+  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop,
+                          _from);
+  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, _from);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size, to, num_heads, projected_query_dim_prop}));
+
+  cached_key.reshape(
+    TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
+  cached_value.reshape(
+    TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
+
+  projected_query_step.transpose("1:0:2", projected_query_step);
+  cached_key.transpose("1:0:2", projected_key_step);
+  cached_value.transpose("1:0:2", projected_value_step);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_query_dim_prop}));
+  projected_key_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
+  projected_value_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+  attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, to, to}));
+  attention_output_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+  /** scaled dot product attention */
+  projected_query_step.dotBatched(projected_key_step, attention_weight_step,
+                                  false, true);
+  attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+
+  if (!from) {
+    unsigned int mask_size = attention_weight_step.getDim().width();
+    unsigned int mask_dim_height = mask_size;
+    unsigned int mask_dim_width = mask_size;
+
+    Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
+                                 attention_weight_step.getTensorType()});
+
+    causal_mask.setZero();
+
+    for (unsigned int i = 0; i < mask_dim_height; ++i) {
+      for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
+        causal_mask.setValue(
+          0, 0, i, j, _MASK_NUM(attention_weight.getTensorType().data_type));
+      }
+    }
+
+    attention_weight_step.add_i(causal_mask);
+  }
+
+  sm.run_fn(attention_weight_step, attention_weight_step);
+
+  attention_weight_step.dotBatched(projected_value_step, attention_output_step);
+
+  attention_output_step.reshape(
+    TensorDim({batch_size, num_heads, to, projected_value_dim_prop}));
+
+  attention_output_step = attention_output_step.transpose("1:0:2");
+
+  attention_output_step.reshape(
+    TensorDim({batch_size * to, 1, 1, num_heads * projected_value_dim_prop}));
+
+  attention_output_step.dot(fc_weight, output_step);
+  if (!disable_bias) {
+    output_step.add_i(fc_bias);
+  }
+
+  // if (layer_progress == 28)
+  //   layer_progress = 0;
+  // layer_progress++;
+
+  // std::cout << "Process Reading: " << (int)((layer_progress / 28.0) * 100.0)
+  //           << " % \r";
+  // std::cout.flush();
+}
+
+void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
+                                                     unsigned int _from,
+                                                     unsigned int _to,
+                                                     bool training) {
+
+  if (!_from) {
+    initial_incremental_forwarding(context, _from, _to, training);
+    return;
+  }
+
+  unsigned int max_timestep =
+    std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+  bool cache_shift = false;
+  unsigned int from = _from;
+  unsigned int to = _to;
+  if (to >= max_timestep) {
+    cache_shift = true;
+    from = max_timestep - 1;
+    to = max_timestep;
+  }
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+  const bool enable_dropout = dropout_rate > epsilon;
+
+  /** get inputs/outputs */
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+
+  Tensor empty_tensor =
+    Tensor("empty_tensor", value.getFormat(), value.getDataType());
+
+  Tensor &mask =
+    provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
+
+  TensorDim query_dim = query.getDim();
+  TensorDim key_dim = key.getDim();
+  TensorDim value_dim = value.getDim();
+
+  TensorDim query_step_dim = query_dim;
+  TensorDim key_step_dim = key_dim;
+  TensorDim value_step_dim = value_dim;
+
+  query_step_dim.height(to - from);
+  key_step_dim.height(to - from);
+  value_step_dim.height(to - from);
+
+  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
+  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+
+  TensorDim output_dim = output.getDim();
+
+  TensorDim output_step_dim = output_dim;
+  output_step_dim.height(to - from);
+  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
+  Tensor &ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  /** get weights */
+  Tensor qWeight, kWeight, vWeight, fWeight, qbias, kbias, vbias, fcWeight;
+  Tensor &query_fc_weight = qWeight;
+  Tensor &key_fc_weight = kWeight;
+  Tensor &value_fc_weight = vWeight;
+  Tensor &fc_weight = fcWeight;
+  Tensor &query_fc_bias = qbias;
+  Tensor &key_fc_bias = kbias;
+  Tensor &value_fc_bias = vbias;
+
+  // auto getWeight_Job = [&](Tensor &t, unsigned int idx) {
+  //   context.getWeight(t, idx);
+  // };
+
+  // auto get_key = std::async(std::launch::async, &RunLayerContext::getWeight,
+  // &context, key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
+
+  // auto get_key = std::async(std::launch::async, getWeight_Job,
+  // std::ref(key_fc_weight),weight_idx[AttentionParams::key_fc_weight] );
+
+  // start = clock();
+  context.getWeight(key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
+  // auto get_value = std::async(std::launch::async,
+  // &RunLayerContext::getWeight, &context, value_fc_weight,
+  // weight_idx[AttentionParams::value_fc_weight]);
+
+  // auto get_value = std::async(std::launch::async, getWeight_Job,
+  // std::ref(value_fc_weight),weight_idx[AttentionParams::value_fc_weight]);
+
+  // auto get_fc = std::async(std::launch::async, getWeight_Job,
+  // std::ref(fc_weight),weight_idx[AttentionParams::fc_weight]);
+
+  // auto get_fc = std::async(std::launch::async, &RunLayerContext::getWeight,
+  // &context, fc_weight, weight_idx[AttentionParams::fc_weight]);
+
+  context.getWeight(query_fc_weight,
+                    weight_idx[AttentionParams::query_fc_weight]);
+  context.getWeight(value_fc_weight,
+                    weight_idx[AttentionParams::value_fc_weight]);
+
+  context.getWeight(fc_weight, weight_idx[AttentionParams::fc_weight]);
+  // finish=clock();
+  // std::cout << "dequanized :" << (double)(finish-start)<<std::endl;
+  //   disable_bias
+  //     ? empty_tensor
+  //     : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
+
+  if (!disable_bias)
+    context.getWeight(query_fc_bias,
+                      weight_idx[AttentionParams::query_fc_bias]);
+  if (!disable_bias)
+    context.getWeight(key_fc_bias, weight_idx[AttentionParams::key_fc_bias]);
+  if (!disable_bias)
+    context.getWeight(value_fc_bias,
+                      weight_idx[AttentionParams::value_fc_bias]);
+
+  /** get tensors */
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+  Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
+  Tensor &cache_value =
+    context.getTensor(weight_idx[AttentionParams::cache_value]);
+
+  TensorDim projected_query_dim = projected_query.getDim();
+  TensorDim projected_key_dim = projected_key.getDim();
+  TensorDim projected_value_dim = projected_value.getDim();
+  TensorDim cache_key_dim = cache_key.getDim();
+  TensorDim cache_value_dim = cache_value.getDim();
+
+  TensorDim projected_query_step_dim = projected_query_dim;
+
+  TensorDim projected_key_step_dim = projected_key_dim;
+  TensorDim projected_value_step_dim = projected_value_dim;
+  TensorDim cache_key_step_dim = cache_key_dim;
+  TensorDim cache_value_step_dim = cache_value_dim;
+  projected_query_step_dim.height(to - from);
+
+  projected_key_step_dim.height(to);
+  projected_value_step_dim.height(to);
+  cache_key_step_dim.height(to - from);
+  cache_value_step_dim.height(to - from);
+
+  Tensor projected_query_step =
+    projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
+  Tensor projected_key_step =
+    projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
+  Tensor projected_value_step =
+    projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
+
+  Tensor cache_key_step = cache_key.getSharedDataTensor(
+    cache_key_step_dim, from * cache_key_dim.width(), true);
+  Tensor cache_value_step = cache_value.getSharedDataTensor(
+    cache_value_step_dim, from * cache_value_dim.width(), true);
+
+  TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
+                              to, cache_key_dim.width(),
+                              cache_key.getTensorType()};
+  TensorDim cached_value_dim = {
+    cache_value_dim.batch(), cache_value_dim.channel(), to,
+    cache_value_dim.width(), cache_value.getTensorType()};
+  Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
+  Tensor cached_value =
+    cache_value.getSharedDataTensor(cached_value_dim, 0, true);
+
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &attention_output =
+    context.getTensor(weight_idx[AttentionParams::attention_output]);
+  TensorDim attention_weight_dim = attention_weight.getDim();
+
+  TensorDim attention_weight_step_dim = attention_weight_dim;
+  attention_weight_step_dim.height(to - from);
+  attention_weight_step_dim.width(to);
+
+  Tensor attention_weight_step =
+    attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
+
+  TensorDim attention_output_dim = attention_output.getDim();
+  TensorDim attention_output_step_dim = attention_output_dim;
+  attention_output_step_dim.height(to - from);
+
+  Tensor attention_output_step =
+    attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
+
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const unsigned int key_height = key_dim.height();
+  const unsigned int value_height = value_dim.height();
+
+  query_step.dot(query_fc_weight, projected_query_step);
+
+  if (!disable_bias) {
+    projected_query_step.add_i(query_fc_bias);
+  }
+  key_step.dot(key_fc_weight, cache_key_step);
+  if (!disable_bias) {
+    cache_key_step.add_i(key_fc_bias);
+  }
+  value_step.dot(value_fc_weight, cache_value_step);
+  if (!disable_bias) {
+    cache_value_step.add_i(value_fc_bias);
+  }
+
+  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop,
+                          _from);
+  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, _from);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size, 1, num_heads, projected_query_dim_prop}));
+  cached_key.reshape(
+    TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
+  cached_value.reshape(
+    TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
+
+  projected_query_step.transpose("1:0:2", projected_query_step);
+  cached_key.transpose("1:0:2", projected_key_step);
+  cached_value.transpose("1:0:2", projected_value_step);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size * num_heads, 1, 1, projected_query_dim_prop}));
+  projected_key_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
+  projected_value_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+  attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, 1, to}));
+  attention_output_step.reshape(
+    TensorDim({batch_size * num_heads, 1, 1, projected_value_dim_prop}));
+
+  /** scaled dot product attention */
+  projected_query_step.dotBatched(projected_key_step, attention_weight_step,
+                                  false, true);
+  attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+
+  if (!from) {
+    unsigned int mask_size = attention_weight_step.getDim().width();
+    unsigned int mask_dim_height = mask_size;
+    unsigned int mask_dim_width = mask_size;
+
+    Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
+                                 attention_weight_step.getTensorType()});
+
+    causal_mask.setZero();
+
+    for (unsigned int i = 0; i < mask_dim_height; ++i) {
+      for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
+        causal_mask.setValue(
+          0, 0, i, j, _MASK_NUM(attention_weight.getTensorType().data_type));
+      }
+    }
+
+    attention_weight_step.add_i(causal_mask);
+  }
+
+  sm.run_fn(attention_weight_step, attention_weight_step);
+
+  attention_weight_step.dotBatched(projected_value_step, attention_output_step);
+
+  attention_output_step.reshape(
+    TensorDim({batch_size, num_heads, to - from, projected_value_dim_prop}));
+
+  attention_output_step = attention_output_step.transpose("1:0:2");
+
+  attention_output_step.reshape(TensorDim(
+    {batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
+
+  attention_output_step.dot(fc_weight, output_step);
+  if (!disable_bias) {
+    output_step.add_i(fc_bias);
+  }
+
+  if (cache_shift) {
+    if (cache_key.getDataType() == ml::train::TensorDim::DataType::FP32) {
+      float *buf = cache_key.getAddress<float>(0, 0, 1, 0);
+      float *dbuf = cache_key.getAddress<float>(0, 0, 0, 0);
+      memcpy(dbuf, buf, (cache_key.size() - cache_key.width()) * sizeof(float));
+      buf = cache_value.getAddress<float>(0, 0, 1, 0);
+      dbuf = cache_value.getAddress<float>(0, 0, 0, 0);
+      memcpy(dbuf, buf,
+             (cache_value.size() - cache_value.width()) * sizeof(float));
+    } else if (cache_key.getDataType() ==
+               ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+
+      _FP16 *buf = cache_key.getAddress<_FP16>(0, 0, 1, 0);
+      _FP16 *dbuf = cache_key.getAddress<_FP16>(0, 0, 0, 0);
+      memcpy(dbuf, buf, (cache_key.size() - cache_key.width()) * sizeof(_FP16));
+      buf = cache_value.getAddress<_FP16>(0, 0, 1, 0);
+      dbuf = cache_value.getAddress<_FP16>(0, 0, 0, 0);
+      memcpy(dbuf, buf,
+             (cache_key.size() - cache_value.width()) * sizeof(_FP16));
+#else
+      throw std::invalid_argument("enable-fp16 is not set");
+#endif
+    }
+  }
+}
+
+void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+
+  Tensor empty_tensor;
+
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
+  const Tensor &d_ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &d_projected_query =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &d_projected_key =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+  Tensor &d_projected_value =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
+
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &d_attention_weight =
+    context.getTensorGrad(weight_idx[AttentionParams::attention_weight]);
+  Tensor &d_attention_output =
+    context.getTensorGrad(weight_idx[AttentionParams::attention_output]);
+
+  const TensorDim query_dim = query.getDim();
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const TensorDim key_dim = key.getDim();
+  const unsigned int key_height = key_dim.height();
+  const TensorDim value_dim = value.getDim();
+  const unsigned int value_height = value_dim.height();
+
+  d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative);
+
+  d_attention_output.reshape(
+    TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop}));
+
+  d_attention_output = d_attention_output.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove
+   * during transpose */
+  d_attention_output.setName("multi_head_attention:attention_output:grad");
+
+  projected_query.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+  d_projected_query.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+  d_projected_key.reshape(
+    TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+  d_projected_value.reshape(TensorDim(
+    {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
+
+  attention_weight.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  d_attention_weight.reshape(
+    TensorDim({batch_size * num_heads, 1, query_height, key_height}));
+  d_attention_output.reshape(TensorDim(
+    {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
+
+  d_attention_weight.dot_batched_deriv_wrt_1(projected_value,
+                                             d_attention_output);
+  attention_weight.dot_batched_deriv_wrt_2(d_projected_value,
+                                           d_attention_output);
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::after) {
+    const float scale = average_attention_weight ? 1 / (float)num_heads : 1;
+    d_attention_weight.add_i(d_ret_attention_weight, scale);
+  }
+
+  if (dropout_rate > epsilon) {
+    Tensor &dropout_mask =
+      context.getTensor(weight_idx[AttentionParams::dropout_mask]);
+    d_attention_weight.multiply_i(dropout_mask);
+  }
+
+  if (return_attention_weight ==
+      props::ReturnAttentionWeightInfo::Enum::before) {
+    d_attention_weight.add_i(d_ret_attention_weight);
+  }
+
+  sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight);
+  if (provide_attention_mask) {
+    Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
+    d_mask.copyData(d_attention_weight);
+  }
+  d_attention_weight.multiply_i(
+    1 / sqrt((float)projected_query_dim_prop)); /** scale */
+
+  d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight,
+                                            false, true);
+  projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight,
+                                          false, true);
+
+  d_projected_query.reshape(
+    TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop}));
+  d_projected_key.reshape(
+    TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop}));
+  d_projected_value.reshape(
+    TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop}));
+
+  d_projected_query = d_projected_query.transpose("1:0:2");
+  d_projected_key = d_projected_key.transpose("1:0:2");
+  d_projected_value = d_projected_value.transpose("1:0:2");
+
+  /** set tensor name to restore origin name cause origin name was remove
+   * during transpose */
+  d_projected_query.setName("multi_head_attention:projected_query:grad");
+  d_projected_key.setName("multi_head_attention:projected_key:grad");
+  d_projected_value.setName("multi_head_attention:projected_value:grad");
+
+  /** restore shape */
+  projected_query.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
+  d_projected_query.reshape(TensorDim(
+    {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
+  projected_key.reshape(
+    TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
+  d_projected_key.reshape(TensorDim(
+    {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
+  projected_value.reshape(TensorDim(
+    {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
+  d_projected_value.reshape(TensorDim(
+    {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
+
+  attention_weight.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  d_attention_weight.reshape(
+    TensorDim({batch_size, num_heads, query_height, key_height}));
+  d_attention_output.reshape(TensorDim(
+    {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
+}
+
+void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) {
+  if (!context.getTrainable()) {
+    calcCommonDerivative(context);
+  }
+
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+  Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE);
+  /** d_mask will be calculated in calcCommonDerivative */
+
+  Tensor &query_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
+  Tensor &key_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
+  Tensor &value_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
+
+  Tensor &d_projected_query =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
+  Tensor &d_projected_key =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
+  Tensor &d_projected_value =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
+
+  const TensorDim query_dim = query.getDim();
+  const TensorDim key_dim = key.getDim();
+  const TensorDim value_dim = value.getDim();
+
+  d_query.dot_deriv_wrt_1(query_fc_weight, d_projected_query);
+  d_key.dot_deriv_wrt_1(key_fc_weight, d_projected_key);
+  d_value.dot_deriv_wrt_1(value_fc_weight, d_projected_value, false, false);
+}
+
+void MultiHeadAttentionLayer::calcGradient(RunLayerContext &context) {
+  calcCommonDerivative(context);
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const unsigned int output_shape =
+    std::get<props::OutputShape>(multi_head_attention_props).get();
+
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
+
+  Tensor &d_query_fc_weight =
+    context.getWeightGrad(weight_idx[AttentionParams::query_fc_weight]);
+  Tensor &d_key_fc_weight =
+    context.getWeightGrad(weight_idx[AttentionParams::key_fc_weight]);
+  Tensor &d_value_fc_weight =
+    context.getWeightGrad(weight_idx[AttentionParams::value_fc_weight]);
+  Tensor &d_fc_weight =
+    context.getWeightGrad(weight_idx[AttentionParams::fc_weight]);
+
+  Tensor empty_tensor;
+  Tensor &d_query_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeightGrad(weight_idx[AttentionParams::query_fc_bias]);
+  Tensor &d_key_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeightGrad(weight_idx[AttentionParams::key_fc_bias]);
+  Tensor &d_value_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeightGrad(weight_idx[AttentionParams::value_fc_bias]);
+  Tensor &d_fc_bias =
+    disable_bias ? empty_tensor
+                 : context.getWeightGrad(weight_idx[AttentionParams::fc_bias]);
+
+  Tensor &d_projected_query =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
+  Tensor &d_projected_key =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
+  Tensor &d_projected_value =
+    context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
+
+  Tensor &attention_output =
+    context.getTensor(weight_idx[AttentionParams::attention_output]);
+
+  const TensorDim query_dim = query.getDim();
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const TensorDim key_dim = key.getDim();
+  const unsigned int key_height = key_dim.height();
+  const TensorDim value_dim = value.getDim();
+  const unsigned int value_height = value_dim.height();
+
+  attention_output.dot_deriv_wrt_2(
+    d_fc_weight, incoming_derivative, false, false,
+    !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_weight]));
+
+  if (!disable_bias) {
+    Tensor incoming_derivative_ = incoming_derivative;
+    incoming_derivative_.reshape(
+      TensorDim({batch_size * query_height, 1, 1, output_shape}));
+    incoming_derivative_.sum(
+      0, d_fc_bias, 1,
+      !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_bias]));
+  }
+
+  query.dot_deriv_wrt_2(d_query_fc_weight, d_projected_query, false, false,
+                        !context.isGradientFirstAccess(
+                          weight_idx[AttentionParams::query_fc_weight]));
+  if (!disable_bias) {
+    d_projected_query.reshape(TensorDim(
+      {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
+    d_projected_query.sum(0, d_query_fc_bias, 1,
+                          !context.isGradientFirstAccess(
+                            weight_idx[AttentionParams::query_fc_bias]));
+    d_projected_query.reshape(TensorDim(
+      {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
+  }
+
+  key.dot_deriv_wrt_2(
+    d_key_fc_weight, d_projected_key, false, false,
+    !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_weight]));
+  if (!disable_bias) {
+    d_projected_key.reshape(TensorDim(
+      {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
+    d_projected_key.sum(
+      0, d_key_fc_bias, 1,
+      !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_bias]));
+    d_projected_key.reshape(TensorDim(
+      {batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
+  }
+
+  value.dot_deriv_wrt_2(d_value_fc_weight, d_projected_value, false, false,
+                        !context.isGradientFirstAccess(
+                          weight_idx[AttentionParams::value_fc_weight]));
+  if (!disable_bias) {
+    d_projected_value.reshape(TensorDim(
+      {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
+    d_projected_value.sum(0, d_value_fc_bias, 1,
+                          !context.isGradientFirstAccess(
+                            weight_idx[AttentionParams::value_fc_bias]));
+    d_projected_value.reshape(TensorDim(
+      {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
+  }
+}
+
+void MultiHeadAttentionLayer::setProperty(
+  const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, multi_head_attention_props);
+  LayerImpl::setProperty(remain_props);
+}
+
+void MultiHeadAttentionLayer::setBatch(RunLayerContext &context,
+                                       unsigned int batch) {
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+
+  context.updateTensor(weight_idx[AttentionParams::projected_query], batch);
+  context.updateTensor(weight_idx[AttentionParams::projected_key], batch);
+  context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
+  context.updateTensor(weight_idx[AttentionParams::cache_key], batch);
+  context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
+  context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
+  if (dropout_rate > epsilon) {
+    context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);
+  }
+  context.updateTensor(weight_idx[AttentionParams::attention_output], batch);
+}
+
+void MultiHeadAttentionLayer::exportTo(
+  Exporter &exporter, const ml::train::ExportMethods &method) const {
+  LayerImpl::exportTo(exporter, method);
+  exporter.saveResult(multi_head_attention_props, method, this);
+}
+
+} /* namespace nntrainer */
diff --git a/Applications/LLaMA/jni/custom_multi_head_attention_layer.h b/Applications/LLaMA/jni/custom_multi_head_attention_layer.h
new file mode 100644 (file)
index 0000000..f6f5e10
--- /dev/null
@@ -0,0 +1,309 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
+ *
+ * @file   multi_head_attention_layer.h
+ * @date   08 July 2022
+ * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/abs/1706.03762
+ * @author hyeonseok Lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is MultiHeadAttention Layer Class for Neural Network
+ *
+ */
+
+#ifndef __MULTI_HEAD_ATTENTION_LAYER_H__
+#define __MULTI_HEAD_ATTENTION_LAYER_H__
+#ifdef __cplusplus
+
+#include <acti_func.h>
+#include <complex>
+#include <layer_impl.h>
+#include <util_simd.h>
+#include <utility>
+
+namespace nntrainer {
+
+/**
+ * @class   Multi Head Attention Layer
+ * @brief   Implementation of multi head attention which is described in paper
+ * "Attention is all you need"
+ */
+class MultiHeadAttentionLayer : public LayerImpl {
+public:
+  /**
+   * @brief     Constructor of MultiHeadAttention Layer
+   */
+  MultiHeadAttentionLayer();
+
+  /**
+   * @brief     Destructor of MultiHeadAttention Layer
+   */
+  ~MultiHeadAttentionLayer();
+
+  /**
+   *  @brief  Move constructor of MultiHeadAttentionLayer.
+   *  @param[in] MultiHeadAttentionLayer &&
+   */
+  MultiHeadAttentionLayer(MultiHeadAttentionLayer &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs MultiHeadAttentionLayer to be moved.
+   */
+  MultiHeadAttentionLayer &operator=(MultiHeadAttentionLayer &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void initial_incremental_forwarding(RunLayerContext &context,
+                                      unsigned int from, unsigned int to,
+                                      bool training);
+
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return true; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
+   * method)
+   */
+  void exportTo(Exporter &exporter,
+                const ml::train::ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override {
+    return MultiHeadAttentionLayer::type;
+  };
+
+  /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(RunLayerContext &context, unsigned int batch) override;
+
+  inline static const std::string type = "multi_head_attention";
+
+private:
+  std::tuple<props::NumHeads, props::ProjectedKeyDim, props::ProjectedValueDim,
+             props::OutputShape, props::DropOutRate,
+             props::ReturnAttentionWeight, props::AverageAttentionWeight,
+             props::MaxTimestep>
+    multi_head_attention_props; /**< multi_head_attention layer properties */
+
+  ActiFunc sm; /** softmax activation operation */
+  std::array<unsigned int, 16>
+    weight_idx; /**< indices of the weights and tensors */
+
+  /**
+   * @brief     to protect overflow
+   */
+  float epsilon;
+
+  unsigned int cache_index;
+
+  inline static unsigned int layer_progress;
+
+  inline static std::vector<std::vector<float>> *freqs_cos = {};
+  inline static std::vector<std::vector<float>> *freqs_sin = {};
+  inline static std::vector<float> freqs;
+
+  /**
+   * @brief     compute frequency for rotary embedding
+   * @param[in] dim hidden dim size
+   * @param[in] seq_len sequency length
+   * @param[in] theta rotary angle
+   */
+  void precompute_freqs(int dim, unsigned int seq_len, float theta = 10000.0) {
+    if (freqs_cos == nullptr) {
+      unsigned int half_ = dim / 2;
+      for (unsigned int i = 0; i < half_; ++i) {
+        freqs.push_back(1.0 /
+                        (std::pow(theta, (2 * i) / static_cast<float>(dim))));
+      }
+
+      auto cos = new std::vector<std::vector<float>>();
+      cos->assign(seq_len, std::vector<float>(dim, 0));
+
+      auto sin = new std::vector<std::vector<float>>();
+      sin->assign(seq_len, std::vector<float>(dim, 0));
+
+      for (unsigned int i = 0; i < seq_len; ++i) {
+#ifdef USE_NEON
+        calc_trigonometric_vals_dup(half_, freqs.data(), (*cos)[i].data(),
+                                    (*sin)[i].data(), i);
+#else
+        for (unsigned int j = 0; j < half_; ++j) {
+          float angle = i * freqs[j];
+          (*cos)[i][j] = std::cos(angle);
+          (*cos)[i][j + half_] = std::cos(angle); // repeated 2 times
+
+          (*sin)[i][j] = std::sin(angle);
+          (*sin)[i][j + half_] = std::sin(angle); // repeated 2 times
+        }
+#endif
+      }
+      freqs_cos = cos;
+      freqs_sin = sin;
+    }
+  }
+
+  /**
+   * @brief     apply rotary embedding
+   * @param[in] in input tensor
+   * @param[in] dim hidden dim size
+   * @param[in] from sequence order
+   */
+  void apply_rotary_emb_tensor(Tensor &in, unsigned int dim,
+                               unsigned int from) {
+    Tensor out(in.getDim());
+    float value = 0;
+    float transformed_value = 0.0;
+    unsigned int half_ = dim / 2;
+    unsigned int max_timestep =
+      std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+    std::vector<float> *cos_;
+    std::vector<float> *sin_;
+
+    if (from >= max_timestep) {
+      cos_ = new std::vector<float>(dim);
+      sin_ = new std::vector<float>(dim);
+#ifdef USE_NEON
+      calc_trigonometric_vals_dup(half_, freqs.data(), cos_->data(),
+                                  sin_->data(), from);
+#else
+      for (unsigned int i = 0; i < half_; ++i) {
+        float angle = from * freqs[i];
+        (*cos_)[i] = std::cos(angle);
+        (*cos_)[i + half_] = std::cos(angle); // repeated 2 times
+
+        (*sin_)[i] = std::sin(angle);
+        (*sin_)[i + half_] = std::sin(angle); // repeated 2 times
+      }
+#endif
+    }
+
+    if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
+      for (unsigned int b = 0; b < in.batch(); b++) {
+        for (unsigned int c = 0; c < in.channel(); c++) {
+          for (unsigned int h = 0; h < in.height(); h++) {
+            if (from < max_timestep) {
+              cos_ = &(*freqs_cos)[from + h];
+              sin_ = &(*freqs_sin)[from + h];
+            }
+
+            for (unsigned int w = 0; w < in.width(); w = w + dim) {
+              for (unsigned int k = 0; k < dim; k++) {
+                unsigned int span = w + k;
+                value = in.getValue<float>(b, c, h, span);
+
+                if (k < half_) {
+                  transformed_value =
+                    -1.0 * in.getValue<float>(b, c, h, span + half_);
+                } else {
+                  transformed_value = in.getValue<float>(b, c, h, span - half_);
+                }
+                value = value * (*cos_)[k] + transformed_value * (*sin_)[k];
+                out.setValue(b, c, h, span, value);
+              }
+            }
+          }
+        }
+      }
+    } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+      for (unsigned int b = 0; b < in.batch(); b++) {
+        for (unsigned int c = 0; c < in.channel(); c++) {
+          for (unsigned int h = 0; h < in.height(); h++) {
+            if (from < max_timestep) {
+              cos_ = &(*freqs_cos)[from + h];
+              sin_ = &(*freqs_sin)[from + h];
+            }
+            for (unsigned int w = 0; w < in.width(); w = w + dim) {
+#ifdef USE_NEON
+              compute_rotary_embedding_value(
+                dim, half_, w, in.getData<_FP16>() + in.getIndex(b, c, h, 0),
+                out.getData<_FP16>() + out.getIndex(b, c, h, 0), cos_->data(),
+                sin_->data());
+#else
+              for (unsigned int k = 0; k < dim; k++) {
+                unsigned int span = w + k;
+                value = static_cast<float>(in.getValue<_FP16>(b, c, h, span));
+
+                if (k < half_) {
+                  transformed_value =
+                    -1.0 * static_cast<float>(
+                             in.getValue<_FP16>(b, c, h, half_ + span));
+                } else {
+                  transformed_value = static_cast<float>(
+                    in.getValue<_FP16>(b, c, h, span - half_));
+                }
+                out.setValue(
+                  b, c, h, span,
+                  static_cast<_FP16>(value * (*cos_)[k] +
+                                     transformed_value * (*sin_)[k]));
+              }
+#endif
+            }
+          }
+        }
+      }
+#else
+      throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+    }
+
+    if (from >= max_timestep) {
+      delete cos_;
+      delete sin_;
+    }
+    in.copy(out);
+  }
+
+  /**
+   * @brief calculate common derivative
+   * @param context Context of the layer
+   */
+  void calcCommonDerivative(RunLayerContext &context);
+};
+
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __MULTI_HEAD_ATTENTION_LAYER_H__ */
index 985d82a79efc813a826efeafdbf357eb89b8723c..c90cb7430552623d6dc61e231ed635b2e4e59579 100644 (file)
@@ -25,6 +25,7 @@
 #include <optimizer.h>
 
 #include <app_context.h>
+#include <custom_multi_head_attention_layer.h>
 #include <rms_norm.h>
 #include <rotary_embedding.h>
 #include <swiglu.h>
index 24ebf6593f46adee53f1e50e4276a7c65b3a4496..9e6b02dbf94ade10eb7f35fca5fdd61b9b688563 100644 (file)
@@ -59,12 +59,27 @@ rotary_emb_dep = declare_dependency(
   include_directories: include_directories('./')
 )
 
+mha_src = files('custom_multi_head_attention_layer.cpp')
+mha_layer = shared_library('custom_multi_head_attention_layer',
+  mha_src,
+  dependencies: [nntrainer_dep, nntrainer_ccapi_dep],
+  include_directories: include_directories('./'),
+  install: true,
+  install_dir: application_install_dir,
+  cpp_args: '-DPLUGGABLE'
+)
+mha_dep = declare_dependency(
+  link_with: mha_layer,
+  include_directories: include_directories('./')
+)
+
 llama_sources = [
   'main.cpp',
   cifar_path / 'cifar_dataloader.cpp',
   rms_norm_src,
   swiglu_src,
-  rotary_emb_src
+  rotary_emb_src,
+  mha_src
 ]
 
 llama_dependencies = [app_utils_dep,
@@ -73,7 +88,8 @@ llama_dependencies = [app_utils_dep,
   transpose_dep,
   rms_norm_dep,
   swiglu_dep,
-  rotary_emb_dep
+  rotary_emb_dep,
+  mha_dep
 ]
 
 e = executable('nntrainer_llama',
index 622459a41bbad0271174400a3d4a6747c51659ae..0d4b73b67fbb427dffcaf6d0bb01d357fee178ca 100644 (file)
  *
  */
 
-#include <algorithm>
 #include <cmath>
+
 #include <layer_context.h>
 #include <multi_head_attention_layer.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <node_exporter.h>
-#include <thread>
-#include <vector>
 
 namespace nntrainer {
 
@@ -28,12 +26,10 @@ MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
   multi_head_attention_props(
     props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
     props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
-    props::AverageAttentionWeight(), props::MaxTimestep()),
+    props::AverageAttentionWeight()),
   sm(ActivationType::ACT_SOFTMAX),
-  epsilon(1e-3),
-  cache_index(0) {
+  epsilon(1e-3) {
   weight_idx.fill(std::numeric_limits<unsigned>::max());
-  layer_progress = 0;
 }
 
 MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
@@ -95,9 +91,9 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
   const unsigned int batch_size = query_dim.batch();
   const unsigned int query_height = query_dim.height();
   const unsigned int query_width = query_dim.width();
-  // const unsigned int key_height = key_dim.height();
+  const unsigned int key_height = key_dim.height();
   const unsigned int key_width = key_dim.width();
-  // const unsigned int value_height = value_dim.height();
+  const unsigned int value_height = value_dim.height();
   const unsigned int value_width = value_dim.width();
 
   const bool disable_bias =
@@ -159,13 +155,6 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
   const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
     std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
 
-  const unsigned int max_timestep =
-    std::get<props::MaxTimestep>(multi_head_attention_props).get();
-
-  // @todo: fix me
-  const unsigned int key_height = max_timestep;
-  const unsigned int value_height = max_timestep;
-
   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
 
   if (activation_type.data_type == TensorDim::DataType::FP32) {
@@ -289,18 +278,12 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
     projected_value_dim, "projected_value", Tensor::Initializer::NONE, true,
     TensorLifespan::ITERATION_LIFESPAN);
 
-  TensorDim cache_key_dim(
-    {batch_size, 1, max_timestep, num_heads * projected_key_dim_prop},
-    activation_type);
-  weight_idx[AttentionParams::cache_key] =
-    context.requestTensor(cache_key_dim, "cache_key", Tensor::Initializer::NONE,
-                          true, TensorLifespan::MAX_LIFESPAN);
+  weight_idx[AttentionParams::cache_key] = context.requestTensor(
+    projected_key_dim, "cache_key", Tensor::Initializer::NONE, true,
+    TensorLifespan::MAX_LIFESPAN);
 
-  TensorDim cache_value_dim(
-    {batch_size, 1, max_timestep, num_heads * projected_value_dim_prop},
-    activation_type);
   weight_idx[AttentionParams::cache_value] = context.requestTensor(
-    cache_value_dim, "cache_value", Tensor::Initializer::NONE, true,
+    projected_value_dim, "cache_value", Tensor::Initializer::NONE, true,
     TensorLifespan::MAX_LIFESPAN);
 
   if (provide_attention_mask) {
@@ -345,22 +328,10 @@ void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
   } else {
     context.setOutputDimensions({output_dim});
   }
-
-  /**
-   * @todo
-   * check query width and key width
-   *
-   */
-  if (freqs_cos == nullptr)
-    precompute_freqs(projected_key_dim_prop, max_timestep);
 }
 
-#define _MASK_NUM(datatype) \
-  (((datatype) == ml::train::TensorDim::DataType::FP16) ? (-1e4) : (-1e10))
-
 void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
                                          bool training) {
-
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -453,9 +424,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     projected_value.add_i(value_fc_bias);
   }
 
-  apply_rotary_emb_tensor(projected_query, projected_query_dim_prop, 0);
-  apply_rotary_emb_tensor(projected_key, projected_key_dim_prop, 0);
-
   projected_query.reshape(
     TensorDim({batch_size, query_height, num_heads, projected_query_dim_prop}));
   projected_key.reshape(
@@ -467,8 +435,8 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
   projected_key = projected_key.transpose("1:0:2");
   projected_value = projected_value.transpose("1:0:2");
 
-  /** set tensor name to restore origin name cause origin name was remove
-   * during transpose */
+  /** set tensor name to restore origin name cause origin name was remove during
+   * transpose */
   projected_query.setName("multi_head_attention:projected_query");
   projected_key.setName("multi_head_attention:projected_key");
   projected_value.setName("multi_head_attention:projected_value");
@@ -489,24 +457,6 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
   projected_query.dotBatched(projected_key, attention_weight, false, true);
   attention_weight.multiply_i(1 / sqrt((float)projected_query_dim_prop));
 
-  unsigned int mask_size = attention_weight.getDim().width();
-  unsigned int mask_dim_height = mask_size;
-  unsigned int mask_dim_width = mask_size;
-
-  Tensor causal_mask(
-    TensorDim{1, 1, mask_size, mask_size, attention_weight.getTensorType()});
-
-  causal_mask.setZero();
-
-  for (unsigned int i = 0; i < mask_dim_height; ++i) {
-    for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
-      causal_mask.setValue(0, 0, i, j,
-                           _MASK_NUM(attention_weight.getDataType()));
-    }
-  }
-
-  attention_weight.add_i(causal_mask);
-
   if (provide_attention_mask) {
     // Tensor &attention_mask =
     //   context.getTensor(weight_idx[AttentionParams::attention_mask]);
@@ -591,298 +541,10 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
 }
 
-void MultiHeadAttentionLayer::initial_incremental_forwarding(
-  RunLayerContext &context, unsigned int _from, unsigned int _to,
-  bool training) {
-  unsigned int max_timestep =
-    std::get<props::MaxTimestep>(multi_head_attention_props).get();
-
-  bool cache_shift = false;
-  unsigned int from = _from;
-  unsigned int to = _to;
-  if (to > max_timestep) {
-    throw std::invalid_argument("to shouldn't greater than max_timestep");
-  }
-
-  const bool disable_bias =
-    std::get<props::DisableBias>(*layer_impl_props).get();
-
-  const unsigned int num_heads =
-    std::get<props::NumHeads>(multi_head_attention_props).get();
-  const unsigned int projected_key_dim_prop =
-    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
-  const unsigned int projected_value_dim_prop =
-    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
-  const float dropout_rate =
-    std::get<props::DropOutRate>(multi_head_attention_props).get();
-  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
-    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
-  const bool average_attention_weight =
-    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
-
-  const bool provide_attention_mask = context.getNumInputs() == 4;
-  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
-  const bool enable_dropout = dropout_rate > epsilon;
-
-  /** get inputs/outputs */
-  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
-  Tensor &key = context.getInput(INOUT_INDEX::KEY);
-  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
-
-  Tensor empty_tensor =
-    Tensor("empty_tensor", value.getFormat(), value.getDataType());
-
-  Tensor &mask =
-    provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
-
-  TensorDim query_dim = query.getDim();
-  TensorDim key_dim = key.getDim();
-  TensorDim value_dim = value.getDim();
-
-  TensorDim query_step_dim = query_dim;
-  TensorDim key_step_dim = key_dim;
-  TensorDim value_step_dim = value_dim;
-
-  query_step_dim.height(to);
-  key_step_dim.height(to);
-  value_step_dim.height(to);
-
-  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
-  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
-  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
-
-  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
-
-  TensorDim output_dim = output.getDim();
-  TensorDim output_step_dim = output_dim;
-  output_step_dim.height(to);
-  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
-
-  Tensor &ret_attention_weight =
-    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
-      ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
-      : empty_tensor;
-
-  /** get weights */
-
-  Tensor qWeight, kWeight, vWeight, fWeight, qbias, kbias, vbias, fcWeight;
-
-  Tensor &query_fc_weight = qWeight;
-  Tensor &key_fc_weight = kWeight;
-  Tensor &value_fc_weight = vWeight;
-  Tensor &fc_weight = fcWeight;
-  Tensor &query_fc_bias = qbias;
-  Tensor &key_fc_bias = kbias;
-  Tensor &value_fc_bias = vbias;
-
-  context.getWeight(query_fc_weight,
-                    weight_idx[AttentionParams::query_fc_weight]);
-  context.getWeight(key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
-  context.getWeight(value_fc_weight,
-                    weight_idx[AttentionParams::value_fc_weight]);
-
-  context.getWeight(fc_weight, weight_idx[AttentionParams::fc_weight]);
-
-  if (!disable_bias)
-    context.getWeight(query_fc_bias,
-                      weight_idx[AttentionParams::query_fc_bias]);
-  if (!disable_bias)
-    context.getWeight(key_fc_bias, weight_idx[AttentionParams::key_fc_bias]);
-
-  if (!disable_bias)
-    context.getWeight(value_fc_bias,
-                      weight_idx[AttentionParams::value_fc_bias]);
-
-  /** get tensors */
-  Tensor &projected_query =
-    context.getTensor(weight_idx[AttentionParams::projected_query]);
-  Tensor &projected_key =
-    context.getTensor(weight_idx[AttentionParams::projected_key]);
-  Tensor &projected_value =
-    context.getTensor(weight_idx[AttentionParams::projected_value]);
-  Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
-  Tensor &cache_value =
-    context.getTensor(weight_idx[AttentionParams::cache_value]);
-
-  TensorDim projected_query_dim = projected_query.getDim();
-  TensorDim projected_key_dim = projected_key.getDim();
-  TensorDim projected_value_dim = projected_value.getDim();
-  TensorDim cache_key_dim = cache_key.getDim();
-  TensorDim cache_value_dim = cache_value.getDim();
-
-  TensorDim projected_query_step_dim = projected_query_dim;
-
-  TensorDim projected_key_step_dim = projected_key_dim;
-  TensorDim projected_value_step_dim = projected_value_dim;
-  TensorDim cache_key_step_dim = cache_key_dim;
-  TensorDim cache_value_step_dim = cache_value_dim;
-  projected_query_step_dim.height(to);
-
-  projected_key_step_dim.height(to);
-  projected_value_step_dim.height(to);
-  cache_key_step_dim.height(to);
-  cache_value_step_dim.height(to);
-
-  Tensor projected_query_step =
-    projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
-  Tensor projected_key_step =
-    projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
-  Tensor projected_value_step =
-    projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
-
-  Tensor cache_key_step =
-    cache_key.getSharedDataTensor(cache_key_step_dim, 0, true);
-  Tensor cache_value_step =
-    cache_value.getSharedDataTensor(cache_value_step_dim, 0, true);
-
-  TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
-                              to, cache_key_dim.width(),
-                              cache_key.getTensorType()};
-  TensorDim cached_value_dim = {
-    cache_value_dim.batch(), cache_value_dim.channel(), to,
-    cache_value_dim.width(), cache_value.getTensorType()};
-  Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
-  Tensor cached_value =
-    cache_value.getSharedDataTensor(cached_value_dim, 0, true);
-
-  Tensor &attention_weight =
-    context.getTensor(weight_idx[AttentionParams::attention_weight]);
-  Tensor &attention_output =
-    context.getTensor(weight_idx[AttentionParams::attention_output]);
-  TensorDim attention_weight_dim = attention_weight.getDim();
-
-  TensorDim attention_weight_step_dim = attention_weight_dim;
-  attention_weight_step_dim.height(to);
-  attention_weight_step_dim.width(to);
-
-  Tensor attention_weight_step =
-    attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
-
-  TensorDim attention_output_dim = attention_output.getDim();
-  TensorDim attention_output_step_dim = attention_output_dim;
-  attention_output_step_dim.height(to);
-
-  Tensor attention_output_step =
-    attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
-
-  const unsigned int batch_size = query_dim.batch();
-  const unsigned int query_height = query_dim.height();
-  const unsigned int key_height = key_dim.height();
-  const unsigned int value_height = value_dim.height();
-
-  query_step.dot(query_fc_weight, projected_query_step);
-  if (!disable_bias) {
-    projected_query_step.add_i(query_fc_bias);
-  }
-  key_step.dot(key_fc_weight, cache_key_step);
-  if (!disable_bias) {
-    cache_key_step.add_i(key_fc_bias);
-  }
-  value_step.dot(value_fc_weight, cache_value_step);
-  if (!disable_bias) {
-    cache_value_step.add_i(value_fc_bias);
-  }
-
-  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop,
-                          _from);
-  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, _from);
-
-  projected_query_step.reshape(
-    TensorDim({batch_size, to, num_heads, projected_query_dim_prop}));
-
-  cached_key.reshape(
-    TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
-  cached_value.reshape(
-    TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
-
-  projected_query_step.transpose("1:0:2", projected_query_step);
-  cached_key.transpose("1:0:2", projected_key_step);
-  cached_value.transpose("1:0:2", projected_value_step);
-
-  projected_query_step.reshape(
-    TensorDim({batch_size * num_heads, 1, to, projected_query_dim_prop}));
-  projected_key_step.reshape(
-    TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
-  projected_value_step.reshape(
-    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
-
-  attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, to, to}));
-  attention_output_step.reshape(
-    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
-
-  /** scaled dot product attention */
-  projected_query_step.dotBatched(projected_key_step, attention_weight_step,
-                                  false, true);
-  attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
-
-  if (!from) {
-    unsigned int mask_size = attention_weight_step.getDim().width();
-    unsigned int mask_dim_height = mask_size;
-    unsigned int mask_dim_width = mask_size;
-
-    Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
-                                 attention_weight_step.getTensorType()});
-
-    causal_mask.setZero();
-
-    for (unsigned int i = 0; i < mask_dim_height; ++i) {
-      for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
-        causal_mask.setValue(
-          0, 0, i, j, _MASK_NUM(attention_weight.getTensorType().data_type));
-      }
-    }
-
-    attention_weight_step.add_i(causal_mask);
-  }
-
-  sm.run_fn(attention_weight_step, attention_weight_step);
-
-  attention_weight_step.dotBatched(projected_value_step, attention_output_step);
-
-  attention_output_step.reshape(
-    TensorDim({batch_size, num_heads, to, projected_value_dim_prop}));
-
-  attention_output_step = attention_output_step.transpose("1:0:2");
-
-  attention_output_step.reshape(
-    TensorDim({batch_size * to, 1, 1, num_heads * projected_value_dim_prop}));
-
-  attention_output_step.dot(fc_weight, output_step);
-  if (!disable_bias) {
-    output_step.add_i(fc_bias);
-  }
-
-  // if (layer_progress == 28)
-  //   layer_progress = 0;
-  // layer_progress++;
-
-  // std::cout << "Process Reading: " << (int)((layer_progress / 28.0) * 100.0)
-  //           << " % \r";
-  // std::cout.flush();
-}
-
 void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
-                                                     unsigned int _from,
-                                                     unsigned int _to,
+                                                     unsigned int from,
+                                                     unsigned int to,
                                                      bool training) {
-
-  if (!_from) {
-    initial_incremental_forwarding(context, _from, _to, training);
-    return;
-  }
-
-  unsigned int max_timestep =
-    std::get<props::MaxTimestep>(multi_head_attention_props).get();
-
-  bool cache_shift = false;
-  unsigned int from = _from;
-  unsigned int to = _to;
-  if (to >= max_timestep) {
-    cache_shift = true;
-    from = max_timestep - 1;
-    to = max_timestep;
-  }
-
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -908,8 +570,9 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   Tensor &key = context.getInput(INOUT_INDEX::KEY);
   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
 
-  Tensor empty_tensor =
-    Tensor("empty_tensor", value.getFormat(), value.getDataType());
+  Tensor empty_tensor;
+
+  empty_tensor.setTensorType(value.getTensorType());
 
   Tensor &mask =
     provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
@@ -918,86 +581,36 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   TensorDim key_dim = key.getDim();
   TensorDim value_dim = value.getDim();
 
-  TensorDim query_step_dim = query_dim;
-  TensorDim key_step_dim = key_dim;
-  TensorDim value_step_dim = value_dim;
-
-  query_step_dim.height(to - from);
-  key_step_dim.height(to - from);
-  value_step_dim.height(to - from);
-
-  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
-  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
-  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
-
   Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
 
   TensorDim output_dim = output.getDim();
-
-  TensorDim output_step_dim = output_dim;
-  output_step_dim.height(to - from);
-  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
-
   Tensor &ret_attention_weight =
     return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
       ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
       : empty_tensor;
 
   /** get weights */
-  Tensor qWeight, kWeight, vWeight, fWeight, qbias, kbias, vbias, fcWeight;
-  Tensor &query_fc_weight = qWeight;
-  Tensor &key_fc_weight = kWeight;
-  Tensor &value_fc_weight = vWeight;
-  Tensor &fc_weight = fcWeight;
-  Tensor &query_fc_bias = qbias;
-  Tensor &key_fc_bias = kbias;
-  Tensor &value_fc_bias = vbias;
-
-  // auto getWeight_Job = [&](Tensor &t, unsigned int idx) {
-  //   context.getWeight(t, idx);
-  // };
-
-  // auto get_key = std::async(std::launch::async, &RunLayerContext::getWeight,
-  // &context, key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
-
-  // auto get_key = std::async(std::launch::async, getWeight_Job,
-  // std::ref(key_fc_weight),weight_idx[AttentionParams::key_fc_weight] );
-
-  // start = clock();
-  context.getWeight(key_fc_weight, weight_idx[AttentionParams::key_fc_weight]);
-  // auto get_value = std::async(std::launch::async,
-  // &RunLayerContext::getWeight, &context, value_fc_weight,
-  // weight_idx[AttentionParams::value_fc_weight]);
-
-  // auto get_value = std::async(std::launch::async, getWeight_Job,
-  // std::ref(value_fc_weight),weight_idx[AttentionParams::value_fc_weight]);
-
-  // auto get_fc = std::async(std::launch::async, getWeight_Job,
-  // std::ref(fc_weight),weight_idx[AttentionParams::fc_weight]);
-
-  // auto get_fc = std::async(std::launch::async, &RunLayerContext::getWeight,
-  // &context, fc_weight, weight_idx[AttentionParams::fc_weight]);
-
-  context.getWeight(query_fc_weight,
-                    weight_idx[AttentionParams::query_fc_weight]);
-  context.getWeight(value_fc_weight,
-                    weight_idx[AttentionParams::value_fc_weight]);
-
-  context.getWeight(fc_weight, weight_idx[AttentionParams::fc_weight]);
-  // finish=clock();
-  // std::cout << "dequanized :" << (double)(finish-start)<<std::endl;
-  //   disable_bias
-  //     ? empty_tensor
-  //     : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
-
-  if (!disable_bias)
-    context.getWeight(query_fc_bias,
-                      weight_idx[AttentionParams::query_fc_bias]);
-  if (!disable_bias)
-    context.getWeight(key_fc_bias, weight_idx[AttentionParams::key_fc_bias]);
-  if (!disable_bias)
-    context.getWeight(value_fc_bias,
-                      weight_idx[AttentionParams::value_fc_bias]);
+  Tensor &query_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
+  Tensor &query_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
+  Tensor &key_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
+  Tensor &key_fc_bias =
+    disable_bias ? empty_tensor
+                 : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
+  Tensor &value_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
+  Tensor &value_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
+  Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+  Tensor &fc_bias = disable_bias
+                      ? empty_tensor
+                      : context.getWeight(weight_idx[AttentionParams::fc_bias]);
 
   /** get tensors */
   Tensor &projected_query =
@@ -1076,24 +689,19 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   const unsigned int key_height = key_dim.height();
   const unsigned int value_height = value_dim.height();
 
-  query_step.dot(query_fc_weight, projected_query_step);
-
+  query.dot(query_fc_weight, projected_query_step);
   if (!disable_bias) {
     projected_query_step.add_i(query_fc_bias);
   }
-  key_step.dot(key_fc_weight, cache_key_step);
+  key.dot(key_fc_weight, cache_key_step);
   if (!disable_bias) {
     cache_key_step.add_i(key_fc_bias);
   }
-  value_step.dot(value_fc_weight, cache_value_step);
+  value.dot(value_fc_weight, cache_value_step);
   if (!disable_bias) {
     cache_value_step.add_i(value_fc_bias);
   }
 
-  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop,
-                          _from);
-  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, _from);
-
   projected_query_step.reshape(
     TensorDim({batch_size, 1, num_heads, projected_query_dim_prop}));
   cached_key.reshape(
@@ -1131,10 +739,15 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
 
     causal_mask.setZero();
 
+#ifdef ENABLE_FP16
+#define _MASK_NUM -1e4
+#else
+#define _MASK_NUM -1e10
+#endif
+
     for (unsigned int i = 0; i < mask_dim_height; ++i) {
       for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
-        causal_mask.setValue(
-          0, 0, i, j, _MASK_NUM(attention_weight.getTensorType().data_type));
+        causal_mask.setValue(0, 0, i, j, _MASK_NUM);
       }
     }
 
@@ -1153,35 +766,9 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   attention_output_step.reshape(TensorDim(
     {batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
 
-  attention_output_step.dot(fc_weight, output_step);
+  attention_output_step.dot(fc_weight, output);
   if (!disable_bias) {
-    output_step.add_i(fc_bias);
-  }
-
-  if (cache_shift) {
-    if (cache_key.getDataType() == ml::train::TensorDim::DataType::FP32) {
-      float *buf = cache_key.getAddress<float>(0, 0, 1, 0);
-      float *dbuf = cache_key.getAddress<float>(0, 0, 0, 0);
-      memcpy(dbuf, buf, (cache_key.size() - cache_key.width()) * sizeof(float));
-      buf = cache_value.getAddress<float>(0, 0, 1, 0);
-      dbuf = cache_value.getAddress<float>(0, 0, 0, 0);
-      memcpy(dbuf, buf,
-             (cache_value.size() - cache_value.width()) * sizeof(float));
-    } else if (cache_key.getDataType() ==
-               ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-
-      _FP16 *buf = cache_key.getAddress<_FP16>(0, 0, 1, 0);
-      _FP16 *dbuf = cache_key.getAddress<_FP16>(0, 0, 0, 0);
-      memcpy(dbuf, buf, (cache_key.size() - cache_key.width()) * sizeof(_FP16));
-      buf = cache_value.getAddress<_FP16>(0, 0, 1, 0);
-      dbuf = cache_value.getAddress<_FP16>(0, 0, 0, 0);
-      memcpy(dbuf, buf,
-             (cache_key.size() - cache_value.width()) * sizeof(_FP16));
-#else
-      throw std::invalid_argument("enable-fp16 is not set");
-#endif
-    }
+    output.add_i(fc_bias);
   }
 }
 
@@ -1521,6 +1108,7 @@ void MultiHeadAttentionLayer::setBatch(RunLayerContext &context,
   context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
   context.updateTensor(weight_idx[AttentionParams::cache_key], batch);
   context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
+  // context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
   context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
   if (dropout_rate > epsilon) {
     context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);
index f6f5e10bf496a349b29734960a83422410420dd9..01f4ca39791e8cc9abdca77f4e51a27e6f2fccee 100644 (file)
 #ifdef __cplusplus
 
 #include <acti_func.h>
-#include <complex>
 #include <layer_impl.h>
-#include <util_simd.h>
-#include <utility>
 
 namespace nntrainer {
 
@@ -63,14 +60,6 @@ public:
    */
   void forwarding(RunLayerContext &context, bool training) override;
 
-  /**
-   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
-   * int from, unsigned int to, bool training)
-   */
-  void initial_incremental_forwarding(RunLayerContext &context,
-                                      unsigned int from, unsigned int to,
-                                      bool training);
-
   /**
    * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
    * int from, unsigned int to, bool training)
@@ -122,8 +111,7 @@ public:
 private:
   std::tuple<props::NumHeads, props::ProjectedKeyDim, props::ProjectedValueDim,
              props::OutputShape, props::DropOutRate,
-             props::ReturnAttentionWeight, props::AverageAttentionWeight,
-             props::MaxTimestep>
+             props::ReturnAttentionWeight, props::AverageAttentionWeight>
     multi_head_attention_props; /**< multi_head_attention layer properties */
 
   ActiFunc sm; /** softmax activation operation */
@@ -135,167 +123,6 @@ private:
    */
   float epsilon;
 
-  unsigned int cache_index;
-
-  inline static unsigned int layer_progress;
-
-  inline static std::vector<std::vector<float>> *freqs_cos = {};
-  inline static std::vector<std::vector<float>> *freqs_sin = {};
-  inline static std::vector<float> freqs;
-
-  /**
-   * @brief     compute frequency for rotary embedding
-   * @param[in] dim hidden dim size
-   * @param[in] seq_len sequency length
-   * @param[in] theta rotary angle
-   */
-  void precompute_freqs(int dim, unsigned int seq_len, float theta = 10000.0) {
-    if (freqs_cos == nullptr) {
-      unsigned int half_ = dim / 2;
-      for (unsigned int i = 0; i < half_; ++i) {
-        freqs.push_back(1.0 /
-                        (std::pow(theta, (2 * i) / static_cast<float>(dim))));
-      }
-
-      auto cos = new std::vector<std::vector<float>>();
-      cos->assign(seq_len, std::vector<float>(dim, 0));
-
-      auto sin = new std::vector<std::vector<float>>();
-      sin->assign(seq_len, std::vector<float>(dim, 0));
-
-      for (unsigned int i = 0; i < seq_len; ++i) {
-#ifdef USE_NEON
-        calc_trigonometric_vals_dup(half_, freqs.data(), (*cos)[i].data(),
-                                    (*sin)[i].data(), i);
-#else
-        for (unsigned int j = 0; j < half_; ++j) {
-          float angle = i * freqs[j];
-          (*cos)[i][j] = std::cos(angle);
-          (*cos)[i][j + half_] = std::cos(angle); // repeated 2 times
-
-          (*sin)[i][j] = std::sin(angle);
-          (*sin)[i][j + half_] = std::sin(angle); // repeated 2 times
-        }
-#endif
-      }
-      freqs_cos = cos;
-      freqs_sin = sin;
-    }
-  }
-
-  /**
-   * @brief     apply rotary embedding
-   * @param[in] in input tensor
-   * @param[in] dim hidden dim size
-   * @param[in] from sequence order
-   */
-  void apply_rotary_emb_tensor(Tensor &in, unsigned int dim,
-                               unsigned int from) {
-    Tensor out(in.getDim());
-    float value = 0;
-    float transformed_value = 0.0;
-    unsigned int half_ = dim / 2;
-    unsigned int max_timestep =
-      std::get<props::MaxTimestep>(multi_head_attention_props).get();
-
-    std::vector<float> *cos_;
-    std::vector<float> *sin_;
-
-    if (from >= max_timestep) {
-      cos_ = new std::vector<float>(dim);
-      sin_ = new std::vector<float>(dim);
-#ifdef USE_NEON
-      calc_trigonometric_vals_dup(half_, freqs.data(), cos_->data(),
-                                  sin_->data(), from);
-#else
-      for (unsigned int i = 0; i < half_; ++i) {
-        float angle = from * freqs[i];
-        (*cos_)[i] = std::cos(angle);
-        (*cos_)[i + half_] = std::cos(angle); // repeated 2 times
-
-        (*sin_)[i] = std::sin(angle);
-        (*sin_)[i + half_] = std::sin(angle); // repeated 2 times
-      }
-#endif
-    }
-
-    if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
-      for (unsigned int b = 0; b < in.batch(); b++) {
-        for (unsigned int c = 0; c < in.channel(); c++) {
-          for (unsigned int h = 0; h < in.height(); h++) {
-            if (from < max_timestep) {
-              cos_ = &(*freqs_cos)[from + h];
-              sin_ = &(*freqs_sin)[from + h];
-            }
-
-            for (unsigned int w = 0; w < in.width(); w = w + dim) {
-              for (unsigned int k = 0; k < dim; k++) {
-                unsigned int span = w + k;
-                value = in.getValue<float>(b, c, h, span);
-
-                if (k < half_) {
-                  transformed_value =
-                    -1.0 * in.getValue<float>(b, c, h, span + half_);
-                } else {
-                  transformed_value = in.getValue<float>(b, c, h, span - half_);
-                }
-                value = value * (*cos_)[k] + transformed_value * (*sin_)[k];
-                out.setValue(b, c, h, span, value);
-              }
-            }
-          }
-        }
-      }
-    } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-      for (unsigned int b = 0; b < in.batch(); b++) {
-        for (unsigned int c = 0; c < in.channel(); c++) {
-          for (unsigned int h = 0; h < in.height(); h++) {
-            if (from < max_timestep) {
-              cos_ = &(*freqs_cos)[from + h];
-              sin_ = &(*freqs_sin)[from + h];
-            }
-            for (unsigned int w = 0; w < in.width(); w = w + dim) {
-#ifdef USE_NEON
-              compute_rotary_embedding_value(
-                dim, half_, w, in.getData<_FP16>() + in.getIndex(b, c, h, 0),
-                out.getData<_FP16>() + out.getIndex(b, c, h, 0), cos_->data(),
-                sin_->data());
-#else
-              for (unsigned int k = 0; k < dim; k++) {
-                unsigned int span = w + k;
-                value = static_cast<float>(in.getValue<_FP16>(b, c, h, span));
-
-                if (k < half_) {
-                  transformed_value =
-                    -1.0 * static_cast<float>(
-                             in.getValue<_FP16>(b, c, h, half_ + span));
-                } else {
-                  transformed_value = static_cast<float>(
-                    in.getValue<_FP16>(b, c, h, span - half_));
-                }
-                out.setValue(
-                  b, c, h, span,
-                  static_cast<_FP16>(value * (*cos_)[k] +
-                                     transformed_value * (*sin_)[k]));
-              }
-#endif
-            }
-          }
-        }
-      }
-#else
-      throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
-    }
-
-    if (from >= max_timestep) {
-      delete cos_;
-      delete sin_;
-    }
-    in.copy(out);
-  }
-
   /**
    * @brief calculate common derivative
    * @param context Context of the layer
index 5dc4ca5c558281c7b261385674d5b4d27e6acc58..5aae748adfae852cdedabf5c0a0524c05cde7d84 100644 (file)
@@ -62,7 +62,7 @@ test_target = [
   'unittest_layers_dropout.cpp',
   'unittest_layers_reshape.cpp',
   # 'unittest_layers_mol_attention.cpp',
-  'unittest_layers_multi_head_attention.cpp',
+  'unittest_layers_multi_head_attention.cpp',
   'unittest_layers_positional_encoding.cpp',
 ]