/** set weight specifications */
// @todo : This NCHW format setting is just temporal, it needs to be set by
// global configuration
+
+ /** Bias Dimension : (1, 1, 1, unit) */
TensorDim bias_dim(
1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0001 : 0b0100);
+ /** Weight Dimension : (1, 1, in_dim.width(), unit)*/
TensorDim weight_dim(
1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1,
is_nchw ? unit : in_dim.channel(),
/** create weights for LoRA */
if (lora_rank) {
- /** loraA : (in_dim.width, lora_rank) */
+ /** loraA Dimension : (1, 1, in_dim.width, lora_rank) */
TensorDim loraA_dim(
1, is_nchw ? 1 : lora_rank, is_nchw ? in_dim.width() : 1,
is_nchw ? lora_rank : in_dim.channel(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);
- /** loraB: (lora_rank, out_dim) */
+ /** loraB Dimension : (1, 1, lora_rank, unit) */
TensorDim loraB_dim(
1, is_nchw ? 1 : unit, is_nchw ? lora_rank : 1,
is_nchw ? unit : lora_rank,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);
- /** loraTmp: (1, lora_rank) */
+ /** loraTmp Dimension : (B, 1, in_dim.height(), lora_rank) */
TensorDim loraTmp_dim(
- 1, is_nchw ? 1 : lora_rank, 1, is_nchw ? lora_rank : 1,
+ in_dim.batch(), is_nchw ? 1 : lora_rank, is_nchw ? in_dim.height() : 1,
+ is_nchw ? lora_rank : in_dim.width(),
+ TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
+ is_nchw ? 0b1011 : 0b1101);
+
+ /** loraTmp Dimension : (B, 1, in_dim.height(), unit) */
+ TensorDim loraOut_dim(
+ in_dim.batch(), is_nchw ? 1 : unit, is_nchw ? in_dim.height() : 1,
+ is_nchw ? unit : in_dim.width(),
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
- is_nchw ? 0b0001 : 0b0100);
+ is_nchw ? 0b1011 : 0b1101);
lora_idx[LORAParams::loraA] = context.requestWeight(
loraA_dim, Initializer::ZEROS, weight_regularizer,
lora_idx[LORAParams::loraTmp] =
context.requestTensor(loraTmp_dim, "hidden_tmp_lora", Initializer::NONE,
- true, TensorLifespan::FORWARD_DERIV_LIFESPAN);
+ true, TensorLifespan::FORWARD_GRAD_LIFESPAN);
lora_idx[LORAParams::loraOut] =
- context.requestTensor(bias_dim, "hidden_lora", Initializer::NONE, true,
+ context.requestTensor(loraOut_dim, "hidden_lora", Initializer::NONE, true,
TensorLifespan::FORWARD_FUNC_LIFESPAN);
}
}
LayerImpl::setProperty(remain_props);
}
+void FullyConnectedLayer::setBatch(nntrainer::RunLayerContext &context,
+ unsigned int batch) {
+ if (!std::get<props::LoraRank>(fc_props).empty()) {
+ // update Lora Tensor's batch info.
+ context.updateTensor(lora_idx[LORAParams::loraTmp], batch);
+ context.updateTensor(lora_idx[LORAParams::loraOut], batch);
+ }
+}
+
void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);