[ GEMM ] Using GEMM for initial sequence
authorjijoong.moon <jijoong.moon@samsung.com>
Fri, 8 Sep 2023 12:56:03 +0000 (21:56 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 20 Oct 2023 01:51:38 +0000 (10:51 +0900)
This PR uses GEMM to compute initial sequences.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
21 files changed:
Applications/LLaMA/jni/main.cpp
Applications/LLaMA/jni/meson.build
Applications/LLaMA/jni/rms_norm.cpp
Applications/LLaMA/jni/rms_norm.h
Applications/LLaMA/jni/swiglu.cpp
Applications/LLaMA/jni/swiglu.h
Applications/PicoGPT/jni/main.cpp
Applications/meson.build
api/ccapi/include/model.h
meson.build
nntrainer/layers/addition_layer.cpp
nntrainer/layers/addition_layer.h
nntrainer/layers/embedding.cpp
nntrainer/layers/embedding.h
nntrainer/layers/fc_layer.cpp
nntrainer/layers/fc_layer.h
nntrainer/layers/multi_head_attention_layer.cpp
nntrainer/layers/multi_head_attention_layer.h
nntrainer/layers/multiout_layer.cpp
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h

index 768b18c40307e3147aaab0d9eb8477114383f949..7c244ba4714b7a4af75d8bbf1ad44c705c4ff645 100644 (file)
@@ -13,6 +13,7 @@
 #include <chrono>
 #include <ctime>
 #include <iostream>
+#include <iterator>
 #include <memory>
 #include <sstream>
 #include <string>
@@ -39,6 +40,8 @@ using json = nlohmann::json;
 using LayerHandle = std::shared_ptr<ml::train::Layer>;
 using ModelHandle = std::unique_ptr<ml::train::Model>;
 
+ModelHandle g_model;
+
 // Hyper params for LLaMA
 int const DIM = 2304;
 int const NUM_LAYERS = 28;
@@ -105,6 +108,9 @@ T unwrap(std::optional<T> &&value, const std::string &error_msg) {
   }
 }
 
+/**
+ * @brief Create Attention Layer for the seperate impelemntation
+ */
 std::vector<LayerHandle> createAttentionLayer(const int layer_id, int seq_len,
                                               int n_heads, int head_dim,
                                               std::string query_name,
@@ -250,6 +256,9 @@ std::vector<LayerHandle> createAttentionLayer(const int layer_id, int seq_len,
   return layers;
 }
 
+/**
+ * @brief Create FF Layers
+ */
 std::vector<LayerHandle> createFeedForwardLayer(const int layer_id, int dim,
                                                 int hidden_dim,
                                                 std::string input_name,
@@ -288,6 +297,9 @@ std::vector<LayerHandle> createFeedForwardLayer(const int layer_id, int dim,
   return layers;
 }
 
+/**
+ * @brief Create Decoder
+ */
 std::vector<LayerHandle> createTransformerDecoder(const int layer_id,
                                                   std::string input_name) {
   using ml::train::createLayer;
@@ -333,6 +345,9 @@ std::vector<LayerHandle> createTransformerDecoder(const int layer_id,
   return layers;
 }
 
+/**
+ * @brief Create LLaMA2 Model
+ */
 ModelHandle createLLaMA() {
   using ml::train::createLayer;
 
@@ -347,7 +362,9 @@ ModelHandle createLLaMA() {
        withKey("input_shape", "1:1:" + std::to_string(INIT_SEQ_LEN))}));
   } else {
     layers.push_back(createLayer(
-      "input", {withKey("name", "input0"), withKey("input_shape", "1:1:1")}));
+      "input",
+      {withKey("name", "input0"),
+       withKey("input_shape", "1:1:" + std::to_string(INIT_SEQ_LEN))}));
   }
 
   layers.push_back(ml::train::layer::Embedding(
@@ -384,33 +401,10 @@ ModelHandle createLLaMA() {
   return model;
 }
 
-void createAndRun(unsigned int epochs, unsigned int batch_size,
-                  std::wstring text) {
-  // setup model
-  ModelHandle model = createLLaMA();
-  model->setProperty({withKey("batch_size", batch_size),
-                      withKey("epochs", epochs),
-                      // #ifdef ENABLE_FP16
-                      // withKey("model_tensor_type", "FP16-FP16"),
-                      // #endif
-                      withKey("save_path", "test_model.bin")});
-
-  auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"});
-  model->setOptimizer(std::move(optimizer));
-
-  int status = model->compile();
-  if (status) {
-    throw std::invalid_argument("model compilation failed!");
-  }
-
-  status = model->initialize();
-  if (status) {
-    throw std::invalid_argument("model initialization failed!");
-  }
-
-  std::string weight_path = "./llama_fp16.bin";
-
-  model->load(weight_path);
+/**
+ * @brief to run for every text sequence
+ */
+void run(std::string text) {
 
   std::vector<float *> input;
   std::vector<float *> label;
@@ -419,49 +413,123 @@ void createAndRun(unsigned int epochs, unsigned int batch_size,
 
   float *input_sample = (float *)malloc(sizeof(float) * data_size);
 
+  unsigned int input_len = INIT_SEQ_LEN;
+
+  unsigned int init_len;
+
 #if defined(ENABLE_ENCODER2)
   std::string vocab_file_name = "../Applications/LLaMA/jni/vocab.json";
   std::string merge_file_name = "../Applications/LLaMA/jni/merges.txt";
 
   auto tokenizer = unwrap(GPT2Encoder::load(vocab_file_name, merge_file_name),
-                          "Error initialising GPT2 tokenizer\n");
+                          "Error initializising GPT2 tokenizer\n");
+
+  std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+
+  auto init_input = tokenizer.encode(converter.from_bytes(text));
+  init_len = init_input.size();
+
+  input_len = (init_len > INIT_SEQ_LEN) ? INIT_SEQ_LEN : init_len;
+
+  for (unsigned int i = 0; i < input_len; ++i) {
+    input_sample[i] = static_cast<float>(init_input[i]);
+  }
 
-  auto init_input = tokenizer.encode(text);
-  INIT_SEQ_LEN = init_input.size();
-  ((uint *)(input_sample))[0] = init_input[0];
   input.push_back(input_sample);
 
 #else
-  float init_data[INIT_SEQ_LEN] = {
+  float init_input[INIT_SEQ_LEN] = {
     0,  1,  2,  3,  4,  5,   6,   7,   8,   9,   10,  20,  30,  40,
     50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900};
-  ((uint *)(input_sample))[0] = init_data[0];
+  ((uint *)(input_sample))[0] = init_input[0];
   input.push_back(input_sample);
+  init_len = 18;
 #endif
 
-  for (unsigned int i = 1; i < INIT_SEQ_LEN + NUM_TO_GENERATE; ++i) {
-    auto output =
-      model->incremental_inference(1, input, label, INIT_SEQ_LEN, i - 1);
+  std::vector<int64_t> token_ids;
+
+  auto output =
+    g_model->incremental_inference(1, input, label, MAX_SEQ_LEN, 0, input_len);
 
-    std::vector<int64_t> tokens;
-    nntrainer::Tensor output_tensor({batch_size, 1, 1, NUM_VOCAB}, output[0]);
+  unsigned int ids = std::distance(
+    output[0], std::max_element(output[0], output[0] + NUM_VOCAB));
 
-    tokens.push_back(static_cast<int64_t>(output_tensor.argmax()[0]));
+  input_sample[0] = static_cast<float>(ids);
+
+#ifdef ENABLE_FP16
+  for (auto o : output) {
+    delete[] o;
+  }
+#endif
+  std::cout << " Progress Reading: 100 % " << std::endl;
+  std::cout << std::endl << "### Output : " << std::endl;
+  if (init_len < INIT_SEQ_LEN) {
 #if defined(ENABLE_ENCODER2)
-    auto decoded_str = tokenizer.decode(tokens);
-    std::cerr << decoded_str << std::flush;
+    auto decoded_str = tokenizer.decode({static_cast<int64_t>(ids)});
+    std::cout << decoded_str << " ";
+    std::cout.flush();
 #endif
+  }
+
+  for (unsigned int i = input_len + 1; i < input_len + NUM_TO_GENERATE; ++i) {
+    auto output_interval =
+      g_model->incremental_inference(1, input, label, MAX_SEQ_LEN, i - 1, i);
 
-    if (i < INIT_SEQ_LEN) {
+    ids = std::distance(
+      output_interval[0],
+      std::max_element(output_interval[0], output_interval[0] + NUM_VOCAB));
+
+    if (i < input_len) {
+      input_sample[0] = static_cast<float>(init_input[i]);
+    } else {
+      input_sample[0] = static_cast<float>(ids);
 #if defined(ENABLE_ENCODER2)
-      ((uint *)(input_sample))[0] = init_input[i];
-#else
-      ((uint *)(input_sample))[0] = init_data[i];
+      auto decoded_str = tokenizer.decode({static_cast<int64_t>(ids)});
+      std::cout << decoded_str << " ";
+      std::cout.flush();
 #endif
-    } else {
-      ((uint *)(input_sample))[0] = output_tensor.argmax()[0];
     }
+
+#ifdef ENABLE_FP16
+    for (auto o : output_interval) {
+      delete[] o;
+    }
+#endif
   }
+
+  std::cout << std::endl;
+  free(input_sample);
+}
+
+/**
+ * @brief to creaet model
+ */
+void createAndRun(unsigned int epochs, unsigned int batch_size) {
+  // setup model
+  g_model = createLLaMA();
+  g_model->setProperty({withKey("batch_size", batch_size),
+                        withKey("epochs", epochs),
+                        // #ifdef ENABLE_FP16
+                        withKey("model_tensor_type", "FP16-FP16"),
+                        // #endif
+                        withKey("save_path", "test_model.bin")});
+
+  auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"});
+  g_model->setOptimizer(std::move(optimizer));
+
+  int status = g_model->compile();
+  if (status) {
+    throw std::invalid_argument("model compilation failed!");
+  }
+
+  status = g_model->initialize();
+  if (status) {
+    throw std::invalid_argument("model initialization failed!");
+  }
+
+  std::string weight_path = "./llama_fp16.bin";
+
+  g_model->load(weight_path);
 }
 
 #if defined(ENABLE_ENCODER2)
@@ -488,22 +556,19 @@ std::wstring decodeUnicodeEscape(const std::wstring &input) {
   return result.str();
 }
 #endif
-
 int main(int argc, char *argv[]) {
   // Setting locale
   std::locale::global(std::locale("ko_KR.UTF-8"));
 
 #if defined(ENABLE_ENCODER2)
-
   // Getting arguments From terminal
   std::wstring input;
   std::getline(std::wcin, input);
   std::wstring test = decodeUnicodeEscape(input);
   std::wstring_convert<std::codecvt_utf16<wchar_t>> converter;
   std::string text = converter.to_bytes(test);
-  std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
 #else
-  std::wstring text = L"This is sample input for LLaMA.";
+  std::string text = "This is smaple input for LLaMA.";
 #endif
 
   auto &app_context = nntrainer::AppContext::Global();
@@ -523,7 +588,17 @@ int main(int argc, char *argv[]) {
     return 1;
   }
 
-  createAndRun(epoch, batch_size, text);
+  try {
+    const std::vector<std::string> args(argv + 1, argv + argc);
+
+    createAndRun(epoch, batch_size);
+
+    run(text);
+  } catch (const std::exception &e) {
+    std::cerr << "uncaught error while running! details: " << e.what()
+              << std::endl;
+    return EXIT_FAILURE;
+  }
 
   int status = EXIT_SUCCESS;
   return status;
index bd68baaa302bf916592ff7220305b6542e91ae49..24ebf6593f46adee53f1e50e4276a7c65b3a4496 100644 (file)
@@ -13,8 +13,8 @@ transpose_dep = declare_dependency(
 )
 
 if get_option('platform') != 'tizen'
-    extra_defines += '-DENABLE_ENCODER2=1'
-    run_command(meson.source_root() / 'jni' / 'prepare_encoder.sh', meson.build_root(), '0.2' ,check: true)
+   extra_defines += '-DENABLE_ENCODER2=1'
+   run_command(meson.source_root() / 'jni' / 'prepare_encoder.sh', meson.build_root(), '0.2' ,check: true)
 endif
 
 rms_norm_src = files('rms_norm.cpp')
index 6d913a1734b38dece2cc9bb1f43bd14ea64e9116..33e03424d08256ceaf483db45528a53a63f0d5af 100644 (file)
@@ -74,6 +74,65 @@ void RMSNormLayer::forwarding(nntrainer::RunLayerContext &context,
   }
 }
 
+void RMSNormLayer::incremental_forwarding(nntrainer::RunLayerContext &context,
+                                          unsigned int from, unsigned int to,
+                                          bool training) {
+  nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+  nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+  nntrainer::Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
+  ml::train::TensorDim in_dim = in.getDim();
+  ml::train::TensorDim out_dim = out.getDim();
+
+  ml::train::TensorDim in_step_dim = in_dim;
+  ml::train::TensorDim out_step_dim = out_dim;
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  in_step_dim.height(to - from);
+  out_step_dim.height(to - from);
+
+  nntrainer::Tensor in_step = in.getSharedDataTensor(in_step_dim, 0, true);
+  nntrainer::Tensor out_step = out.getSharedDataTensor(out_step_dim, 0, true);
+
+  auto &epsilon = std::get<nntrainer::props::Epsilon>(rms_props).get();
+
+  if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    std::function<float(float)> f = [](float x) { return 1 / std::sqrt(x); };
+    auto t = in_step.multiply(in_step).average(3).add(epsilon);
+    t.apply_i(f);
+    in_step.multiply(t, out_step);
+    out_step.multiply_i(gamma);
+
+  } else if (in_step.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    ml::train::TensorDim d = in_step.getDim();
+    d.width(1);
+    nntrainer::Tensor t(d, true);
+
+    unsigned int axis_dim = in_step.getDim()[3];
+    for (unsigned int i = from; i < to; ++i) {
+      float sum = 0.0;
+      _FP16 *data = in_step.getAddress<_FP16>(0, 0, i, 0);
+      for (unsigned int j = 0; j < axis_dim; ++j) {
+        sum += powf(static_cast<float>(data[j]), 2.0f);
+      }
+      t.setValue(0, 0, i, 0,
+                 static_cast<_FP16>(1.0 / sqrt(sum / axis_dim + epsilon)));
+    }
+    in_step.multiply(t, out_step);
+    out_step.multiply_i(gamma);
+
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not set");
+#endif
+  }
+}
+
 void RMSNormLayer::calcDerivative(nntrainer::RunLayerContext &context) {
   // std::throw_with_nested(std::runtime_error("Training is not supported
   // yet."));
index 2e675059e04fb714bad5ebd99a0015779121ede4..cc7ae4be896f9126850069a958341206e1cd796e 100644 (file)
@@ -79,6 +79,14 @@ public:
    */
   void forwarding(nntrainer::RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(nntrainer::RunLayerContext &context,
+                              unsigned int from, unsigned int to,
+                              bool training) override;
+
   /**
    * @copydoc Layer::calcDerivative(RunLayerContext &context)
    */
index 02ef82c526800d7fa3f0c6a051174801aa44b095..bc056c4336c7d1ef187f42f73b2c6811ea8e91f9 100644 (file)
@@ -75,6 +75,54 @@ void SwiGLULayer::forwarding(nntrainer::RunLayerContext &context,
   }
 }
 
+void SwiGLULayer::incremental_forwarding(nntrainer::RunLayerContext &context,
+                                         unsigned int from, unsigned int to,
+                                         bool training) {
+  nntrainer::Tensor &in1 = context.getInput(INPUT_IDX_1);
+  nntrainer::Tensor &in2 = context.getInput(INPUT_IDX_2);
+  nntrainer::Tensor &out = context.getOutput(OUT_IDX);
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    for (unsigned int b = 0; b < in1.batch(); b++) {
+      for (unsigned int c = 0; c < in1.channel(); c++) {
+        for (unsigned int h = from; h < to; h++) {
+          for (unsigned int w = 0; w < in1.width(); w++) {
+            out.setValue(b, c, h, w,
+                         ActivationOp::swish(in1.getValue<float>(b, c, h, w)) *
+                           in2.getValue<float>(b, c, h, w));
+          }
+        }
+      }
+    }
+  } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    for (unsigned int b = 0; b < in1.batch(); b++) {
+      for (unsigned int c = 0; c < in1.channel(); c++) {
+        for (unsigned int h = from; h < to; h++) {
+          for (unsigned int w = 0; w < in1.width(); w++) {
+            out.setValue(
+              b, c, h, w,
+              static_cast<_FP16>(
+                ActivationOp::swish(
+                  static_cast<float>(in1.getValue<_FP16>(b, c, h, w))) *
+                static_cast<float>(in2.getValue<_FP16>(b, c, h, w))));
+          }
+        }
+      }
+    }
+#else
+    NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!";
+#endif
+  }
+}
+
 void SwiGLULayer::calcDerivative(nntrainer::RunLayerContext &context) {
   // std::throw_with_nested(std::runtime_error("Training is not supported
   // yet."));
index a2690f83ec6f63e96f1dd96cd05828973a7c942b..eed9bbaf6f243fa0fdabd9739505b62091116f11 100644 (file)
@@ -49,6 +49,13 @@ public:
    */
   void forwarding(nntrainer::RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(nntrainer::RunLayerContext &context,
+                              unsigned int from, unsigned int to,
+                              bool training) override;
 
   /**
    * @copydoc Layer::calcDerivative(RunLayerContext &context)
index a5cc930b115bb7bb55dd04d3c4736bc22146cee1..7c990e3b81aa41c0dad36ad0ff2034f9e8e92fb3 100644 (file)
@@ -344,7 +344,7 @@ int main(int argc, char *argv[]) {
   for (unsigned int i = 1; i < init_input_seq_len + NUM_TOKENS_TO_GENERATE;
        ++i) {
     output_bufs = model->incremental_inference(
-      BATCH_SIZE, {wte_input, wpe_input}, {}, init_input_seq_len, i - 1);
+      BATCH_SIZE, {wte_input, wpe_input}, {}, init_input_seq_len, i - 1, i);
 
     nntrainer::Tensor output({BATCH_SIZE, 1, i, MODEL_DIM}, output_bufs[0]);
 
index fd5a7a38899d1227de2dcf81f473e542557978d5..afdf76520d1f4490294c60e5ee555e10b618a38c 100644 (file)
@@ -11,7 +11,10 @@ subdir('VGG/jni')
 subdir('Resnet/jni')
 subdir('YOLO/jni')
 subdir('YOLOv3/jni')
+if get_option('platform') != 'tizen'
 subdir('LLaMA/jni')
+endif
+
 subdir('ReinforcementLearning/DeepQ/jni')
 subdir('TransferLearning/CIFAR_Classification/jni')
 # if enable_capi
index b690a3955cdc0d5f780e086ead58397915ffa596..bff3f29526ff524bcaf4ced81724da8cd023a607 100644 (file)
@@ -305,14 +305,16 @@ public:
    * @param[in] input inputs as a list of each input data
    * @param[in] label labels as a list of each label data
    * @param[in] init_seq_len initial sequence length
-   * @param[in] cur_step current working step index (zero based index)
+   * @param[in] from current working step index
+   * @param[in] to next working step index
    * @retval list of output as float *
    * @note The output memory must not be freed by the caller
    */
   virtual std::vector<float *>
   incremental_inference(unsigned int batch, const std::vector<float *> &input,
                         const std::vector<float *> &label,
-                        unsigned int init_seq_len, unsigned int cur_step) = 0;
+                        unsigned int init_seq_len, unsigned int from,
+                        unsigned int to) = 0;
 
   /**
    * @brief     Summarize the model
index 6ed7a3aa9bd3457e9809272fb9d92a192854e285..dc1ba3ad67d02627d9786365de87238f008d114f 100644 (file)
@@ -62,11 +62,11 @@ if get_option('enable-fp16')
      extra_defines += '-DENABLE_FP16=1'
      extra_defines += '-DUSE__FP16=1'
      extra_defines += '-DUSE_NEON=1'
-   elif arch == 'aarch64' or arch =='arm'
+   elif arch == 'aarch64'
      add_project_arguments('-mfp16-format=ieee', language: ['c', 'cpp'])
      extra_defines += '-DENABLE_FP16=1'
      extra_defines += '-DUSE__FP16=1'
-     extra_defines += '-DUSE_NEON=0'
+     extra_defines += '-DUSE_NEON=1'
    else
      has_avx512fp16 = cc.has_argument('-mavx512fp16')
      if (has_avx512fp16)
index 646d303c45862901fbe3dc190867663fbf70866f..e5399d331132a473ace4d082276715252ba2fc37 100644 (file)
@@ -39,12 +39,40 @@ void AdditionLayer::forwarding(RunLayerContext &context, bool training) {
       hidden_.add_i(input_);
     }
   }
+}
+
+void AdditionLayer::incremental_forwarding(RunLayerContext &context,
+                                           unsigned int from, unsigned int to,
+                                           bool training) {
+  Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+  TensorDim hidden_dim = hidden_.getDim();
+  TensorDim hidden_step_dim = hidden_dim;
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  hidden_step_dim.height(to - from);
 
-  bool print = std::get<props::Print>(add_props).get();
-  if (print) {
-    // std::cerr << input_ << "\n";
-    // std::cerr << weight << "\n";
-    // std::cerr << hidden_ << "\n";
+  Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);
+
+  /** @todo check possibility for in-place of addition layer */
+  for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
+    const Tensor &input_ = context.getInput(idx);
+    TensorDim input_dim = input_.getDim();
+
+    TensorDim input_step_dim = input_dim;
+    input_step_dim.height(to - from);
+
+    Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
+    if (!idx) {
+      hidden_step.copy(input_step);
+    } else {
+      hidden_step.add_i(input_step);
+    }
   }
 }
 
index 75d57a6591123e155acdeacf75991ee32d6ef616..80898259664c6e7d35066f3c6df656250e6cf671 100644 (file)
@@ -58,6 +58,13 @@ public:
    */
   void forwarding(RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
   /**
    * @copydoc Layer::calcDerivative(RunLayerContext &context)
    */
index f3ab6ffae3dd62fba9e1a607c0deac361a9d2b17..bc203de0e132564bc448601538bfc77fdca42e74 100644 (file)
@@ -112,6 +112,50 @@ void EmbeddingLayer::forwarding(RunLayerContext &context, bool training) {
   }
 }
 
+void EmbeddingLayer::incremental_forwarding(RunLayerContext &context,
+                                            unsigned int from, unsigned int to,
+                                            bool training) {
+
+  /// @todo get input and output dimension from input_ and hidden itself
+  unsigned int in_dim = std::get<props::InDim>(embedding_props);
+  unsigned int out_dim = std::get<props::OutDim>(embedding_props);
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  Tensor &weight = context.getWeight(weight_idx);
+  Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+  Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
+
+  TensorDim out_tensor_dim =
+    TensorDim({1, 1, 1, out_dim}, hidden_.getTensorType());
+
+  for (unsigned int b = 0; b < input_.batch(); ++b) {
+    float *in_data =
+      input_.getAddress<float>(b * input_.getDim().getFeatureLen());
+
+    Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1);
+    for (unsigned int i = from; i < to; ++i) {
+      uint embed_idx = static_cast<uint>(in_data[i]);
+      if (embed_idx >= in_dim) {
+        throw std::invalid_argument("input word index is greater than in_dim");
+      }
+
+      Tensor cur_weight =
+        weight.getSharedDataTensor(out_tensor_dim, out_dim * embed_idx);
+
+      Tensor out_tensor = batchsliced_hidden.getSharedDataTensor(
+        out_tensor_dim, out_dim * (i - from));
+
+      out_tensor.copyData(cur_weight);
+    }
+  }
+}
+
 void EmbeddingLayer::calcDerivative(RunLayerContext &context) {
   throw exception::not_supported(
     "calcDerivative for Embedding layer is not supported");
index 9cea924c847489af6ad48e4a23c8cda1fccd817a..50d51fc389c7fcc3a3ec3e33a2aee1e144e3df6c 100644 (file)
@@ -59,6 +59,13 @@ public:
    */
   void forwarding(RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
   /**
    * @copydoc Layer::calcDerivative(RunLayerContext &context)
    */
index 189901f90bd324173f948c3f422ce49843beeb58..57bc0c1e71485a64385efe26bbc399d63b522963 100644 (file)
@@ -152,6 +152,45 @@ void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
   }
 }
 
+void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context,
+                                                 unsigned int from,
+                                                 unsigned int to,
+                                                 bool training) {
+  Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
+
+  Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
+  Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+
+  TensorDim input_dim = input_.getDim();
+  TensorDim hidden_dim = hidden_.getDim();
+
+  TensorDim input_step_dim = input_dim;
+  TensorDim hidden_step_dim = hidden_dim;
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  input_step_dim.height(to - from);
+  hidden_step_dim.height(to - from);
+
+  // @todo: set reset stride as false. This implementation only works when batch
+  // size is 1
+  Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
+  Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);
+
+  input_step.dot(weight, hidden_step, false, false);
+
+  if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
+      disable_bias.empty() || disable_bias.get() == false) {
+    Tensor &bias = context.getWeight(weight_idx[FCParams::bias]);
+    hidden_step.add_i(bias);
+  }
+}
+
 void FullyConnectedLayer::calcDerivative(RunLayerContext &context) {
   Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
 
index 25736016d7c5905cacdd9646767d43b6febc03b7..1d3cc28b38c5d159af06c5d5ab5a12a5fbdbcdd0 100644 (file)
@@ -58,6 +58,13 @@ public:
    */
   void forwarding(RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
   /**
    * @copydoc Layer::calcDerivative(RunLayerContext &context)
    */
index 1f13c51765bc23569134ee5f4249926377e05d84..c5668f922a3d57bb5a5f32b607ffe4fd05f98d73 100644 (file)
@@ -33,6 +33,7 @@ MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
   epsilon(1e-3),
   cache_index(0) {
   weight_idx.fill(std::numeric_limits<unsigned>::max());
+  layer_progress = 0;
 }
 
 MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
@@ -592,11 +593,283 @@ void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
 }
 
+void MultiHeadAttentionLayer::initial_incremental_forwarding(
+  RunLayerContext &context, unsigned int _from, unsigned int _to,
+  bool training) {
+  unsigned int max_timestep =
+    std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+  bool cache_shift = false;
+  unsigned int from = _from;
+  unsigned int to = _to;
+  if (to > max_timestep) {
+    throw std::invalid_argument("to shouldn't greater than max_timestep");
+  }
+
+  const bool disable_bias =
+    std::get<props::DisableBias>(*layer_impl_props).get();
+
+  const unsigned int num_heads =
+    std::get<props::NumHeads>(multi_head_attention_props).get();
+  const unsigned int projected_key_dim_prop =
+    std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+  const unsigned int projected_value_dim_prop =
+    std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+  const float dropout_rate =
+    std::get<props::DropOutRate>(multi_head_attention_props).get();
+  const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+    std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+  const bool average_attention_weight =
+    std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+  const bool provide_attention_mask = context.getNumInputs() == 4;
+  const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+  const bool enable_dropout = dropout_rate > epsilon;
+
+  /** get inputs/outputs */
+  Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+  Tensor &key = context.getInput(INOUT_INDEX::KEY);
+  Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+
+  Tensor empty_tensor;
+
+  empty_tensor.setTensorType(value.getTensorType());
+
+  Tensor &mask =
+    provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
+
+  TensorDim query_dim = query.getDim();
+  TensorDim key_dim = key.getDim();
+  TensorDim value_dim = value.getDim();
+
+  TensorDim query_step_dim = query_dim;
+  TensorDim key_step_dim = key_dim;
+  TensorDim value_step_dim = value_dim;
+
+  query_step_dim.height(to);
+  key_step_dim.height(to);
+  value_step_dim.height(to);
+
+  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
+  Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+
+  TensorDim output_dim = output.getDim();
+  TensorDim output_step_dim = output_dim;
+  output_step_dim.height(to);
+  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
+  Tensor &ret_attention_weight =
+    return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+      ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+      : empty_tensor;
+
+  /** get weights */
+  Tensor &query_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
+  Tensor &query_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
+  Tensor &key_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
+  Tensor &key_fc_bias =
+    disable_bias ? empty_tensor
+                 : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
+  Tensor &value_fc_weight =
+    context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
+  Tensor &value_fc_bias =
+    disable_bias
+      ? empty_tensor
+      : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
+  Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+  Tensor &fc_bias = disable_bias
+                      ? empty_tensor
+                      : context.getWeight(weight_idx[AttentionParams::fc_bias]);
+
+  /** get tensors */
+  Tensor &projected_query =
+    context.getTensor(weight_idx[AttentionParams::projected_query]);
+  Tensor &projected_key =
+    context.getTensor(weight_idx[AttentionParams::projected_key]);
+  Tensor &projected_value =
+    context.getTensor(weight_idx[AttentionParams::projected_value]);
+  Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
+  Tensor &cache_value =
+    context.getTensor(weight_idx[AttentionParams::cache_value]);
+
+  TensorDim projected_query_dim = projected_query.getDim();
+  TensorDim projected_key_dim = projected_key.getDim();
+  TensorDim projected_value_dim = projected_value.getDim();
+  TensorDim cache_key_dim = cache_key.getDim();
+  TensorDim cache_value_dim = cache_value.getDim();
+
+  TensorDim projected_query_step_dim = projected_query_dim;
+
+  TensorDim projected_key_step_dim = projected_key_dim;
+  TensorDim projected_value_step_dim = projected_value_dim;
+  TensorDim cache_key_step_dim = cache_key_dim;
+  TensorDim cache_value_step_dim = cache_value_dim;
+  projected_query_step_dim.height(to);
+
+  projected_key_step_dim.height(to);
+  projected_value_step_dim.height(to);
+  cache_key_step_dim.height(to);
+  cache_value_step_dim.height(to);
+
+  Tensor projected_query_step =
+    projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
+  Tensor projected_key_step =
+    projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
+  Tensor projected_value_step =
+    projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
+
+  Tensor cache_key_step =
+    cache_key.getSharedDataTensor(cache_key_step_dim, 0, true);
+  Tensor cache_value_step =
+    cache_value.getSharedDataTensor(cache_value_step_dim, 0, true);
+
+  TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
+                              to, cache_key_dim.width(),
+                              cache_key.getTensorType()};
+  TensorDim cached_value_dim = {
+    cache_value_dim.batch(), cache_value_dim.channel(), to,
+    cache_value_dim.width(), cache_value.getTensorType()};
+  Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
+  Tensor cached_value =
+    cache_value.getSharedDataTensor(cached_value_dim, 0, true);
+
+  Tensor &attention_weight =
+    context.getTensor(weight_idx[AttentionParams::attention_weight]);
+  Tensor &attention_output =
+    context.getTensor(weight_idx[AttentionParams::attention_output]);
+  TensorDim attention_weight_dim = attention_weight.getDim();
+
+  TensorDim attention_weight_step_dim = attention_weight_dim;
+  attention_weight_step_dim.height(to);
+  attention_weight_step_dim.width(to);
+
+  Tensor attention_weight_step =
+    attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
+
+  TensorDim attention_output_dim = attention_output.getDim();
+  TensorDim attention_output_step_dim = attention_output_dim;
+  attention_output_step_dim.height(to);
+
+  Tensor attention_output_step =
+    attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
+
+  const unsigned int batch_size = query_dim.batch();
+  const unsigned int query_height = query_dim.height();
+  const unsigned int key_height = key_dim.height();
+  const unsigned int value_height = value_dim.height();
+
+  query_step.dot(query_fc_weight, projected_query_step);
+  if (!disable_bias) {
+    projected_query_step.add_i(query_fc_bias);
+  }
+  key_step.dot(key_fc_weight, cache_key_step);
+  if (!disable_bias) {
+    cache_key_step.add_i(key_fc_bias);
+  }
+  value_step.dot(value_fc_weight, cache_value_step);
+  if (!disable_bias) {
+    cache_value_step.add_i(value_fc_bias);
+  }
+
+  apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop, from);
+  apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, from);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size, to, num_heads, projected_query_dim_prop}));
+  cached_key.reshape(
+    TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
+  cached_value.reshape(
+    TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
+
+  projected_query_step.transpose("1:0:2", projected_query_step);
+  cached_key.transpose("1:0:2", projected_key_step);
+  cached_value.transpose("1:0:2", projected_value_step);
+
+  projected_query_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_query_dim_prop}));
+  projected_key_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
+  projected_value_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+  attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, to, to}));
+  attention_output_step.reshape(
+    TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+  /** scaled dot product attention */
+  projected_query_step.dotBatched(projected_key_step, attention_weight_step,
+                                  false, true);
+  attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+
+  if (!from) {
+    unsigned int mask_size = attention_weight_step.getDim().width();
+    unsigned int mask_dim_height = mask_size;
+    unsigned int mask_dim_width = mask_size;
+
+    Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
+                                 attention_weight_step.getTensorType()});
+
+    causal_mask.setZero();
+
+#ifdef ENABLE_FP16
+#define _MASK_NUM -1e4
+#else
+#define _MASK_NUM -1e10
+#endif
+
+    for (unsigned int i = 0; i < mask_dim_height; ++i) {
+      for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
+        causal_mask.setValue(0, 0, i, j, _MASK_NUM);
+      }
+    }
+
+    attention_weight_step.add_i(causal_mask);
+  }
+
+  sm.run_fn(attention_weight_step, attention_weight_step);
+
+  attention_weight_step.dotBatched(projected_value_step, attention_output_step);
+
+  attention_output_step.reshape(
+    TensorDim({batch_size, num_heads, to, projected_value_dim_prop}));
+
+  attention_output_step = attention_output_step.transpose("1:0:2");
+
+  attention_output_step.reshape(
+    TensorDim({batch_size * to, 1, 1, num_heads * projected_value_dim_prop}));
+
+  attention_output_step.dot(fc_weight, output_step);
+  if (!disable_bias) {
+    output_step.add_i(fc_bias);
+  }
+
+  if (layer_progress == 28)
+    layer_progress = 0;
+  layer_progress++;
+
+  std::cout << "Process Reading: " << (int)((layer_progress / 28.0) * 100.0)
+            << " % \r";
+  std::cout.flush();
+}
+
 void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
                                                      unsigned int _from,
                                                      unsigned int _to,
                                                      bool training) {
 
+  if (!_from) {
+    initial_incremental_forwarding(context, _from, _to, training);
+    return;
+  }
+
   unsigned int max_timestep =
     std::get<props::MaxTimestep>(multi_head_attention_props).get();
 
@@ -645,9 +918,26 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   TensorDim key_dim = key.getDim();
   TensorDim value_dim = value.getDim();
 
+  TensorDim query_step_dim = query_dim;
+  TensorDim key_step_dim = key_dim;
+  TensorDim value_step_dim = value_dim;
+
+  query_step_dim.height(to - from);
+  key_step_dim.height(to - from);
+  value_step_dim.height(to - from);
+
+  Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+  Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+  Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
   Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
 
   TensorDim output_dim = output.getDim();
+
+  TensorDim output_step_dim = output_dim;
+  output_step_dim.height(to - from);
+  Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
   Tensor &ret_attention_weight =
     return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
       ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
@@ -753,15 +1043,16 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   const unsigned int key_height = key_dim.height();
   const unsigned int value_height = value_dim.height();
 
-  query.dot(query_fc_weight, projected_query_step);
+  query_step.dot(query_fc_weight, projected_query_step);
+
   if (!disable_bias) {
     projected_query_step.add_i(query_fc_bias);
   }
-  key.dot(key_fc_weight, cache_key_step);
+  key_step.dot(key_fc_weight, cache_key_step);
   if (!disable_bias) {
     cache_key_step.add_i(key_fc_bias);
   }
-  value.dot(value_fc_weight, cache_value_step);
+  value_step.dot(value_fc_weight, cache_value_step);
   if (!disable_bias) {
     cache_value_step.add_i(value_fc_bias);
   }
@@ -833,9 +1124,9 @@ void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
   attention_output_step.reshape(TensorDim(
     {batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
 
-  attention_output_step.dot(fc_weight, output);
+  attention_output_step.dot(fc_weight, output_step);
   if (!disable_bias) {
-    output.add_i(fc_bias);
+    output_step.add_i(fc_bias);
   }
 
   if (cache_shift) {
index 27b98cc0c710d6bea16244a9103416baf3311e93..0619b372be624a98ac9aee8bf7d53d0f3b795e4c 100644 (file)
@@ -64,6 +64,14 @@ public:
    */
   void forwarding(RunLayerContext &context, bool training) override;
 
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void initial_incremental_forwarding(RunLayerContext &context,
+                                      unsigned int from, unsigned int to,
+                                      bool training);
+
   /**
    * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
    * int from, unsigned int to, bool training)
@@ -130,9 +138,11 @@ private:
 
   unsigned int cache_index;
 
-  inline static std::vector<std::vector<float>> *freqs_cos = {};
+  inline static unsigned int layer_progress;
 
+  inline static std::vector<std::vector<float>> *freqs_cos = {};
   inline static std::vector<std::vector<float>> *freqs_sin = {};
+  inline static std::vector<float> freqs;
 
   /**
    * @brief     compute frequency for rotary embedding
@@ -143,9 +153,9 @@ private:
   void precompute_freqs(int dim, unsigned int seq_len, float theta = 10000.0) {
     if (freqs_cos == nullptr) {
       unsigned int half_ = dim / 2;
-      std::vector<float> freqs(half_);
       for (unsigned int i = 0; i < half_; ++i) {
-        freqs[i] = 1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim)));
+        freqs.push_back(1.0 /
+                        (std::pow(theta, (2 * i) / static_cast<float>(dim))));
       }
 
       auto cos = new std::vector<std::vector<float>>();
@@ -182,11 +192,36 @@ private:
     float value = 0;
     float transformed_value = 0.0;
     unsigned int half_ = dim / 2;
+    unsigned int max_timestep =
+      std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+    std::vector<float> *cos_;
+    std::vector<float> *sin_;
+
+    if (from >= max_timestep) {
+      std::cout << from << " " << max_timestep << std::endl;
+      cos_ = new std::vector<float>(dim);
+      sin_ = new std::vector<float>(dim);
+
+      for (unsigned int i = 0; i < half_; ++i) {
+        float angle = from * freqs[i];
+        (*cos_)[i] = std::cos(angle);
+        (*cos_)[i + half_] = std::cos(angle); // repeated 2 times
+
+        (*sin_)[i] = std::sin(angle);
+        (*sin_)[i + half_] = std::sin(angle); // repeated 2 times
+      }
+    }
 
     if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
       for (unsigned int b = 0; b < in.batch(); b++) {
         for (unsigned int c = 0; c < in.channel(); c++) {
           for (unsigned int h = 0; h < in.height(); h++) {
+            if (from < max_timestep) {
+              cos_ = &(*freqs_cos)[from + h];
+              sin_ = &(*freqs_sin)[from + h];
+            }
+
             for (unsigned int w = 0; w < in.width(); w = w + dim) {
               for (unsigned int k = 0; k < dim; k++) {
                 unsigned int span = w + k;
@@ -212,6 +247,11 @@ private:
       for (unsigned int b = 0; b < in.batch(); b++) {
         for (unsigned int c = 0; c < in.channel(); c++) {
           for (unsigned int h = 0; h < in.height(); h++) {
+            if (from < max_timestep) {
+              cos_ = &(*freqs_cos)[from + h];
+              sin_ = &(*freqs_sin)[from + h];
+            }
+
             for (unsigned int w = 0; w < in.width(); w = w + dim) {
               for (unsigned int k = 0; k < dim; k++) {
 #ifdef ENABLE_FP16
index 07ea7cbad0f93a3a39c81dd1cf0b1347f9b06138..a9c714e5497ba129fb6c501333339dccd52b47da 100644 (file)
@@ -43,12 +43,18 @@ void MultiOutLayer::incremental_forwarding(RunLayerContext &context,
                                            unsigned int from, unsigned int to,
                                            bool training) {
   if (!context.executeInPlace()) {
+    if (from) {
+      NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+        << "incremental step size is not 1";
+      from = 0;
+      to = 1;
+    }
+
     const Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
     TensorDim input_dim = input_.getDim();
     TensorDim input_step_dim = {input_dim.batch(), input_dim.channel(),
                                 to - from, input_dim.width()};
-    Tensor input_step = input_.getSharedDataTensor(
-      input_step_dim, from * input_dim.width(), true);
+    Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
 
     for (unsigned int idx = 0; idx < context.getNumOutputs(); ++idx) {
       Tensor &output = context.getOutput(idx);
@@ -58,8 +64,7 @@ void MultiOutLayer::incremental_forwarding(RunLayerContext &context,
                                    to - from, output_dim.width()};
       // @todo: set reset stride as false. This implementation only works when
       // batch size is 1
-      Tensor output_step = output.getSharedDataTensor(
-        output_step_dim, from * output_dim.width(), true);
+      Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
       output_step.fill(input_step);
     }
   }
index 108c1e79e361a60ffaad6976d3d49ecb22ee49e1..59ea78a05caf953b5f03cfd229dbac5eb8ca55fd 100644 (file)
@@ -766,14 +766,16 @@ NeuralNetwork::inference(unsigned int batch_size,
   return output;
 }
 
-sharedConstTensors NeuralNetwork::incremental_inference(
-  sharedConstTensors X, unsigned int init_seq_len, unsigned int cur_step) {
-  return incremental_inference(X, {}, init_seq_len, cur_step);
+sharedConstTensors
+NeuralNetwork::incremental_inference(sharedConstTensors X,
+                                     unsigned int init_seq_len,
+                                     unsigned int from, unsigned int to) {
+  return incremental_inference(X, {}, init_seq_len, from, to);
 }
 
 sharedConstTensors NeuralNetwork::incremental_inference(
   sharedConstTensors X, sharedConstTensors label, unsigned int init_seq_len,
-  unsigned int cur_step) {
+  unsigned int from, unsigned int to) {
   if (model_graph.getBatchSize() != X[0]->batch()) {
     model_graph.setBatchSize(X[0]->batch());
   }
@@ -782,7 +784,7 @@ sharedConstTensors NeuralNetwork::incremental_inference(
   if (!validateInput(X))
     throw std::invalid_argument("Input validation failed.");
 
-  if (cur_step == 0) {
+  if (from == 0) {
     allocate(ExecutionMode::INFERENCE);
   }
 
@@ -790,7 +792,7 @@ sharedConstTensors NeuralNetwork::incremental_inference(
   PROFILE_TIME_REGISTER_EVENT(nn_foward, "nn_forward");
   PROFILE_TIME_START(nn_foward);
 
-  out = incremental_forwarding(cur_step, cur_step + 1, X, label, false);
+  out = incremental_forwarding(from, to, X, label, false);
 
   PROFILE_TIME_END(nn_foward);
 
@@ -804,7 +806,7 @@ sharedConstTensors NeuralNetwork::incremental_inference(
 std::vector<float *> NeuralNetwork::incremental_inference(
   unsigned int batch_size, const std::vector<float *> &input,
   const std::vector<float *> &label, unsigned int init_seq_len,
-  unsigned int cur_step) {
+  unsigned int from, unsigned int to) {
   sharedConstTensors input_tensors, output_tensors;
   auto in_dim = getInputDimension();
 
@@ -826,26 +828,27 @@ std::vector<float *> NeuralNetwork::incremental_inference(
                     label_dim[idx], 0)));
     }
     output_tensors = incremental_inference(input_tensors, label_tensors,
-                                           init_seq_len, cur_step);
+                                           init_seq_len, from, to);
   } else {
     output_tensors =
-      incremental_inference(input_tensors, init_seq_len, cur_step);
+      incremental_inference(input_tensors, init_seq_len, from, to);
   }
 
   std::vector<float *> output;
-  output.reserve(output_tensors.size());
 
   unsigned int idx = 0;
+  if (!from) {
+    idx = to - 1;
+  }
   for (auto &out : output_tensors) {
     if (out->getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
       auto out_t = *out.get();
-      _FP16 *vec_fp16 = out_t.getData<_FP16>();
-      float *vec_fp32 = new float[out_t.size()]();
-      output.push_back(vec_fp32);
-      for (unsigned int i = 0; i < out_t.size(); ++i) {
-        output[idx][i] = static_cast<float>(vec_fp16[i]);
+      float *vec_fp32 = new float[out_t.width()];
+      for (unsigned int i = 0; i < out_t.width(); ++i) {
+        (vec_fp32)[i] = static_cast<float>(out_t.getValue<_FP16>(0, 0, idx, i));
       }
+      output.emplace_back(vec_fp32);
 #else
       throw std::invalid_argument("Errro: enable-fp16 is not set");
 #endif
index 18f6ae2f336a96b7e01aeccd9eb7387f046829f9..457b7d1e97e7d301ca0be6e9ba35ff3ab6190beb 100644 (file)
@@ -376,7 +376,7 @@ s   * @retval shared_ptr<const Tensor>
    */
   sharedConstTensors incremental_inference(sharedConstTensors X,
                                            unsigned int init_seq_len,
-                                           unsigned int step);
+                                           unsigned int from, unsigned int to);
 
   /**
    * @brief     Run NeuralNetwork incremental inference
@@ -389,7 +389,7 @@ s   * @retval shared_ptr<const Tensor>
   sharedConstTensors incremental_inference(sharedConstTensors X,
                                            sharedConstTensors label,
                                            unsigned int init_seq_len,
-                                           unsigned int step);
+                                           unsigned int from, unsigned int to);
 
   /**
    * @brief     Run the incremental inference of the model
@@ -405,7 +405,8 @@ s   * @retval shared_ptr<const Tensor>
                                              const std::vector<float *> &input,
                                              const std::vector<float *> &label,
                                              unsigned int init_seq_len,
-                                             unsigned int step) override;
+                                             unsigned int from,
+                                             unsigned int to) override;
 
   /**
    * @brief     Run NeuralNetwork train with callback function by user