[layer] Add full support of MoL attention layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Wed, 1 Dec 2021 06:33:08 +0000 (15:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Dec 2021 05:46:00 +0000 (14:46 +0900)
Add MoL attention layer backwarding, exportTo, setbatch member
functions.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/mol_attention_layer.h

index b8a66b9..fcc10b9 100644 (file)
@@ -11,7 +11,7 @@
  *
  */
 
-#include <cmath>
+#include <math.h>
 
 #include <layer_context.h>
 #include <mol_attention_layer.h>
 
 namespace nntrainer {
 
-MoLAttentionLayer::MoLAttentionLayer() : wt_idx({0}) {}
+MoLAttentionLayer::MoLAttentionLayer() : helper_exec(false), wt_idx({0}) {}
 
 MoLAttentionLayer::~MoLAttentionLayer() {}
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum AttentionParams { query = 0, value = 1, state = 2, mlp_w, mlp_proj_w };
+enum AttentionParams {
+  query = 0,
+  value = 1,
+  state = 2,
+  fc_w,
+  fc_bias,
+  fc_proj_w,
+  fc_out,
+  fc_tanh,
+  fc_proj_out,
+  scores,
+  prob,
+  prob_left,
+  prob_right,
+  u_neg_div,
+  u_pos_div
+};
 
 void MoLAttentionLayer::finalize(InitLayerContext &context) {
   if (context.getNumInputs() != 3)
@@ -35,6 +51,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
 
   auto const &all_dims = context.getInputDimensions();
   auto const &query_dim = all_dims[AttentionParams::query];
+  auto const &value_dim = all_dims[AttentionParams::value];
 
   wt_idx[AttentionParams::query] = AttentionParams::query;
   wt_idx[AttentionParams::value] = AttentionParams::value;
@@ -53,15 +70,63 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
     << "MoL_K property not provided for layer " << context.getName();
   auto mol_k = std::get<props::MoL_K>(mol_props).get();
 
-  TensorDim mlp_w_dim = {query_dim.width(), unit};
-  wt_idx[AttentionParams::mlp_w] =
-    context.requestWeight(mlp_w_dim, Tensor::Initializer::XAVIER_UNIFORM,
-                          WeightRegularizer::NONE, 0.0, "mlp_w", true);
-
-  TensorDim mlp_proj_w_dim = {unit, 3 * mol_k};
-  wt_idx[AttentionParams::mlp_proj_w] =
-    context.requestWeight(mlp_proj_w_dim, Tensor::Initializer::XAVIER_UNIFORM,
-                          WeightRegularizer::NONE, 0.0, "mlp_proj_w", true);
+  auto &weight_regularizer =
+    std::get<props::WeightRegularizer>(*layer_impl_props);
+  auto &weight_regularizer_constant =
+    std::get<props::WeightRegularizerConstant>(*layer_impl_props);
+  auto &weight_initializer =
+    std::get<props::WeightInitializer>(*layer_impl_props);
+  auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
+
+  TensorDim fc_w_dim = {query_dim.width(), unit};
+  wt_idx[AttentionParams::fc_w] =
+    context.requestWeight(fc_w_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "fc_w", true);
+  TensorDim fc_bias_dim = {unit};
+  wt_idx[AttentionParams::fc_bias] =
+    context.requestWeight(fc_bias_dim, bias_initializer, weight_regularizer,
+                          weight_regularizer_constant, "fc_bias", true);
+
+  TensorDim fc_proj_w_dim = {unit, 3 * mol_k};
+  wt_idx[AttentionParams::fc_proj_w] =
+    context.requestWeight(fc_proj_w_dim, weight_initializer, weight_regularizer,
+                          weight_regularizer_constant, "fc_proj_w", true);
+
+  TensorDim fc_out_dim = query_dim;
+  fc_out_dim.width(fc_w_dim.width());
+  wt_idx[AttentionParams::fc_out] =
+    context.requestTensor(fc_out_dim, "fc_out", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+
+  TensorDim fc_proj_out_dim = fc_out_dim;
+  fc_out_dim.width(fc_proj_w_dim.width());
+  wt_idx[AttentionParams::fc_proj_out] = context.requestTensor(
+    fc_proj_out_dim, "fc_proj_out", Tensor::Initializer::NONE, false,
+    TensorLifespan::ITERATION_LIFESPAN);
+
+  TensorDim scores_dim =
+    TensorDim({value_dim.batch(), 1, 1, value_dim.height()});
+  wt_idx[AttentionParams::scores] =
+    context.requestTensor(scores_dim, "scores", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+
+  TensorDim prob_dim = value_dim;
+  prob_dim.width(mol_k);
+  wt_idx[AttentionParams::prob] =
+    context.requestTensor(prob_dim, "prob", Tensor::Initializer::NONE, false,
+                          TensorLifespan::ITERATION_LIFESPAN);
+  wt_idx[AttentionParams::prob_left] =
+    context.requestTensor(prob_dim, "prob_left", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+  wt_idx[AttentionParams::prob_right] =
+    context.requestTensor(prob_dim, "prob_right", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+  wt_idx[AttentionParams::u_neg_div] =
+    context.requestTensor(prob_dim, "u_neg_div", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+  wt_idx[AttentionParams::u_pos_div] =
+    context.requestTensor(prob_dim, "u_pos_div", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
 
   context.setOutputDimensions({query_dim});
 }
@@ -72,34 +137,55 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &state = context.getInput(wt_idx[AttentionParams::state]);
 
   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
-  Tensor &mlp_w = context.getWeight(wt_idx[AttentionParams::mlp_w]);
-  Tensor &mlp_proj_w = context.getWeight(wt_idx[AttentionParams::mlp_proj_w]);
+  Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
+  Tensor &fc_bias = context.getWeight(wt_idx[AttentionParams::fc_bias]);
+  Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
+  Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
+  Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
+  Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
+  Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]);
+  Tensor &prob = context.getTensor(wt_idx[AttentionParams::prob]);
+  Tensor &prob_left = context.getTensor(wt_idx[AttentionParams::prob_left]);
+  Tensor &prob_right = context.getTensor(wt_idx[AttentionParams::prob_right]);
+  Tensor &u_neg_div = context.getTensor(wt_idx[AttentionParams::u_neg_div]);
+  Tensor &u_pos_div = context.getTensor(wt_idx[AttentionParams::u_pos_div]);
 
   const TensorDim &input_dim = query.getDim();
   unsigned int batch = input_dim.batch();
   auto mol_k = std::get<props::MoL_K>(mol_props).get();
 
-  Tensor mlp_out = query.dot(mlp_w);
+  /** reset helper state */
+  helper_exec = false;
+
+  fc_out = query.dot(fc_w);
+  fc_out.add_i(fc_bias);
 
-  Tensor mlp_tanh;
-  tanh.run_fn(mlp_out, mlp_tanh);
+  tanh.run_fn(fc_out, fc_tanh);
 
-  Tensor mlp_proj_out = mlp_tanh.dot(mlp_proj_w);
+  fc_proj_out = fc_tanh.dot(fc_proj_w);
 
-  /** @note kappa_hat, beta_hat are strided */
-  Tensor kappa_hat =
-    mlp_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false);
-  Tensor beta_hat =
-    mlp_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false);
-  Tensor alpha_hat;
-  alpha_hat.copy_with_stride(
-    mlp_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
+  Tensor kappa_src, beta_src, alpha_src;
+  kappa_src.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
+  beta_src.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
+  alpha_src.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
 
-  Tensor kappa = kappa_hat.apply(static_cast<float (*)(float)>(&std::exp));
-  Tensor beta = beta_hat.apply(static_cast<float (*)(float)>(&std::exp));
+  kappa_src.apply_i(&expf);
+  beta_src.apply_i(&expf);
+  Tensor kappa = kappa_src;
+  Tensor beta = beta_src;
 
   Tensor alpha;
-  softmax.run_fn(alpha_hat, alpha);
+  softmax.run_fn(alpha_src, alpha);
+
+  fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
+    .copy_with_stride(kappa);
+  fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
+    .copy_with_stride(beta);
+  fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
+    .copy_with_stride(alpha);
 
   Tensor m = state.add(kappa);
 
@@ -116,19 +202,159 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   u_base.add_i(-0.5);
   Tensor u_neg = u_base;
 
-  Tensor integral_left, integral_right;
-  sigmoid.run_fn(u_pos.subtract(state).divide(beta.add(1e-8f)), integral_left);
-  sigmoid.run_fn(u_neg.subtract(state).divide(beta.add(1e-8f)), integral_right);
-  Tensor integral = integral_left.subtract(integral_right);
+  Tensor beta_eps = beta.add(1e-8f);
+
+  Tensor u_pos_m = u_pos.subtract(m);
+  u_pos_m.divide(beta_eps, u_pos_div);
+  sigmoid.run_fn(u_pos_div, prob_left);
 
-  Tensor integral_scaled = integral.multiply(alpha);
-  Tensor scores = integral_scaled.sum(3);
-  scores.reshape(TensorDim({scores.batch(), 1, 1, scores.height()}));
+  Tensor u_neg_m = u_neg.subtract(m);
+  u_neg_m.divide(beta_eps, u_neg_div);
+  sigmoid.run_fn(u_neg_div, prob_right);
+
+  prob_left.subtract(prob_right, prob);
+
+  Tensor prob_scaled = prob.multiply(alpha);
+  prob_scaled.sum(3, scores);
 
   scores.dotBatched(value, output);
 }
 
-void MoLAttentionLayer::calcDerivative(RunLayerContext &context) { /** NYI */
+void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
+                                             Tensor &dstate) {
+  Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
+  Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
+
+  Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
+  Tensor &dfc_proj_out =
+    context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
+  Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]);
+  Tensor &prob = context.getTensor(wt_idx[AttentionParams::prob]);
+  Tensor &prob_left = context.getTensor(wt_idx[AttentionParams::prob_left]);
+  Tensor &prob_right = context.getTensor(wt_idx[AttentionParams::prob_right]);
+  Tensor &u_neg_div = context.getTensor(wt_idx[AttentionParams::u_neg_div]);
+  Tensor &u_pos_div = context.getTensor(wt_idx[AttentionParams::u_pos_div]);
+
+  const TensorDim &input_dim = query.getDim();
+  unsigned int batch = input_dim.batch();
+  auto mol_k = std::get<props::MoL_K>(mol_props).get();
+
+  Tensor kappa, beta, alpha;
+  kappa.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
+  beta.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
+  alpha.copy_with_stride(
+    fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
+
+  Tensor dscores = Tensor(TensorDim({value.batch(), 1, 1, value.height()}));
+  dscores.dot_batched_deriv_wrt_1(value, derivative);
+  dscores.reshape(TensorDim({scores.batch(), 1, scores.width(), 1}));
+
+  Tensor dprob_scaled = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
+  dprob_scaled.setZero();
+  dprob_scaled.add_i(dscores);
+
+  Tensor dalpha = dprob_scaled.multiply(prob).sum(2);
+  Tensor dprob = dprob_scaled.multiply(alpha);
+
+  Tensor dprob_left = dprob;
+  Tensor dprob_right = dprob.multiply(-1);
+
+  Tensor beta_eps = beta.add(1e-8f);
+  Tensor du_neg_div, du_pos_div;
+  sigmoid.run_prime_fn(prob_right, du_neg_div, dprob_right);
+  Tensor du_neg_m = du_neg_div.divide(beta_eps);
+  Tensor dm_neg = du_neg_m.multiply(-1).sum(2);
+  Tensor dbeta_eps_neg = du_neg_m.multiply(u_neg_div).multiply(-1).sum(2);
+
+  sigmoid.run_prime_fn(prob_left, du_pos_div, dprob_left);
+  Tensor du_pos_m = du_pos_div.divide(beta_eps);
+  Tensor dm_pos = du_pos_m.multiply(-1).sum(2);
+  Tensor dbeta_eps_pos = du_pos_m.multiply(u_pos_div).multiply(-1).sum(2);
+
+  Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
+  dm_neg.add(dm_pos, dstate);
+  Tensor dkappa = dstate;
+  Tensor dbeta = dbeta_eps;
+
+  Tensor dalpha_src;
+  softmax.run_prime_fn(alpha, dalpha_src, dalpha);
+
+  Tensor dkappa_src = dkappa.multiply(kappa);
+  Tensor dbeta_src = dbeta.multiply(beta);
+
+  /** dfc_proj_out shares memory with fc_proj_out */
+  dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
+    .copy_with_stride(dkappa_src);
+  dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
+    .copy_with_stride(dbeta_src);
+  dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
+    .copy_with_stride(dalpha_src);
+
+  /** update the helper state */
+  helper_exec = true;
+}
+
+void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
+  Tensor &dquery =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::query]);
+  Tensor &dvalue =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
+  Tensor &dstate =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+
+  Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
+  Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
+  Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
+  Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
+  Tensor &dfc_proj_out =
+    context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
+  Tensor &scores = context.getTensor(wt_idx[AttentionParams::scores]);
+
+  scores.dot_batched_deriv_wrt_1(dvalue, derivative);
+
+  if (!helper_exec)
+    calcDerivativeHelper(context, dstate);
+
+  Tensor dfc_tanh = Tensor(fc_out.getDim());
+  dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
+
+  Tensor dfc_out;
+  tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
+  dquery.dot_deriv_wrt_1(fc_w, dfc_out);
+}
+
+void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
+  Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
+  Tensor &dstate =
+    context.getOutgoingDerivative(wt_idx[AttentionParams::state]);
+
+  Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
+  Tensor &dfc_w = context.getWeightGrad(wt_idx[AttentionParams::fc_w]);
+  Tensor &dfc_bias = context.getWeightGrad(wt_idx[AttentionParams::fc_bias]);
+  Tensor &dfc_proj_w =
+    context.getWeightGrad(wt_idx[AttentionParams::fc_proj_w]);
+  Tensor &fc_out = context.getTensor(wt_idx[AttentionParams::fc_out]);
+  Tensor &fc_tanh = context.getTensor(wt_idx[AttentionParams::fc_tanh]);
+  Tensor &dfc_proj_out =
+    context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
+
+  if (!helper_exec)
+    calcDerivativeHelper(context, dstate);
+
+  Tensor dfc_tanh = Tensor(fc_out.getDim());
+  fc_tanh.dot_deriv_wrt_2(dfc_proj_w, dfc_proj_out);
+  dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
+
+  Tensor dfc_out;
+  tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
+  query.dot_deriv_wrt_2(dfc_w, dfc_out);
+  dfc_out.sum({0, 1, 2}, dfc_bias);
 }
 
 void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
@@ -137,7 +363,21 @@ void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
 }
 
 void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
-  /** NYI */
+  context.updateTensor(wt_idx[AttentionParams::fc_out], batch);
+  context.updateTensor(wt_idx[AttentionParams::fc_proj_out], batch);
+  context.updateTensor(wt_idx[AttentionParams::scores], batch);
+  context.updateTensor(wt_idx[AttentionParams::prob], batch);
+  context.updateTensor(wt_idx[AttentionParams::prob_left], batch);
+  context.updateTensor(wt_idx[AttentionParams::prob_right], batch);
+  context.updateTensor(wt_idx[AttentionParams::u_neg_div], batch);
+  context.updateTensor(wt_idx[AttentionParams::u_pos_div], batch);
+}
+
+void MoLAttentionLayer::exportTo(Exporter &exporter,
+                                 const ExportMethods &method) const {
+  AttentionLayer::exportTo(exporter, method);
+  LayerImpl::exportTo(exporter, method);
+  exporter.saveResult(mol_props, method, this);
 }
 
 } /* namespace nntrainer */
index e0b5dbf..1f54502 100644 (file)
@@ -16,6 +16,7 @@
 #ifdef __cplusplus
 
 #include <attention_layer.h>
+#include <layer_impl.h>
 
 namespace nntrainer {
 
@@ -23,7 +24,7 @@ namespace nntrainer {
  * @class   MoL Attention Layer
  * @brief   Mixture of Logistics Attention Layer
  */
-class MoLAttentionLayer : public AttentionLayer {
+class MoLAttentionLayer : public AttentionLayer, public LayerImpl {
 public:
   /**
    * @brief     Constructor of MoL Attention Layer
@@ -63,6 +64,11 @@ public:
   void calcDerivative(RunLayerContext &context) override;
 
   /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+
+  /**
    * @copydoc bool supportBackwarding() const
    */
   bool supportBackwarding() const override { return true; };
@@ -70,8 +76,7 @@ public:
   /**
    * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
    */
-  void exportTo(Exporter &exporter,
-                const ExportMethods &method) const override {}
+  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
 
   /**
    * @copydoc Layer::setProperty(const std::vector<std::string> &values)
@@ -97,10 +102,20 @@ private:
     mol_props; /**< mol attention layer properties : unit - number of output
                   neurons */
 
-  ActiFunc softmax;                   /** softmax activation operation */
-  ActiFunc tanh;                      /** softmax activation operation */
-  ActiFunc sigmoid;                   /** softmax activation operation */
-  std::array<unsigned int, 5> wt_idx; /**< indices of the weights and tensors */
+  bool helper_exec; /** check if the helper function has already ran */
+  ActiFunc softmax; /** softmax activation operation */
+  ActiFunc tanh;    /** softmax activation operation */
+  ActiFunc sigmoid; /** softmax activation operation */
+  std::array<unsigned int, 15>
+    wt_idx; /**< indices of the weights and tensors */
+
+  /**
+   * @brief Helper function for calculation of the derivative
+   *
+   * @param context layer context
+   * @param dstate to store the derivative of the state
+   */
+  void calcDerivativeHelper(RunLayerContext &context, Tensor &dstate);
 };
 
 } // namespace nntrainer