This PR uses GEMM to compute initial sequences.
**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
#include <chrono>
#include <ctime>
#include <iostream>
+#include <iterator>
#include <memory>
#include <sstream>
#include <string>
using LayerHandle = std::shared_ptr<ml::train::Layer>;
using ModelHandle = std::unique_ptr<ml::train::Model>;
+ModelHandle g_model;
+
// Hyper params for LLaMA
int const DIM = 2304;
int const NUM_LAYERS = 28;
}
}
+/**
+ * @brief Create Attention Layer for the seperate impelemntation
+ */
std::vector<LayerHandle> createAttentionLayer(const int layer_id, int seq_len,
int n_heads, int head_dim,
std::string query_name,
return layers;
}
+/**
+ * @brief Create FF Layers
+ */
std::vector<LayerHandle> createFeedForwardLayer(const int layer_id, int dim,
int hidden_dim,
std::string input_name,
return layers;
}
+/**
+ * @brief Create Decoder
+ */
std::vector<LayerHandle> createTransformerDecoder(const int layer_id,
std::string input_name) {
using ml::train::createLayer;
return layers;
}
+/**
+ * @brief Create LLaMA2 Model
+ */
ModelHandle createLLaMA() {
using ml::train::createLayer;
withKey("input_shape", "1:1:" + std::to_string(INIT_SEQ_LEN))}));
} else {
layers.push_back(createLayer(
- "input", {withKey("name", "input0"), withKey("input_shape", "1:1:1")}));
+ "input",
+ {withKey("name", "input0"),
+ withKey("input_shape", "1:1:" + std::to_string(INIT_SEQ_LEN))}));
}
layers.push_back(ml::train::layer::Embedding(
return model;
}
-void createAndRun(unsigned int epochs, unsigned int batch_size,
- std::wstring text) {
- // setup model
- ModelHandle model = createLLaMA();
- model->setProperty({withKey("batch_size", batch_size),
- withKey("epochs", epochs),
- // #ifdef ENABLE_FP16
- // withKey("model_tensor_type", "FP16-FP16"),
- // #endif
- withKey("save_path", "test_model.bin")});
-
- auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"});
- model->setOptimizer(std::move(optimizer));
-
- int status = model->compile();
- if (status) {
- throw std::invalid_argument("model compilation failed!");
- }
-
- status = model->initialize();
- if (status) {
- throw std::invalid_argument("model initialization failed!");
- }
-
- std::string weight_path = "./llama_fp16.bin";
-
- model->load(weight_path);
+/**
+ * @brief to run for every text sequence
+ */
+void run(std::string text) {
std::vector<float *> input;
std::vector<float *> label;
float *input_sample = (float *)malloc(sizeof(float) * data_size);
+ unsigned int input_len = INIT_SEQ_LEN;
+
+ unsigned int init_len;
+
#if defined(ENABLE_ENCODER2)
std::string vocab_file_name = "../Applications/LLaMA/jni/vocab.json";
std::string merge_file_name = "../Applications/LLaMA/jni/merges.txt";
auto tokenizer = unwrap(GPT2Encoder::load(vocab_file_name, merge_file_name),
- "Error initialising GPT2 tokenizer\n");
+ "Error initializising GPT2 tokenizer\n");
+
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+
+ auto init_input = tokenizer.encode(converter.from_bytes(text));
+ init_len = init_input.size();
+
+ input_len = (init_len > INIT_SEQ_LEN) ? INIT_SEQ_LEN : init_len;
+
+ for (unsigned int i = 0; i < input_len; ++i) {
+ input_sample[i] = static_cast<float>(init_input[i]);
+ }
- auto init_input = tokenizer.encode(text);
- INIT_SEQ_LEN = init_input.size();
- ((uint *)(input_sample))[0] = init_input[0];
input.push_back(input_sample);
#else
- float init_data[INIT_SEQ_LEN] = {
+ float init_input[INIT_SEQ_LEN] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40,
50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900};
- ((uint *)(input_sample))[0] = init_data[0];
+ ((uint *)(input_sample))[0] = init_input[0];
input.push_back(input_sample);
+ init_len = 18;
#endif
- for (unsigned int i = 1; i < INIT_SEQ_LEN + NUM_TO_GENERATE; ++i) {
- auto output =
- model->incremental_inference(1, input, label, INIT_SEQ_LEN, i - 1);
+ std::vector<int64_t> token_ids;
+
+ auto output =
+ g_model->incremental_inference(1, input, label, MAX_SEQ_LEN, 0, input_len);
- std::vector<int64_t> tokens;
- nntrainer::Tensor output_tensor({batch_size, 1, 1, NUM_VOCAB}, output[0]);
+ unsigned int ids = std::distance(
+ output[0], std::max_element(output[0], output[0] + NUM_VOCAB));
- tokens.push_back(static_cast<int64_t>(output_tensor.argmax()[0]));
+ input_sample[0] = static_cast<float>(ids);
+
+#ifdef ENABLE_FP16
+ for (auto o : output) {
+ delete[] o;
+ }
+#endif
+ std::cout << " Progress Reading: 100 % " << std::endl;
+ std::cout << std::endl << "### Output : " << std::endl;
+ if (init_len < INIT_SEQ_LEN) {
#if defined(ENABLE_ENCODER2)
- auto decoded_str = tokenizer.decode(tokens);
- std::cerr << decoded_str << std::flush;
+ auto decoded_str = tokenizer.decode({static_cast<int64_t>(ids)});
+ std::cout << decoded_str << " ";
+ std::cout.flush();
#endif
+ }
+
+ for (unsigned int i = input_len + 1; i < input_len + NUM_TO_GENERATE; ++i) {
+ auto output_interval =
+ g_model->incremental_inference(1, input, label, MAX_SEQ_LEN, i - 1, i);
- if (i < INIT_SEQ_LEN) {
+ ids = std::distance(
+ output_interval[0],
+ std::max_element(output_interval[0], output_interval[0] + NUM_VOCAB));
+
+ if (i < input_len) {
+ input_sample[0] = static_cast<float>(init_input[i]);
+ } else {
+ input_sample[0] = static_cast<float>(ids);
#if defined(ENABLE_ENCODER2)
- ((uint *)(input_sample))[0] = init_input[i];
-#else
- ((uint *)(input_sample))[0] = init_data[i];
+ auto decoded_str = tokenizer.decode({static_cast<int64_t>(ids)});
+ std::cout << decoded_str << " ";
+ std::cout.flush();
#endif
- } else {
- ((uint *)(input_sample))[0] = output_tensor.argmax()[0];
}
+
+#ifdef ENABLE_FP16
+ for (auto o : output_interval) {
+ delete[] o;
+ }
+#endif
}
+
+ std::cout << std::endl;
+ free(input_sample);
+}
+
+/**
+ * @brief to creaet model
+ */
+void createAndRun(unsigned int epochs, unsigned int batch_size) {
+ // setup model
+ g_model = createLLaMA();
+ g_model->setProperty({withKey("batch_size", batch_size),
+ withKey("epochs", epochs),
+ // #ifdef ENABLE_FP16
+ withKey("model_tensor_type", "FP16-FP16"),
+ // #endif
+ withKey("save_path", "test_model.bin")});
+
+ auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"});
+ g_model->setOptimizer(std::move(optimizer));
+
+ int status = g_model->compile();
+ if (status) {
+ throw std::invalid_argument("model compilation failed!");
+ }
+
+ status = g_model->initialize();
+ if (status) {
+ throw std::invalid_argument("model initialization failed!");
+ }
+
+ std::string weight_path = "./llama_fp16.bin";
+
+ g_model->load(weight_path);
}
#if defined(ENABLE_ENCODER2)
return result.str();
}
#endif
-
int main(int argc, char *argv[]) {
// Setting locale
std::locale::global(std::locale("ko_KR.UTF-8"));
#if defined(ENABLE_ENCODER2)
-
// Getting arguments From terminal
std::wstring input;
std::getline(std::wcin, input);
std::wstring test = decodeUnicodeEscape(input);
std::wstring_convert<std::codecvt_utf16<wchar_t>> converter;
std::string text = converter.to_bytes(test);
- std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
#else
- std::wstring text = L"This is sample input for LLaMA.";
+ std::string text = "This is smaple input for LLaMA.";
#endif
auto &app_context = nntrainer::AppContext::Global();
return 1;
}
- createAndRun(epoch, batch_size, text);
+ try {
+ const std::vector<std::string> args(argv + 1, argv + argc);
+
+ createAndRun(epoch, batch_size);
+
+ run(text);
+ } catch (const std::exception &e) {
+ std::cerr << "uncaught error while running! details: " << e.what()
+ << std::endl;
+ return EXIT_FAILURE;
+ }
int status = EXIT_SUCCESS;
return status;
)
if get_option('platform') != 'tizen'
- extra_defines += '-DENABLE_ENCODER2=1'
- run_command(meson.source_root() / 'jni' / 'prepare_encoder.sh', meson.build_root(), '0.2' ,check: true)
+ extra_defines += '-DENABLE_ENCODER2=1'
+ run_command(meson.source_root() / 'jni' / 'prepare_encoder.sh', meson.build_root(), '0.2' ,check: true)
endif
rms_norm_src = files('rms_norm.cpp')
}
}
+void RMSNormLayer::incremental_forwarding(nntrainer::RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) {
+ nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+ nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+ nntrainer::Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
+ ml::train::TensorDim in_dim = in.getDim();
+ ml::train::TensorDim out_dim = out.getDim();
+
+ ml::train::TensorDim in_step_dim = in_dim;
+ ml::train::TensorDim out_step_dim = out_dim;
+
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
+ in_step_dim.height(to - from);
+ out_step_dim.height(to - from);
+
+ nntrainer::Tensor in_step = in.getSharedDataTensor(in_step_dim, 0, true);
+ nntrainer::Tensor out_step = out.getSharedDataTensor(out_step_dim, 0, true);
+
+ auto &epsilon = std::get<nntrainer::props::Epsilon>(rms_props).get();
+
+ if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ std::function<float(float)> f = [](float x) { return 1 / std::sqrt(x); };
+ auto t = in_step.multiply(in_step).average(3).add(epsilon);
+ t.apply_i(f);
+ in_step.multiply(t, out_step);
+ out_step.multiply_i(gamma);
+
+ } else if (in_step.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ ml::train::TensorDim d = in_step.getDim();
+ d.width(1);
+ nntrainer::Tensor t(d, true);
+
+ unsigned int axis_dim = in_step.getDim()[3];
+ for (unsigned int i = from; i < to; ++i) {
+ float sum = 0.0;
+ _FP16 *data = in_step.getAddress<_FP16>(0, 0, i, 0);
+ for (unsigned int j = 0; j < axis_dim; ++j) {
+ sum += powf(static_cast<float>(data[j]), 2.0f);
+ }
+ t.setValue(0, 0, i, 0,
+ static_cast<_FP16>(1.0 / sqrt(sum / axis_dim + epsilon)));
+ }
+ in_step.multiply(t, out_step);
+ out_step.multiply_i(gamma);
+
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not set");
+#endif
+ }
+}
+
void RMSNormLayer::calcDerivative(nntrainer::RunLayerContext &context) {
// std::throw_with_nested(std::runtime_error("Training is not supported
// yet."));
*/
void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void incremental_forwarding(nntrainer::RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) override;
+
/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
}
}
+void SwiGLULayer::incremental_forwarding(nntrainer::RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) {
+ nntrainer::Tensor &in1 = context.getInput(INPUT_IDX_1);
+ nntrainer::Tensor &in2 = context.getInput(INPUT_IDX_2);
+ nntrainer::Tensor &out = context.getOutput(OUT_IDX);
+
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
+ if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ for (unsigned int b = 0; b < in1.batch(); b++) {
+ for (unsigned int c = 0; c < in1.channel(); c++) {
+ for (unsigned int h = from; h < to; h++) {
+ for (unsigned int w = 0; w < in1.width(); w++) {
+ out.setValue(b, c, h, w,
+ ActivationOp::swish(in1.getValue<float>(b, c, h, w)) *
+ in2.getValue<float>(b, c, h, w));
+ }
+ }
+ }
+ }
+ } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ for (unsigned int b = 0; b < in1.batch(); b++) {
+ for (unsigned int c = 0; c < in1.channel(); c++) {
+ for (unsigned int h = from; h < to; h++) {
+ for (unsigned int w = 0; w < in1.width(); w++) {
+ out.setValue(
+ b, c, h, w,
+ static_cast<_FP16>(
+ ActivationOp::swish(
+ static_cast<float>(in1.getValue<_FP16>(b, c, h, w))) *
+ static_cast<float>(in2.getValue<_FP16>(b, c, h, w))));
+ }
+ }
+ }
+ }
+#else
+ NNTR_THROW_IF(true, std::invalid_argument) << "enable-fp16 is not set!";
+#endif
+ }
+}
+
void SwiGLULayer::calcDerivative(nntrainer::RunLayerContext &context) {
// std::throw_with_nested(std::runtime_error("Training is not supported
// yet."));
*/
void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void incremental_forwarding(nntrainer::RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) override;
/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
for (unsigned int i = 1; i < init_input_seq_len + NUM_TOKENS_TO_GENERATE;
++i) {
output_bufs = model->incremental_inference(
- BATCH_SIZE, {wte_input, wpe_input}, {}, init_input_seq_len, i - 1);
+ BATCH_SIZE, {wte_input, wpe_input}, {}, init_input_seq_len, i - 1, i);
nntrainer::Tensor output({BATCH_SIZE, 1, i, MODEL_DIM}, output_bufs[0]);
subdir('Resnet/jni')
subdir('YOLO/jni')
subdir('YOLOv3/jni')
+if get_option('platform') != 'tizen'
subdir('LLaMA/jni')
+endif
+
subdir('ReinforcementLearning/DeepQ/jni')
subdir('TransferLearning/CIFAR_Classification/jni')
# if enable_capi
* @param[in] input inputs as a list of each input data
* @param[in] label labels as a list of each label data
* @param[in] init_seq_len initial sequence length
- * @param[in] cur_step current working step index (zero based index)
+ * @param[in] from current working step index
+ * @param[in] to next working step index
* @retval list of output as float *
* @note The output memory must not be freed by the caller
*/
virtual std::vector<float *>
incremental_inference(unsigned int batch, const std::vector<float *> &input,
const std::vector<float *> &label,
- unsigned int init_seq_len, unsigned int cur_step) = 0;
+ unsigned int init_seq_len, unsigned int from,
+ unsigned int to) = 0;
/**
* @brief Summarize the model
extra_defines += '-DENABLE_FP16=1'
extra_defines += '-DUSE__FP16=1'
extra_defines += '-DUSE_NEON=1'
- elif arch == 'aarch64' or arch =='arm'
+ elif arch == 'aarch64'
add_project_arguments('-mfp16-format=ieee', language: ['c', 'cpp'])
extra_defines += '-DENABLE_FP16=1'
extra_defines += '-DUSE__FP16=1'
- extra_defines += '-DUSE_NEON=0'
+ extra_defines += '-DUSE_NEON=1'
else
has_avx512fp16 = cc.has_argument('-mavx512fp16')
if (has_avx512fp16)
hidden_.add_i(input_);
}
}
+}
+
+void AdditionLayer::incremental_forwarding(RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) {
+ Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+ TensorDim hidden_dim = hidden_.getDim();
+ TensorDim hidden_step_dim = hidden_dim;
+
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
+ hidden_step_dim.height(to - from);
- bool print = std::get<props::Print>(add_props).get();
- if (print) {
- // std::cerr << input_ << "\n";
- // std::cerr << weight << "\n";
- // std::cerr << hidden_ << "\n";
+ Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);
+
+ /** @todo check possibility for in-place of addition layer */
+ for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
+ const Tensor &input_ = context.getInput(idx);
+ TensorDim input_dim = input_.getDim();
+
+ TensorDim input_step_dim = input_dim;
+ input_step_dim.height(to - from);
+
+ Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
+ if (!idx) {
+ hidden_step.copy(input_step);
+ } else {
+ hidden_step.add_i(input_step);
+ }
}
}
*/
void forwarding(RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void incremental_forwarding(RunLayerContext &context, unsigned int from,
+ unsigned int to, bool training) override;
+
/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
}
}
+void EmbeddingLayer::incremental_forwarding(RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training) {
+
+ /// @todo get input and output dimension from input_ and hidden itself
+ unsigned int in_dim = std::get<props::InDim>(embedding_props);
+ unsigned int out_dim = std::get<props::OutDim>(embedding_props);
+
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
+ Tensor &weight = context.getWeight(weight_idx);
+ Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+ Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
+
+ TensorDim out_tensor_dim =
+ TensorDim({1, 1, 1, out_dim}, hidden_.getTensorType());
+
+ for (unsigned int b = 0; b < input_.batch(); ++b) {
+ float *in_data =
+ input_.getAddress<float>(b * input_.getDim().getFeatureLen());
+
+ Tensor batchsliced_hidden = hidden_.getBatchSlice(b, 1);
+ for (unsigned int i = from; i < to; ++i) {
+ uint embed_idx = static_cast<uint>(in_data[i]);
+ if (embed_idx >= in_dim) {
+ throw std::invalid_argument("input word index is greater than in_dim");
+ }
+
+ Tensor cur_weight =
+ weight.getSharedDataTensor(out_tensor_dim, out_dim * embed_idx);
+
+ Tensor out_tensor = batchsliced_hidden.getSharedDataTensor(
+ out_tensor_dim, out_dim * (i - from));
+
+ out_tensor.copyData(cur_weight);
+ }
+ }
+}
+
void EmbeddingLayer::calcDerivative(RunLayerContext &context) {
throw exception::not_supported(
"calcDerivative for Embedding layer is not supported");
*/
void forwarding(RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void incremental_forwarding(RunLayerContext &context, unsigned int from,
+ unsigned int to, bool training) override;
+
/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
}
}
+void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context,
+ unsigned int from,
+ unsigned int to,
+ bool training) {
+ Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
+
+ Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
+ Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+
+ TensorDim input_dim = input_.getDim();
+ TensorDim hidden_dim = hidden_.getDim();
+
+ TensorDim input_step_dim = input_dim;
+ TensorDim hidden_step_dim = hidden_dim;
+
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
+ input_step_dim.height(to - from);
+ hidden_step_dim.height(to - from);
+
+ // @todo: set reset stride as false. This implementation only works when batch
+ // size is 1
+ Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
+ Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);
+
+ input_step.dot(weight, hidden_step, false, false);
+
+ if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
+ disable_bias.empty() || disable_bias.get() == false) {
+ Tensor &bias = context.getWeight(weight_idx[FCParams::bias]);
+ hidden_step.add_i(bias);
+ }
+}
+
void FullyConnectedLayer::calcDerivative(RunLayerContext &context) {
Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
*/
void forwarding(RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void incremental_forwarding(RunLayerContext &context, unsigned int from,
+ unsigned int to, bool training) override;
+
/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
epsilon(1e-3),
cache_index(0) {
weight_idx.fill(std::numeric_limits<unsigned>::max());
+ layer_progress = 0;
}
MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
{batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
}
+void MultiHeadAttentionLayer::initial_incremental_forwarding(
+ RunLayerContext &context, unsigned int _from, unsigned int _to,
+ bool training) {
+ unsigned int max_timestep =
+ std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+ bool cache_shift = false;
+ unsigned int from = _from;
+ unsigned int to = _to;
+ if (to > max_timestep) {
+ throw std::invalid_argument("to shouldn't greater than max_timestep");
+ }
+
+ const bool disable_bias =
+ std::get<props::DisableBias>(*layer_impl_props).get();
+
+ const unsigned int num_heads =
+ std::get<props::NumHeads>(multi_head_attention_props).get();
+ const unsigned int projected_key_dim_prop =
+ std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
+ const unsigned int projected_value_dim_prop =
+ std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
+ const float dropout_rate =
+ std::get<props::DropOutRate>(multi_head_attention_props).get();
+ const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
+ std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
+ const bool average_attention_weight =
+ std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
+
+ const bool provide_attention_mask = context.getNumInputs() == 4;
+ const unsigned int projected_query_dim_prop = projected_key_dim_prop;
+ const bool enable_dropout = dropout_rate > epsilon;
+
+ /** get inputs/outputs */
+ Tensor &query = context.getInput(INOUT_INDEX::QUERY);
+ Tensor &key = context.getInput(INOUT_INDEX::KEY);
+ Tensor &value = context.getInput(INOUT_INDEX::VALUE);
+
+ Tensor empty_tensor;
+
+ empty_tensor.setTensorType(value.getTensorType());
+
+ Tensor &mask =
+ provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
+
+ TensorDim query_dim = query.getDim();
+ TensorDim key_dim = key.getDim();
+ TensorDim value_dim = value.getDim();
+
+ TensorDim query_step_dim = query_dim;
+ TensorDim key_step_dim = key_dim;
+ TensorDim value_step_dim = value_dim;
+
+ query_step_dim.height(to);
+ key_step_dim.height(to);
+ value_step_dim.height(to);
+
+ Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+ Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+ Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
+ Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
+
+ TensorDim output_dim = output.getDim();
+ TensorDim output_step_dim = output_dim;
+ output_step_dim.height(to);
+ Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
+ Tensor &ret_attention_weight =
+ return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
+ ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
+ : empty_tensor;
+
+ /** get weights */
+ Tensor &query_fc_weight =
+ context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
+ Tensor &query_fc_bias =
+ disable_bias
+ ? empty_tensor
+ : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
+ Tensor &key_fc_weight =
+ context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
+ Tensor &key_fc_bias =
+ disable_bias ? empty_tensor
+ : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
+ Tensor &value_fc_weight =
+ context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
+ Tensor &value_fc_bias =
+ disable_bias
+ ? empty_tensor
+ : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
+ Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
+ Tensor &fc_bias = disable_bias
+ ? empty_tensor
+ : context.getWeight(weight_idx[AttentionParams::fc_bias]);
+
+ /** get tensors */
+ Tensor &projected_query =
+ context.getTensor(weight_idx[AttentionParams::projected_query]);
+ Tensor &projected_key =
+ context.getTensor(weight_idx[AttentionParams::projected_key]);
+ Tensor &projected_value =
+ context.getTensor(weight_idx[AttentionParams::projected_value]);
+ Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
+ Tensor &cache_value =
+ context.getTensor(weight_idx[AttentionParams::cache_value]);
+
+ TensorDim projected_query_dim = projected_query.getDim();
+ TensorDim projected_key_dim = projected_key.getDim();
+ TensorDim projected_value_dim = projected_value.getDim();
+ TensorDim cache_key_dim = cache_key.getDim();
+ TensorDim cache_value_dim = cache_value.getDim();
+
+ TensorDim projected_query_step_dim = projected_query_dim;
+
+ TensorDim projected_key_step_dim = projected_key_dim;
+ TensorDim projected_value_step_dim = projected_value_dim;
+ TensorDim cache_key_step_dim = cache_key_dim;
+ TensorDim cache_value_step_dim = cache_value_dim;
+ projected_query_step_dim.height(to);
+
+ projected_key_step_dim.height(to);
+ projected_value_step_dim.height(to);
+ cache_key_step_dim.height(to);
+ cache_value_step_dim.height(to);
+
+ Tensor projected_query_step =
+ projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
+ Tensor projected_key_step =
+ projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
+ Tensor projected_value_step =
+ projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
+
+ Tensor cache_key_step =
+ cache_key.getSharedDataTensor(cache_key_step_dim, 0, true);
+ Tensor cache_value_step =
+ cache_value.getSharedDataTensor(cache_value_step_dim, 0, true);
+
+ TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
+ to, cache_key_dim.width(),
+ cache_key.getTensorType()};
+ TensorDim cached_value_dim = {
+ cache_value_dim.batch(), cache_value_dim.channel(), to,
+ cache_value_dim.width(), cache_value.getTensorType()};
+ Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
+ Tensor cached_value =
+ cache_value.getSharedDataTensor(cached_value_dim, 0, true);
+
+ Tensor &attention_weight =
+ context.getTensor(weight_idx[AttentionParams::attention_weight]);
+ Tensor &attention_output =
+ context.getTensor(weight_idx[AttentionParams::attention_output]);
+ TensorDim attention_weight_dim = attention_weight.getDim();
+
+ TensorDim attention_weight_step_dim = attention_weight_dim;
+ attention_weight_step_dim.height(to);
+ attention_weight_step_dim.width(to);
+
+ Tensor attention_weight_step =
+ attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
+
+ TensorDim attention_output_dim = attention_output.getDim();
+ TensorDim attention_output_step_dim = attention_output_dim;
+ attention_output_step_dim.height(to);
+
+ Tensor attention_output_step =
+ attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
+
+ const unsigned int batch_size = query_dim.batch();
+ const unsigned int query_height = query_dim.height();
+ const unsigned int key_height = key_dim.height();
+ const unsigned int value_height = value_dim.height();
+
+ query_step.dot(query_fc_weight, projected_query_step);
+ if (!disable_bias) {
+ projected_query_step.add_i(query_fc_bias);
+ }
+ key_step.dot(key_fc_weight, cache_key_step);
+ if (!disable_bias) {
+ cache_key_step.add_i(key_fc_bias);
+ }
+ value_step.dot(value_fc_weight, cache_value_step);
+ if (!disable_bias) {
+ cache_value_step.add_i(value_fc_bias);
+ }
+
+ apply_rotary_emb_tensor(projected_query_step, projected_query_dim_prop, from);
+ apply_rotary_emb_tensor(cache_key_step, projected_key_dim_prop, from);
+
+ projected_query_step.reshape(
+ TensorDim({batch_size, to, num_heads, projected_query_dim_prop}));
+ cached_key.reshape(
+ TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
+ cached_value.reshape(
+ TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
+
+ projected_query_step.transpose("1:0:2", projected_query_step);
+ cached_key.transpose("1:0:2", projected_key_step);
+ cached_value.transpose("1:0:2", projected_value_step);
+
+ projected_query_step.reshape(
+ TensorDim({batch_size * num_heads, 1, to, projected_query_dim_prop}));
+ projected_key_step.reshape(
+ TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
+ projected_value_step.reshape(
+ TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+ attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, to, to}));
+ attention_output_step.reshape(
+ TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
+
+ /** scaled dot product attention */
+ projected_query_step.dotBatched(projected_key_step, attention_weight_step,
+ false, true);
+ attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
+
+ if (!from) {
+ unsigned int mask_size = attention_weight_step.getDim().width();
+ unsigned int mask_dim_height = mask_size;
+ unsigned int mask_dim_width = mask_size;
+
+ Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
+ attention_weight_step.getTensorType()});
+
+ causal_mask.setZero();
+
+#ifdef ENABLE_FP16
+#define _MASK_NUM -1e4
+#else
+#define _MASK_NUM -1e10
+#endif
+
+ for (unsigned int i = 0; i < mask_dim_height; ++i) {
+ for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
+ causal_mask.setValue(0, 0, i, j, _MASK_NUM);
+ }
+ }
+
+ attention_weight_step.add_i(causal_mask);
+ }
+
+ sm.run_fn(attention_weight_step, attention_weight_step);
+
+ attention_weight_step.dotBatched(projected_value_step, attention_output_step);
+
+ attention_output_step.reshape(
+ TensorDim({batch_size, num_heads, to, projected_value_dim_prop}));
+
+ attention_output_step = attention_output_step.transpose("1:0:2");
+
+ attention_output_step.reshape(
+ TensorDim({batch_size * to, 1, 1, num_heads * projected_value_dim_prop}));
+
+ attention_output_step.dot(fc_weight, output_step);
+ if (!disable_bias) {
+ output_step.add_i(fc_bias);
+ }
+
+ if (layer_progress == 28)
+ layer_progress = 0;
+ layer_progress++;
+
+ std::cout << "Process Reading: " << (int)((layer_progress / 28.0) * 100.0)
+ << " % \r";
+ std::cout.flush();
+}
+
void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
unsigned int _from,
unsigned int _to,
bool training) {
+ if (!_from) {
+ initial_incremental_forwarding(context, _from, _to, training);
+ return;
+ }
+
unsigned int max_timestep =
std::get<props::MaxTimestep>(multi_head_attention_props).get();
TensorDim key_dim = key.getDim();
TensorDim value_dim = value.getDim();
+ TensorDim query_step_dim = query_dim;
+ TensorDim key_step_dim = key_dim;
+ TensorDim value_step_dim = value_dim;
+
+ query_step_dim.height(to - from);
+ key_step_dim.height(to - from);
+ value_step_dim.height(to - from);
+
+ Tensor query_step = query.getSharedDataTensor(query_step_dim, 0, true);
+ Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
+ Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
+
Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
TensorDim output_dim = output.getDim();
+
+ TensorDim output_step_dim = output_dim;
+ output_step_dim.height(to - from);
+ Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
+
Tensor &ret_attention_weight =
return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
const unsigned int key_height = key_dim.height();
const unsigned int value_height = value_dim.height();
- query.dot(query_fc_weight, projected_query_step);
+ query_step.dot(query_fc_weight, projected_query_step);
+
if (!disable_bias) {
projected_query_step.add_i(query_fc_bias);
}
- key.dot(key_fc_weight, cache_key_step);
+ key_step.dot(key_fc_weight, cache_key_step);
if (!disable_bias) {
cache_key_step.add_i(key_fc_bias);
}
- value.dot(value_fc_weight, cache_value_step);
+ value_step.dot(value_fc_weight, cache_value_step);
if (!disable_bias) {
cache_value_step.add_i(value_fc_bias);
}
attention_output_step.reshape(TensorDim(
{batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
- attention_output_step.dot(fc_weight, output);
+ attention_output_step.dot(fc_weight, output_step);
if (!disable_bias) {
- output.add_i(fc_bias);
+ output_step.add_i(fc_bias);
}
if (cache_shift) {
*/
void forwarding(RunLayerContext &context, bool training) override;
+ /**
+ * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+ * int from, unsigned int to, bool training)
+ */
+ void initial_incremental_forwarding(RunLayerContext &context,
+ unsigned int from, unsigned int to,
+ bool training);
+
/**
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
unsigned int cache_index;
- inline static std::vector<std::vector<float>> *freqs_cos = {};
+ inline static unsigned int layer_progress;
+ inline static std::vector<std::vector<float>> *freqs_cos = {};
inline static std::vector<std::vector<float>> *freqs_sin = {};
+ inline static std::vector<float> freqs;
/**
* @brief compute frequency for rotary embedding
void precompute_freqs(int dim, unsigned int seq_len, float theta = 10000.0) {
if (freqs_cos == nullptr) {
unsigned int half_ = dim / 2;
- std::vector<float> freqs(half_);
for (unsigned int i = 0; i < half_; ++i) {
- freqs[i] = 1.0 / (std::pow(theta, (2 * i) / static_cast<float>(dim)));
+ freqs.push_back(1.0 /
+ (std::pow(theta, (2 * i) / static_cast<float>(dim))));
}
auto cos = new std::vector<std::vector<float>>();
float value = 0;
float transformed_value = 0.0;
unsigned int half_ = dim / 2;
+ unsigned int max_timestep =
+ std::get<props::MaxTimestep>(multi_head_attention_props).get();
+
+ std::vector<float> *cos_;
+ std::vector<float> *sin_;
+
+ if (from >= max_timestep) {
+ std::cout << from << " " << max_timestep << std::endl;
+ cos_ = new std::vector<float>(dim);
+ sin_ = new std::vector<float>(dim);
+
+ for (unsigned int i = 0; i < half_; ++i) {
+ float angle = from * freqs[i];
+ (*cos_)[i] = std::cos(angle);
+ (*cos_)[i + half_] = std::cos(angle); // repeated 2 times
+
+ (*sin_)[i] = std::sin(angle);
+ (*sin_)[i + half_] = std::sin(angle); // repeated 2 times
+ }
+ }
if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
for (unsigned int b = 0; b < in.batch(); b++) {
for (unsigned int c = 0; c < in.channel(); c++) {
for (unsigned int h = 0; h < in.height(); h++) {
+ if (from < max_timestep) {
+ cos_ = &(*freqs_cos)[from + h];
+ sin_ = &(*freqs_sin)[from + h];
+ }
+
for (unsigned int w = 0; w < in.width(); w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
unsigned int span = w + k;
for (unsigned int b = 0; b < in.batch(); b++) {
for (unsigned int c = 0; c < in.channel(); c++) {
for (unsigned int h = 0; h < in.height(); h++) {
+ if (from < max_timestep) {
+ cos_ = &(*freqs_cos)[from + h];
+ sin_ = &(*freqs_sin)[from + h];
+ }
+
for (unsigned int w = 0; w < in.width(); w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
#ifdef ENABLE_FP16
unsigned int from, unsigned int to,
bool training) {
if (!context.executeInPlace()) {
+ if (from) {
+ NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+ << "incremental step size is not 1";
+ from = 0;
+ to = 1;
+ }
+
const Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
TensorDim input_dim = input_.getDim();
TensorDim input_step_dim = {input_dim.batch(), input_dim.channel(),
to - from, input_dim.width()};
- Tensor input_step = input_.getSharedDataTensor(
- input_step_dim, from * input_dim.width(), true);
+ Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
for (unsigned int idx = 0; idx < context.getNumOutputs(); ++idx) {
Tensor &output = context.getOutput(idx);
to - from, output_dim.width()};
// @todo: set reset stride as false. This implementation only works when
// batch size is 1
- Tensor output_step = output.getSharedDataTensor(
- output_step_dim, from * output_dim.width(), true);
+ Tensor output_step = output.getSharedDataTensor(output_step_dim, 0, true);
output_step.fill(input_step);
}
}
return output;
}
-sharedConstTensors NeuralNetwork::incremental_inference(
- sharedConstTensors X, unsigned int init_seq_len, unsigned int cur_step) {
- return incremental_inference(X, {}, init_seq_len, cur_step);
+sharedConstTensors
+NeuralNetwork::incremental_inference(sharedConstTensors X,
+ unsigned int init_seq_len,
+ unsigned int from, unsigned int to) {
+ return incremental_inference(X, {}, init_seq_len, from, to);
}
sharedConstTensors NeuralNetwork::incremental_inference(
sharedConstTensors X, sharedConstTensors label, unsigned int init_seq_len,
- unsigned int cur_step) {
+ unsigned int from, unsigned int to) {
if (model_graph.getBatchSize() != X[0]->batch()) {
model_graph.setBatchSize(X[0]->batch());
}
if (!validateInput(X))
throw std::invalid_argument("Input validation failed.");
- if (cur_step == 0) {
+ if (from == 0) {
allocate(ExecutionMode::INFERENCE);
}
PROFILE_TIME_REGISTER_EVENT(nn_foward, "nn_forward");
PROFILE_TIME_START(nn_foward);
- out = incremental_forwarding(cur_step, cur_step + 1, X, label, false);
+ out = incremental_forwarding(from, to, X, label, false);
PROFILE_TIME_END(nn_foward);
std::vector<float *> NeuralNetwork::incremental_inference(
unsigned int batch_size, const std::vector<float *> &input,
const std::vector<float *> &label, unsigned int init_seq_len,
- unsigned int cur_step) {
+ unsigned int from, unsigned int to) {
sharedConstTensors input_tensors, output_tensors;
auto in_dim = getInputDimension();
label_dim[idx], 0)));
}
output_tensors = incremental_inference(input_tensors, label_tensors,
- init_seq_len, cur_step);
+ init_seq_len, from, to);
} else {
output_tensors =
- incremental_inference(input_tensors, init_seq_len, cur_step);
+ incremental_inference(input_tensors, init_seq_len, from, to);
}
std::vector<float *> output;
- output.reserve(output_tensors.size());
unsigned int idx = 0;
+ if (!from) {
+ idx = to - 1;
+ }
for (auto &out : output_tensors) {
if (out->getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
auto out_t = *out.get();
- _FP16 *vec_fp16 = out_t.getData<_FP16>();
- float *vec_fp32 = new float[out_t.size()]();
- output.push_back(vec_fp32);
- for (unsigned int i = 0; i < out_t.size(); ++i) {
- output[idx][i] = static_cast<float>(vec_fp16[i]);
+ float *vec_fp32 = new float[out_t.width()];
+ for (unsigned int i = 0; i < out_t.width(); ++i) {
+ (vec_fp32)[i] = static_cast<float>(out_t.getValue<_FP16>(0, 0, idx, i));
}
+ output.emplace_back(vec_fp32);
#else
throw std::invalid_argument("Errro: enable-fp16 is not set");
#endif
*/
sharedConstTensors incremental_inference(sharedConstTensors X,
unsigned int init_seq_len,
- unsigned int step);
+ unsigned int from, unsigned int to);
/**
* @brief Run NeuralNetwork incremental inference
sharedConstTensors incremental_inference(sharedConstTensors X,
sharedConstTensors label,
unsigned int init_seq_len,
- unsigned int step);
+ unsigned int from, unsigned int to);
/**
* @brief Run the incremental inference of the model
const std::vector<float *> &input,
const std::vector<float *> &label,
unsigned int init_seq_len,
- unsigned int step) override;
+ unsigned int from,
+ unsigned int to) override;
/**
* @brief Run NeuralNetwork train with callback function by user