[layer] fix for mol attention layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 3 Dec 2021 12:09:51 +0000 (21:09 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Dec 2021 12:51:57 +0000 (21:51 +0900)
Bug fix for mol attention layer as getOutputDerivatives is not available
in calcGradient, so the usage is replaced with a temporary tensor.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/mol_attention_layer.h

index a857a38..f90fe7d 100644 (file)
@@ -43,7 +43,8 @@ enum AttentionParams {
   prob_left,
   prob_right,
   u_neg_div,
-  u_pos_div
+  u_pos_div,
+  dstate
 };
 
 void MoLAttentionLayer::finalize(InitLayerContext &context) {
@@ -137,6 +138,9 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
   wt_idx[AttentionParams::u_pos_div] =
     context.requestTensor(prob_dim, "u_pos_div", Tensor::Initializer::NONE,
                           false, TensorLifespan::ITERATION_LIFESPAN);
+  wt_idx[AttentionParams::dstate] =
+    context.requestTensor(state_dim, "dstate", Tensor::Initializer::NONE, false,
+                          TensorLifespan::BACKWARD_FUNC_LIFESPAN);
 
   context.setOutputDimensions({query_dim, state_dim});
 }
@@ -331,6 +335,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
     context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
   Tensor &dstate =
     context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+  Tensor &dstate_local = context.getTensor(wt_idx[AttentionParams::dstate]);
 
   Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
 
@@ -346,6 +351,8 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
 
   if (!helper_exec)
     calcDerivativeHelper(context, dstate);
+  else
+    dstate.copyData(dstate_local);
 
   Tensor dfc_tanh = Tensor(fc_out.getDim());
   dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
@@ -357,8 +364,7 @@ void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
 
 void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
   Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
-  Tensor &dstate =
-    context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+  Tensor &dstate = context.getTensor(wt_idx[AttentionParams::dstate]);
 
   Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
   Tensor &dfc_w = context.getWeightGrad(wt_idx[AttentionParams::fc_w]);
index e41e5f4..141c7c7 100644 (file)
@@ -106,7 +106,7 @@ private:
   ActiFunc softmax; /** softmax activation operation */
   ActiFunc tanh;    /** softmax activation operation */
   ActiFunc sigmoid; /** softmax activation operation */
-  std::array<unsigned int, 15>
+  std::array<unsigned int, 16>
     wt_idx; /**< indices of the weights and tensors */
 
   /**