)
# test split example
-test('nntrainer_product_ratings', e, args: ['train',
+test('app_product_ratings', e, args: ['train',
nntr_app_resdir / 'ProductRatings' / 'product_ratings_model.ini',
nntr_app_resdir / 'ProductRatings' / 'sample_product_ratings.txt']
)
* together and all the dimensions after the split_dimension to faciliate
* easier splitting of the data.
*/
- input_reshape_helper = {1, 1, 1, 1};
- for (unsigned int idx = 0; idx < split_dimension; ++idx) {
- input_reshape_helper.batch(input_reshape_helper.batch() *
- in_dim.getTensorDim(idx));
+ leading_helper_dim = 1;
+ input_reshape_helper.channel(1);
+ input_reshape_helper.height(1);
+ input_reshape_helper.width(1);
+ for (unsigned int idx = 1; idx < split_dimension; ++idx) {
+ leading_helper_dim *= in_dim.getTensorDim(idx);
}
input_reshape_helper.height(in_dim.getTensorDim(split_dimension));
output_reshape_helper = input_reshape_helper;
output_reshape_helper.height(1);
+ setBatch(in_dim.batch());
+
return status;
}
+void SplitLayer::setBatch(unsigned int batch) {
+ Layer::setBatch(batch);
+ input_reshape_helper.batch(batch * leading_helper_dim);
+ output_reshape_helper.batch(batch * leading_helper_dim);
+}
+
void SplitLayer::forwarding(bool training) {
Tensor &input_ = net_input[0]->getVariableRef();
case PropertyType::split_dimension: {
if (!value.empty()) {
status = setUint(split_dimension, value);
+ NNTR_THROW_IF(split_dimension == 0, std::invalid_argument)
+ << "[Split] Batch dimension cannot be split dimension";
throw_status(status);
}
} break;
*/
void calcDerivative() override;
+ /**
+ * @copydoc Layer::setBatch(unsigned int batch)
+ */
+ void setBatch(unsigned int batch) override;
+
using Layer::setProperty;
/**
private:
unsigned int split_dimension; /** dimension along which to split the input */
+ unsigned int leading_helper_dim; /**< batch dimension of helper dimension not
+ containing the actual batch */
TensorDim input_reshape_helper; /** helper dimension to reshape input */
TensorDim output_reshape_helper; /** helper dimension to reshape outputs */
};
}
void Tensor::reshape(const TensorDim &d) {
- if (d.getDataLen() != dim.getDataLen()) {
- throw std::invalid_argument("Error: reshape cannot change the tensor size");
- }
+ NNTR_THROW_IF(d.getDataLen() != dim.getDataLen(), std::invalid_argument)
+ << "[Tensor]: reshape cannot change the buffer size, trying reshaping "
+ "\nfrom "
+ << getDim() << " to " << d;
dim = d;
strides = d.computeStrides();