static constexpr size_t SINGLE_INOUT_IDX = 0;
-enum AttentionParams { query = 0, value = 1 };
+enum AttentionParams { query = 0, value = 1, score, weights };
void AttentionLayer::finalize(InitLayerContext &context) {
if (context.getNumInputs() != 2)
sm.setActiFunc(ActivationType::ACT_SOFTMAX);
- auto const &all_shapes = context.getInputDimensions();
- auto const &query_shape = all_shapes[AttentionParams::query];
+ auto const &all_dims = context.getInputDimensions();
+ auto const &query_dim = all_dims[AttentionParams::query];
+ auto const &value_dim = all_dims[AttentionParams::value];
- context.setOutputDimensions({query_shape});
+ wt_idx[AttentionParams::query] = query;
+ wt_idx[AttentionParams::value] = value;
+
+ auto weights_dim = query_dim;
+ weights_dim.width(value_dim.width());
+ wt_idx[AttentionParams::weights] = context.requestTensor(
+ weights_dim, context.getName() + ":weights", Tensor::Initializer::NONE,
+ false, TensorLifespan::ITERATION_LIFESPAN);
+
+ wt_idx[AttentionParams::score] = context.requestTensor(
+ weights_dim, context.getName() + ":score", Tensor::Initializer::NONE, false,
+ TensorLifespan::FORWARD_FUNC_LIFESPAN);
+
+ context.setOutputDimensions({query_dim});
}
void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
Tensor &value = context.getInput(AttentionParams::value);
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
- Tensor distribution;
+ Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
+ Tensor &score = context.getTensor(wt_idx[AttentionParams::weights]);
- Tensor score = query.dot(value, false, true);
- sm.run_fn(score, distribution);
- distribution.dot(value, output);
+ query.dot(value, score, false, true);
+ sm.run_fn(score, weights);
+ weights.dot(value, output);
}
void AttentionLayer::calcDerivative(RunLayerContext &context) {
- /**
- * Not yet implemented
- */
+ Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+ Tensor &query = context.getInput(AttentionParams::query);
+ Tensor &value = context.getInput(AttentionParams::value);
+
+ Tensor &dquery = context.getOutgoingDerivative(AttentionParams::query);
+ Tensor &dvalue = context.getOutgoingDerivative(AttentionParams::value);
+ Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
+
+ Tensor t1;
+ sm.run_prime_fn(weights, t1, derivative_);
+
+ Tensor t2 = value.dot(t1);
+ dquery = t2.dot(value).dot(derivative_);
+
+ dvalue = t2.dot(query).add(weights).dot(derivative_);
}
void AttentionLayer::setProperty(const std::vector<std::string> &values) {
/**
* Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
*
- * @file unittest_layers_addition.cpp
+ * @file unittest_layers_attention.cpp
* @date 1 October 2021
* @brief Attention Layer Test
* @see https://github.com/nnstreamer/nntrainer
LayerSemanticsParamType(nntrainer::createLayer<nntrainer::AttentionLayer>,
nntrainer::AttentionLayer::type, {}, 0, false, 2);
-INSTANTIATE_TEST_CASE_P(Addition, LayerSemantics,
+INSTANTIATE_TEST_CASE_P(Attention, LayerSemantics,
::testing::Values(semantic_attention));