namespace nntrainer {
-AttentionLayer::AttentionLayer() {
+AttentionLayer::AttentionLayer() : sm(ActivationType::ACT_SOFTMAX) {
wt_idx.fill(std::numeric_limits<unsigned>::max());
}
static constexpr size_t SINGLE_INOUT_IDX = 0;
-enum AttentionParams { query = 0, value = 1, key = 2, score, weights };
+enum AttentionParams { query = 0, value = 1, key = 2, weights };
void AttentionLayer::finalizeCommon(InitLayerContext &context) {
if (context.getNumInputs() < 2 || context.getNumInputs() > 3)
void AttentionLayer::finalize(InitLayerContext &context) {
finalizeCommon(context);
- sm.setActiFunc(ActivationType::ACT_SOFTMAX);
-
auto const &all_dims = context.getInputDimensions();
auto const &query_dim = all_dims[AttentionParams::query];
auto const &value_dim = all_dims[AttentionParams::value];
context.requestTensor(weights_dim, "weights", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
- wt_idx[AttentionParams::score] =
- context.requestTensor(weights_dim, "score", Tensor::Initializer::NONE,
- false, TensorLifespan::FORWARD_FUNC_LIFESPAN);
-
context.setOutputDimensions({query_dim});
}
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
- Tensor &score = context.getTensor(wt_idx[AttentionParams::score]);
- query.dotBatched(key, score, false, true); /** dot 1 */
- sm.run_fn(score, weights); /** softmax */
- weights.dotBatched(value, output); /** dot 2 */
+ query.dotBatched(key, weights, false, true); /** dot 1 */
+ sm.run_fn(weights, weights); /** softmax */
+ weights.dotBatched(value, output); /** dot 2 */
}
void AttentionLayer::calcDerivative(RunLayerContext &context) {
weights.dot_batched_deriv_wrt_2(dvalue, derivative);
/** derivative for softmax */
- Tensor dscore;
- sm.run_prime_fn(weights, dscore, dweight);
+ sm.run_prime_fn(weights, dweight, dweight);
/** derivative for dot 1 */
- dquery.dot_batched_deriv_wrt_1(key, dscore, false, true);
- query.dot_batched_deriv_wrt_2(dkey, dscore, false, true,
+ dquery.dot_batched_deriv_wrt_1(key, dweight, false, true);
+ query.dot_batched_deriv_wrt_2(dkey, dweight, false, true,
context.getNumInputs() == 2);
}
}
void AttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
- context.updateTensor(wt_idx[AttentionParams::score], batch);
context.updateTensor(wt_idx[AttentionParams::weights], batch);
}
u_neg_div,
u_pos_div,
dstate,
- updated_state
};
void MoLAttentionLayer::finalize(InitLayerContext &context) {
fc_out_dim.width(fc_w_dim.width());
wt_idx[AttentionParams::fc_out] =
context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE,
- false, TensorLifespan::ITERATION_LIFESPAN);
+ false, TensorLifespan::FORWARD_FUNC_LIFESPAN);
wt_idx[AttentionParams::fc_tanh] =
context.requestTensor(fc_out_dim, "fc_tanh", Tensor::Initializer::NONE,
Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
- Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
Tensor &dfc_proj_out =
context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
else
dstate.copyData(dstate_local);
- Tensor dfc_tanh = Tensor(fc_out.getDim());
+ Tensor dfc_tanh = Tensor(fc_tanh.getDim());
dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
Tensor dfc_out;
Tensor &dfc_bias = context.getWeightGrad(wt_idx[AttentionParams::fc_bias]);
Tensor &dfc_proj_w =
context.getWeightGrad(wt_idx[AttentionParams::fc_proj_w]);
- Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
Tensor &dfc_proj_out =
context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
if (!helper_exec)
calcDerivativeHelper(context, dstate);
- Tensor dfc_tanh = Tensor(fc_out.getDim());
+ Tensor dfc_tanh = Tensor(fc_tanh.getDim());
fc_tanh.dot_deriv_wrt_2(
dfc_proj_w, dfc_proj_out, false, false,
!context.isGradientFirstAccess(wt_idx[AttentionParams::fc_proj_w]));
props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
props::AverageAttentionWeight()),
+ sm(ActivationType::ACT_SOFTMAX),
epsilon(1e-3) {
weight_idx.fill(std::numeric_limits<unsigned>::max());
}
projected_query,
projected_key,
projected_value,
- attention_score,
- d_attention_score,
/** intended comment for later use of attention_mask */
// attention_mask,
attention_weight,
projected_value_dim, "projected_value", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
- /** tensor for attention score */
- TensorDim attention_score_dim(
- {batch_size, num_heads, query_height, key_height});
- weight_idx[AttentionParams::attention_score] = context.requestTensor(
- attention_score_dim, "attention_score", Tensor::Initializer::NONE, false,
- TensorLifespan::FORWARD_FUNC_LIFESPAN);
- weight_idx[AttentionParams::d_attention_score] = context.requestTensor(
- attention_score_dim, "d_attention_score", Tensor::Initializer::NONE, false,
- TensorLifespan::BACKWARD_FUNC_LIFESPAN);
if (provide_attention_mask) {
/** Intended comment for bool type mask */
// TensorDim attention_mask_dim(
Tensor &projected_value =
context.getTensor(weight_idx[AttentionParams::projected_value]);
- Tensor &attention_score =
- context.getTensor(weight_idx[AttentionParams::attention_score]);
Tensor &attention_weight =
context.getTensor(weight_idx[AttentionParams::attention_weight]);
Tensor &attention_output =
projected_value.reshape(TensorDim(
{batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
- attention_score.reshape(
- TensorDim({batch_size * num_heads, 1, query_height, key_height}));
attention_weight.reshape(
TensorDim({batch_size * num_heads, 1, query_height, key_height}));
attention_output.reshape(TensorDim(
{batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
/** scaled dot product attention */
- projected_query.dotBatched(projected_key, attention_score, false, true);
- attention_score.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+ projected_query.dotBatched(projected_key, attention_weight, false, true);
+ attention_weight.multiply_i(1 / sqrt((float)projected_query_dim_prop));
if (provide_attention_mask) {
// Tensor &attention_mask =
// attention_mask.multiply_i(-1e9);
// }
// attention_mask.multiply_i(mask);
- // attention_score.add_i(attention_mask);
+ // attention_weight.add_i(attention_mask);
- attention_score.reshape(
+ attention_weight.reshape(
TensorDim({batch_size, num_heads, query_height, key_height}));
- attention_score.add_i(mask);
- attention_score.reshape(
+ attention_weight.add_i(mask);
+ attention_weight.reshape(
TensorDim({batch_size * num_heads, 1, query_height, key_height}));
}
- sm.run_fn(attention_score, attention_weight);
+ sm.run_fn(attention_weight, attention_weight);
if (return_attention_weight ==
props::ReturnAttentionWeightInfo::Enum::before) {
projected_value.reshape(TensorDim(
{batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
- attention_score.reshape(
- TensorDim({batch_size, num_heads, query_height, key_height}));
attention_weight.reshape(
TensorDim({batch_size, num_heads, query_height, key_height}));
attention_output.reshape(TensorDim(
Tensor &d_projected_value =
context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
- Tensor &d_attention_score =
- context.getTensor(weight_idx[AttentionParams::d_attention_score]);
Tensor &attention_weight =
context.getTensor(weight_idx[AttentionParams::attention_weight]);
Tensor &d_attention_weight =
d_projected_value.reshape(TensorDim(
{batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
- d_attention_score.reshape(
- TensorDim({batch_size * num_heads, 1, query_height, key_height}));
attention_weight.reshape(
TensorDim({batch_size * num_heads, 1, query_height, key_height}));
d_attention_weight.reshape(
d_attention_weight.add_i(d_ret_attention_weight);
}
- sm.run_prime_fn(attention_weight, d_attention_score, d_attention_weight);
+ sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight);
if (provide_attention_mask) {
Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
- d_mask.copyData(d_attention_score);
+ d_mask.copyData(d_attention_weight);
}
- d_attention_score.multiply_i(
+ d_attention_weight.multiply_i(
1 / sqrt((float)projected_query_dim_prop)); /** scale */
- d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_score,
+ d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight,
false, true);
- projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_score,
+ projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight,
false, true);
d_projected_query.reshape(
d_projected_value.reshape(TensorDim(
{batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
- d_attention_score.reshape(
- TensorDim({batch_size, num_heads, query_height, key_height}));
attention_weight.reshape(
TensorDim({batch_size, num_heads, query_height, key_height}));
d_attention_weight.reshape(
context.updateTensor(weight_idx[AttentionParams::projected_query], batch);
context.updateTensor(weight_idx[AttentionParams::projected_key], batch);
context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
- context.updateTensor(weight_idx[AttentionParams::attention_score], batch);
- context.updateTensor(weight_idx[AttentionParams::d_attention_score], batch);
context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
if (dropout_rate > epsilon) {
context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);