From dcdf4ffd215450966fed767c85326fadda4f032b Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Mon, 6 Dec 2021 17:00:36 +0900 Subject: [PATCH] [layer] Bug fix for embedding layer This patch adds bug fix for the embedding layer related to the index of the data. Signed-off-by: Parichay Kapoor --- nntrainer/layers/embedding.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/nntrainer/layers/embedding.cpp b/nntrainer/layers/embedding.cpp index 80d1edf..1049096 100644 --- a/nntrainer/layers/embedding.cpp +++ b/nntrainer/layers/embedding.cpp @@ -91,14 +91,17 @@ void EmbeddingLayer::forwarding(RunLayerContext &context, bool training) { uint *in_data = input_.getAddress(b * input_.getDim().getFeatureLen()); + Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1); for (unsigned int i = 0; i < input_.width(); ++i) { uint embed_idx = in_data[i]; if (embed_idx >= in_dim) { throw std::invalid_argument("input word index is greater than in_dim"); } - Tensor cur_weight = weight.getSharedDataTensor(out_tensor_dim, embed_idx); - Tensor out_tensor = hidden_.getSharedDataTensor(out_tensor_dim, i); + Tensor cur_weight = + weight.getSharedDataTensor(out_tensor_dim, embed_idx * out_dim); + Tensor out_tensor = + batchsliced_hidden.getSharedDataTensor(out_tensor_dim, i * out_dim); /** if zero_mask_idx matches the given index, set the output to zero */ if (!zero_mask_idx.empty() && embed_idx == zero_mask_idx.get()) { @@ -132,11 +135,14 @@ void EmbeddingLayer::calcGradient(RunLayerContext &context) { uint *in_data = input_.getAddress(b * input_.getDim().getFeatureLen()); + Tensor batchsliced_derivative = derivative_.getBatchSlice(b, 1); for (unsigned int i = 0; i < input_.width(); ++i) { uint embed_idx = in_data[i]; - Tensor cur_dw = djdw.getSharedDataTensor(out_tensor_dim, embed_idx); - Tensor in_derv = derivative_.getSharedDataTensor(out_tensor_dim, i); + Tensor cur_dw = + djdw.getSharedDataTensor(out_tensor_dim, embed_idx * out_dim); + Tensor in_derv = + batchsliced_derivative.getSharedDataTensor(out_tensor_dim, i * out_dim); /** if zero_mask_idx matches the given index, set the grad to zero */ if (!zero_mask_idx.empty() && embed_idx == zero_mask_idx.get()) { -- 2.7.4