From: jijoong.moon Date: Tue, 27 Oct 2020 11:01:52 +0000 (+0900) Subject: [ Layer ] Multiple Input Dimensions X-Git-Tag: submit/tizen/20201119.063013~49 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16c816d46d32dba8610f13b480c1ac18b4b0212f;p=platform%2Fcore%2Fml%2Fnntrainer.git [ Layer ] Multiple Input Dimensions Current implementaion only take one input. In order to take multiple input, input_dim / output_dim should be vector type. This PR includes this fixes except about addition layer which requires actual multiple input. This will be done consequtive PR. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- diff --git a/nnstreamer/tensor_filter/tensor_filter_nntrainer.cc b/nnstreamer/tensor_filter/tensor_filter_nntrainer.cc index 09c09741..cd61a269 100644 --- a/nnstreamer/tensor_filter/tensor_filter_nntrainer.cc +++ b/nnstreamer/tensor_filter/tensor_filter_nntrainer.cc @@ -118,9 +118,9 @@ void NNTrainer::validateTensor(const GstTensorsInfo *tensorInfo, unsigned int order[3] = {1, 3, 2}; if (is_input) - dim = model->getInputDimension(); + dim = model->getInputDimension()[0]; else - dim = model->getOutputDimension(); + dim = model->getOutputDimension()[0]; if (tensorInfo->info[0].type != _NNS_FLOAT32) throw std::invalid_argument( diff --git a/nntrainer/include/layer_internal.h b/nntrainer/include/layer_internal.h index 04f9c420..83080d89 100644 --- a/nntrainer/include/layer_internal.h +++ b/nntrainer/include/layer_internal.h @@ -83,7 +83,10 @@ public: trainable(trainable_), num_weights(0), num_inputs(1), - num_outputs(1) {} + num_outputs(1) { + input_dim.resize(1); + output_dim.resize(1); + } /** * @brief Move constructor of Layer. @@ -192,13 +195,13 @@ public: * @brief Get the output dimension * @return TensorDim dimension of the output */ - TensorDim getOutputDimension() { return output_dim; } + std::vector getOutputDimension() { return output_dim; } /** * @brief Get the input dimension * @return TensorDim dimension of the input */ - TensorDim getInputDimension() { return input_dim; } + std::vector getInputDimension() { return input_dim; } /** * @brief get the loss value added by this layer @@ -311,12 +314,12 @@ protected: /** * @brief Dimension of input activation */ - TensorDim input_dim; + std::vector input_dim; /** * @brief Dimension of output activation */ - TensorDim output_dim; + std::vector output_dim; /** * @brief Optimizer for this layer @@ -498,7 +501,7 @@ private: * @brief Set the input dimension * @param[in] d dimension to be set */ - void setInputDimension(TensorDim d) { input_dim = d; } + void setInputDimension(std::vector d) { input_dim = d; } }; /** diff --git a/nntrainer/include/neuralnet.h b/nntrainer/include/neuralnet.h index 56cb6cc4..c1c3149a 100644 --- a/nntrainer/include/neuralnet.h +++ b/nntrainer/include/neuralnet.h @@ -279,15 +279,19 @@ public: /* * @brief get input dimension of neural network - * @retval TensorDim input dimension + * @retval std::vector input dimension */ - TensorDim getInputDimension() { return layers[0]->getInputDimension(); } + std::vector getInputDimension() { + return layers[0]->getInputDimension(); + } /* * @brief get output dimension of neural network - * @retval TensorDim output dimension + * @retval std::vector output dimension */ - TensorDim getOutputDimension() { return layers.back()->getOutputDimension(); } + std::vector getOutputDimension() { + return layers.back()->getOutputDimension(); + } /** * @brief get FlatGraph of current graph diff --git a/nntrainer/include/tensor_dim.h b/nntrainer/include/tensor_dim.h index a996d22f..42e0429b 100644 --- a/nntrainer/include/tensor_dim.h +++ b/nntrainer/include/tensor_dim.h @@ -99,6 +99,8 @@ public: bool isEmpty() const { return len == 0; } unsigned int rank() const; + unsigned int &operator[](unsigned int index); + /** * @brief Calculate standard strides * diff --git a/nntrainer/src/addition_layer.cpp b/nntrainer/src/addition_layer.cpp index 85bef1eb..6ef73f07 100644 --- a/nntrainer/src/addition_layer.cpp +++ b/nntrainer/src/addition_layer.cpp @@ -27,25 +27,29 @@ int AdditionLayer::initialize() { return ML_ERROR_INVALID_PARAMETER; } - if (input_dim.getDataLen() == 1) { - ml_logw("Warning: the length of previous layer dimension is one"); + for (unsigned int idx = 0; idx < num_inputs; ++idx) { + if (input_dim[idx].getDataLen() == 1) { + ml_logw("Warning: the length of previous layer dimension is one"); + } } /** input dimension indicates the dimension for all the inputs to follow */ - output_dim = input_dim; + output_dim[0] = input_dim[0]; return status; } sharedConstTensors AdditionLayer::forwarding(sharedConstTensors in) { - hidden = Tensor(input_dim); + hidden = Tensor(input_dim[0]); hidden.setZero(); + TensorDim &in_dim = input_dim[0]; + for (unsigned int idx = 0; idx < num_inputs; ++idx) { - if (input_dim != in[0].get()[idx].getDim()) + if (in_dim != in[idx]->getDim()) throw std::runtime_error("Error: addition layer requires same " "shape from all input layers"); - hidden.add_i(in[0].get()[idx]); + hidden.add_i(*in[idx]); } return {MAKE_SHARED_TENSOR(hidden)}; @@ -53,15 +57,16 @@ sharedConstTensors AdditionLayer::forwarding(sharedConstTensors in) { sharedConstTensors AdditionLayer::backwarding(sharedConstTensors derivative, int iteration) { - sharedTensor ret = std::shared_ptr(new Tensor[num_inputs], - std::default_delete()); - for (unsigned int idx = 0; idx < num_inputs; ++idx) { - Tensor &t = ret.get()[idx]; - t = *derivative[0]; + sharedConstTensors ret; + for (unsigned int i = 0; i < num_inputs; ++i) { + sharedTensor t = + std::shared_ptr(new Tensor(), std::default_delete()); + *t = *derivative[0]; + ret.push_back(t); } - return {ret}; + return ret; } void AdditionLayer::setProperty(const PropertyType type, diff --git a/nntrainer/src/bn_layer.cpp b/nntrainer/src/bn_layer.cpp index d1b19073..d5075e3e 100644 --- a/nntrainer/src/bn_layer.cpp +++ b/nntrainer/src/bn_layer.cpp @@ -38,14 +38,19 @@ enum class BNParams { mu, var, gamma, beta }; int BatchNormalizationLayer::initialize() { int status = ML_ERROR_NONE; - output_dim = input_dim; + if (num_inputs != 1) { + throw std::invalid_argument( + "Only one input is allowed for batch normalization layer"); + } + + output_dim[0] = input_dim[0]; TensorDim dim; /// @note this logic cannot tell channel is actually 1 or it is just not used. if (axis == -1) - axis = input_dim.channel() > 1 ? 1 : 3; + axis = input_dim[0].channel() > 1 ? 1 : 3; - dim.setTensorDim(axis, input_dim.getTensorDim(axis)); + dim.setTensorDim(axis, input_dim[0].getTensorDim(axis)); for (int i = 0; i < 4; ++i) { if (axis != i) @@ -161,7 +166,7 @@ BatchNormalizationLayer::backwarding(sharedConstTensors derivative, int N = 1; for (auto &axis : axes_to_reduce) { - N *= input_dim.getTensorDim(axis); + N *= input_dim[0].getTensorDim(axis); } dbeta = deriv.sum(axes_to_reduce); diff --git a/nntrainer/src/conv2d_layer.cpp b/nntrainer/src/conv2d_layer.cpp index 436cb92b..2bcc149f 100644 --- a/nntrainer/src/conv2d_layer.cpp +++ b/nntrainer/src/conv2d_layer.cpp @@ -27,12 +27,19 @@ namespace nntrainer { int Conv2DLayer::initialize() { int status = ML_ERROR_NONE; - if (input_dim.getDataLen() == 1) { + if (input_dim.size() != 1 || output_dim.size() != 1) { + throw std::invalid_argument("Convolution layer only takes one input"); + } + + TensorDim &in_dim = input_dim[0]; + TensorDim &out_dim = output_dim[0]; + + if (in_dim.getDataLen() == 1) { ml_logw("Warning: the length of previous layer dimension is one"); } TensorDim dim = - TensorDim(1, input_dim.channel(), kernel_size[0], kernel_size[1]); + TensorDim(1, in_dim.channel(), kernel_size[0], kernel_size[1]); TensorDim bias_dim = TensorDim(1, 1, 1, 1); std::string kernelPrefix = "Conv2d:filter"; @@ -50,12 +57,12 @@ int Conv2DLayer::initialize() { } // this output_dim should be the same with dimension of hidden - output_dim.batch(input_dim.batch()); - output_dim.channel(filter_size); - output_dim.height( - (input_dim.height() - kernel_size[0] + 2 * padding[0]) / stride[0] + 1); - output_dim.width( - (input_dim.width() - kernel_size[1] + 2 * padding[1]) / stride[1] + 1); + out_dim.batch(in_dim.batch()); + out_dim.channel(filter_size); + out_dim.height( + (in_dim.height() - kernel_size[0] + 2 * padding[0]) / stride[0] + 1); + out_dim.width((in_dim.width() - kernel_size[1] + 2 * padding[1]) / stride[1] + + 1); return status; } @@ -66,8 +73,15 @@ void Conv2DLayer::save(std::ofstream &file) { Layer::save(file); } sharedConstTensors Conv2DLayer::forwarding(sharedConstTensors in) { int status = ML_ERROR_NONE; + + if (num_inputs != 1) + throw std::invalid_argument("Convolution layer only takes one input"); + input = *in[0]; + TensorDim &in_dim = input_dim[0]; + TensorDim &out_dim = output_dim[0]; + if (normalization) { input = input.normalization(); } @@ -76,7 +90,7 @@ sharedConstTensors Conv2DLayer::forwarding(sharedConstTensors in) { input = input.standardization(); } - TensorDim hidden_dim = output_dim; + TensorDim hidden_dim = output_dim[0]; hidden = Tensor(hidden_dim); hidden.setZero(); @@ -117,8 +131,7 @@ sharedConstTensors Conv2DLayer::forwarding(sharedConstTensors in) { * x [output_dim.height x output_dim.width] */ - TensorDim kdim(filter_size, input_dim.channel(), kernel_size[0], - kernel_size[1]); + TensorDim kdim(filter_size, in_dim.channel(), kernel_size[0], kernel_size[1]); std::vector imkernel(kdim.getFeatureLen() * filter_size); @@ -129,13 +142,13 @@ sharedConstTensors Conv2DLayer::forwarding(sharedConstTensors in) { kdim.getFeatureLen() * sizeof(float)); } - for (unsigned int b = 0; b < input_dim.batch(); ++b) { - std::vector out(output_dim.getFeatureLen()); + for (unsigned int b = 0; b < in_dim.batch(); ++b) { + std::vector out(out_dim.getFeatureLen()); Tensor inSub(TensorDim(1, input.channel(), input.height(), input.width()), input.getAddress(b * input.getDim().getFeatureLen())); - status = conv2d_gemm(imkernel.data(), kdim, inSub, output_dim, stride, - padding, out.data(), out.size(), true); + status = conv2d_gemm(imkernel.data(), kdim, inSub, out_dim, stride, padding, + out.data(), out.size(), true); if (status != ML_ERROR_NONE) throw std::runtime_error("Forwarding Convolution failed."); memcpy(hidden.getAddress(b * hidden.getDim().getFeatureLen()), out.data(), @@ -167,6 +180,8 @@ sharedConstTensors Conv2DLayer::forwarding(sharedConstTensors in) { sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, int iteration) { + TensorDim &in_dim = input_dim[0]; + std::array same_pad; sharedConstTensor derivative = derivatives[0]; @@ -180,9 +195,8 @@ sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, delBias.setZero(); } - Tensor ret(input_dim.batch(), input_dim.channel(), - input_dim.height() + padding[0] * 2, - input_dim.width() + padding[1] * 2); + Tensor ret(in_dim.batch(), in_dim.channel(), in_dim.height() + padding[0] * 2, + in_dim.width() + padding[1] * 2); ret.setZero(); /** Calculate DelK @@ -221,12 +235,12 @@ sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, */ int status = ML_ERROR_NONE; - for (unsigned int b = 0; b < input_dim.batch(); ++b) { - std::vector out(kernel_size[0] * kernel_size[1] * - input_dim.channel() * filter_size); + for (unsigned int b = 0; b < in_dim.batch(); ++b) { + std::vector out(kernel_size[0] * kernel_size[1] * in_dim.channel() * + filter_size); Tensor inSub( - TensorDim(1, input_dim.channel(), input_dim.height(), input_dim.width()), + TensorDim(1, in_dim.channel(), in_dim.height(), in_dim.width()), input.getAddress(b * input.getDim().getFeatureLen())); status = conv2d_gemm( @@ -235,7 +249,7 @@ sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, derivative->width()), inSub, TensorDim(1, 1, filter_size, - kernel_size[0] * kernel_size[1] * input_dim.channel()), + kernel_size[0] * kernel_size[1] * in_dim.channel()), stride, padding, out.data(), out.size(), false); if (status != ML_ERROR_NONE) throw std::runtime_error("Backwarding Convolution failed."); @@ -244,7 +258,7 @@ sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, Tensor &delK = weightAt(i).getGradientRef(); Tensor &delBias = weightAt(i + filter_size).getGradientRef(); float *del = delK.getData(); - unsigned int s = kernel_size[0] * kernel_size[1] * input_dim.channel(); + unsigned int s = kernel_size[0] * kernel_size[1] * in_dim.channel(); for (unsigned int k = 0; k < s; ++k) { del[k] += out[i * s + k]; @@ -318,11 +332,11 @@ sharedConstTensors Conv2DLayer::backwarding(sharedConstTensors derivatives, } } - TensorDim input_dim_padded(1, input_dim.channel(), - input_dim.height() + padding[0] * 2, - input_dim.width() + padding[1] * 2); + TensorDim input_dim_padded(1, in_dim.channel(), + in_dim.height() + padding[0] * 2, + in_dim.width() + padding[1] * 2); - for (unsigned int b = 0; b < input_dim.batch(); ++b) { + for (unsigned int b = 0; b < in_dim.batch(); ++b) { Tensor inSub( TensorDim(1, derivative->channel(), derivative->height(), derivative->width()), diff --git a/nntrainer/src/fc_layer.cpp b/nntrainer/src/fc_layer.cpp index 85e5f9f2..e194980f 100644 --- a/nntrainer/src/fc_layer.cpp +++ b/nntrainer/src/fc_layer.cpp @@ -36,14 +36,18 @@ enum class FCParams { weight, bias }; int FullyConnectedLayer::initialize() { int status = ML_ERROR_NONE; - output_dim = input_dim; - output_dim.width(unit); + if (num_inputs != 1) { + throw std::invalid_argument("Fully connected layer takes only one input"); + } + + output_dim[0] = input_dim[0]; + output_dim[0].width(unit); TensorDim bias_dim = TensorDim(); bias_dim.setTensorDim(3, unit); - TensorDim dim = output_dim; - dim.height(input_dim.width()); + TensorDim dim = output_dim[0]; + dim.height(input_dim[0].width()); dim.batch(1); setNumWeights(2); @@ -61,7 +65,7 @@ void FullyConnectedLayer::setProperty(const PropertyType type, if (!value.empty()) { status = setUint(unit, value); throw_status(status); - output_dim.width(unit); + output_dim[0].width(unit); } } break; default: diff --git a/nntrainer/src/flatten_layer.cpp b/nntrainer/src/flatten_layer.cpp index bfd9345d..c0f794be 100644 --- a/nntrainer/src/flatten_layer.cpp +++ b/nntrainer/src/flatten_layer.cpp @@ -21,15 +21,20 @@ namespace nntrainer { int FlattenLayer::initialize() { + if (num_inputs != 1) { + throw std::invalid_argument("input_shape keyword is only for one input"); + } + + TensorDim &out_dim = output_dim[0]; int status = ML_ERROR_NONE; - if (input_dim.getDataLen() == 1) { + if (input_dim[0].getDataLen() == 1) { ml_logw("Warning: the length of previous layer dimension is one"); } - output_dim.batch(input_dim.batch()); - output_dim.channel(1); - output_dim.height(1); - output_dim.width(input_dim.getFeatureLen()); + out_dim.batch(input_dim[0].batch()); + out_dim.channel(1); + out_dim.height(1); + out_dim.width(input_dim[0].getFeatureLen()); return status; } @@ -38,7 +43,7 @@ sharedConstTensors FlattenLayer::forwarding(sharedConstTensors in) { input = *in[0]; hidden = input; - hidden.reshape(output_dim); + hidden.reshape(output_dim[0]); return {MAKE_SHARED_TENSOR(hidden)}; } @@ -46,7 +51,7 @@ sharedConstTensors FlattenLayer::forwarding(sharedConstTensors in) { sharedConstTensors FlattenLayer::backwarding(sharedConstTensors in, int iteration) { Tensor temp = *in[0]; - temp.reshape(input_dim); + temp.reshape(input_dim[0]); return {MAKE_SHARED_TENSOR(std::move(temp))}; } diff --git a/nntrainer/src/layer.cpp b/nntrainer/src/layer.cpp index 4438e0ad..9280da19 100644 --- a/nntrainer/src/layer.cpp +++ b/nntrainer/src/layer.cpp @@ -58,8 +58,11 @@ int Layer::checkValidation() { } void Layer::setBatch(unsigned int batch) { - input_dim.setTensorDim(0, batch); - output_dim.setTensorDim(0, batch); + for (unsigned int idx = 0; idx < num_inputs; ++idx) + input_dim[idx].setTensorDim(0, batch); + + for (unsigned int idx = 0; idx < num_outputs; ++idx) + output_dim[idx].setTensorDim(0, batch); } void Layer::copy(std::shared_ptr l) { @@ -134,25 +137,30 @@ void Layer::setProperty(const PropertyType type, const std::string &value) { throw_status(status); } break; - case PropertyType::input_shape: + case PropertyType::input_shape: { + if (num_inputs != 1) { + throw std::invalid_argument("input_shape keyword is only for one input"); + } + + TensorDim &in_dim = input_dim[0]; if (!value.empty()) { unsigned int cache_batch_size = 1; /** cache original value of batch size */ - if (input_dim.batch()) { - cache_batch_size = input_dim.batch(); - input_dim.batch(1); + if (in_dim.batch()) { + cache_batch_size = in_dim.batch(); + in_dim.batch(1); } - status = input_dim.setTensorDim(value.c_str()); - if (input_dim.batch() > 1) { + status = in_dim.setTensorDim(value.c_str()); + if (in_dim.batch() > 1) { ml_logw("Batch size set with input dimension %d is ignored." "Set batchsize property for the model to update batchsize.", - input_dim.batch()); + in_dim.batch()); } /** set back to cache value of dimension */ - input_dim.batch(cache_batch_size); + in_dim.batch(cache_batch_size); throw_status(status); } - break; + } break; case PropertyType::activation: if (!value.empty()) { setActivation((ActivationType)parseType(value, TOKEN_ACTI)); @@ -219,10 +227,14 @@ void Layer::printIfValid(std::ostream &out, const PropertyType type, } void Layer::printShapeInfo(std::ostream &out) { - out << "input " << input_dim; - for (unsigned int i = 0; i < num_weights; i++) - out << "inner" << i << " " << weightAt(i).var.getDim(); - out << "output " << output_dim; + for (unsigned int idx = 0; idx < num_inputs; ++idx) { + out << "input " << input_dim[idx]; + for (unsigned int i = 0; i < num_weights; i++) + out << "inner" << i << " " << weightAt(i).var.getDim(); + } + for (unsigned int idx = 0; idx < num_outputs; ++idx) { + out << "output " << output_dim[idx]; + } } void Layer::printPropertiesMeta(std::ostream &out) { diff --git a/nntrainer/src/neuralnet.cpp b/nntrainer/src/neuralnet.cpp index 75856912..37c53b66 100644 --- a/nntrainer/src/neuralnet.cpp +++ b/nntrainer/src/neuralnet.cpp @@ -184,7 +184,7 @@ int NeuralNetwork::setTrainConfig(std::vector values) { int NeuralNetwork::init() { int status = ML_ERROR_NONE; - TensorDim previous_dim; + std::vector previous_dim; status = isInitializable(); NN_RETURN_STATUS(); @@ -203,7 +203,7 @@ int NeuralNetwork::init() { ml_loge("double activation is not allowed"); return ML_ERROR_INVALID_PARAMETER; } - if (l.getInputDimension().isEmpty()) { + if (l.getInputDimension().size()) { l.setInputDimension(previous_dim); } else if (previous_dim != l.getInputDimension()) { ml_loge("Dimension mismatch between layers."); @@ -412,10 +412,10 @@ int NeuralNetwork::train(std::vector values) { setBatchSize(batch_size); /** Setup data buffer properties */ - status = data_buffer->setClassNum(getOutputDimension().width()); + status = data_buffer->setClassNum(getOutputDimension()[0].width()); NN_RETURN_STATUS(); - status = data_buffer->setFeatureSize(layers[0]->getInputDimension()); + status = data_buffer->setFeatureSize(layers[0]->getInputDimension()[0]); NN_RETURN_STATUS(); status = data_buffer->init(); @@ -448,8 +448,8 @@ int NeuralNetwork::train_run() { int count = 0; - sharedTensor in = MAKE_SHARED_TENSOR(getInputDimension()); - sharedTensor label = MAKE_SHARED_TENSOR(getOutputDimension()); + sharedTensor in = MAKE_SHARED_TENSOR(getInputDimension()[0]); + sharedTensor label = MAKE_SHARED_TENSOR(getOutputDimension()[0]); while (true) { if (data_buffer->getDataFromBuffer(nntrainer::BufferType::BUF_TRAIN, @@ -534,7 +534,7 @@ int NeuralNetwork::isInitializable() { Layer &l = *layers[0]; /** Dimension of first layer must be known */ - if (l.getInputDimension().isEmpty()) { + if (l.getInputDimension().size() == 0) { ml_loge("InputDimension of first layer is not set"); return ML_ERROR_INVALID_PARAMETER; } diff --git a/nntrainer/src/pooling2d_layer.cpp b/nntrainer/src/pooling2d_layer.cpp index 8d81efcb..e9edec08 100644 --- a/nntrainer/src/pooling2d_layer.cpp +++ b/nntrainer/src/pooling2d_layer.cpp @@ -26,30 +26,38 @@ namespace nntrainer { int Pooling2DLayer::initialize() { int status = ML_ERROR_NONE; - if (input_dim.getDataLen() == 1) { + + if (input_dim.size() != 1 || output_dim.size() != 1) { + throw std::invalid_argument("Convolution layer only takes one input"); + } + + TensorDim &in_dim = input_dim[0]; + TensorDim &out_dim = output_dim[0]; + + if (in_dim.getDataLen() == 1) { ml_logw("Warning: the length of previous layer dimension is one"); } - output_dim.batch(input_dim.batch()); - output_dim.channel(input_dim.channel()); + out_dim.batch(in_dim.batch()); + out_dim.channel(in_dim.channel()); if (pooling_type == PoolingType::max || pooling_type == PoolingType::average) { - output_dim.height( - (input_dim.height() - pool_size[0] + 2 * padding[0]) / stride[0] + 1); - output_dim.width( - (input_dim.width() - pool_size[1] + 2 * padding[1]) / stride[1] + 1); + out_dim.height( + (in_dim.height() - pool_size[0] + 2 * padding[0]) / stride[0] + 1); + out_dim.width((in_dim.width() - pool_size[1] + 2 * padding[1]) / stride[1] + + 1); } else { - output_dim.height(1); - output_dim.width(1); + out_dim.height(1); + out_dim.width(1); } if (pooling_type == PoolingType::max) { - max_idx.resize(output_dim.getDataLen()); + max_idx.resize(out_dim.getDataLen()); } if (pooling_type == PoolingType::global_max) { - max_idx_global.resize(output_dim.getDataLen()); + max_idx_global.resize(out_dim.getDataLen()); } return status; @@ -58,11 +66,13 @@ int Pooling2DLayer::initialize() { sharedConstTensors Pooling2DLayer::forwarding(sharedConstTensors in) { input = *in[0]; - TensorDim hidden_dim = output_dim; + TensorDim hidden_dim = output_dim[0]; + TensorDim &in_dim = input_dim[0]; + hidden = Tensor(hidden_dim); hidden.setZero(); - for (unsigned int b = 0; b < input_dim.batch(); ++b) { + for (unsigned int b = 0; b < in_dim.batch(); ++b) { Tensor in_padded = zero_pad(b, input, padding.data()); Tensor result = pooling2d(b, in_padded); memcpy(hidden.getAddress(b * hidden.getDim().getFeatureLen()), @@ -74,16 +84,16 @@ sharedConstTensors Pooling2DLayer::forwarding(sharedConstTensors in) { sharedConstTensors Pooling2DLayer::backwarding(sharedConstTensors derivative, int iteration) { - unsigned int batch = input_dim.batch(); - unsigned int channel = input_dim.channel(); - unsigned int height = input_dim.height(); - unsigned int width = input_dim.width(); + unsigned int batch = input_dim[0].batch(); + unsigned int channel = input_dim[0].channel(); + unsigned int height = input_dim[0].height(); + unsigned int width = input_dim[0].width(); unsigned int p_height = pool_size[0]; unsigned int p_width = pool_size[1]; unsigned int p_size = p_height * p_width; unsigned int J, K; - Tensor result = Tensor(input_dim); + Tensor result = Tensor(input_dim[0]); result.setZero(); float *out = result.getData(); switch (pooling_type) { @@ -172,9 +182,9 @@ void Pooling2DLayer::setBatch(unsigned int batch) { Layer::setBatch(batch); if (pooling_type == PoolingType::max) { - max_idx.resize(output_dim.getDataLen()); + max_idx.resize(output_dim[0].getDataLen()); } else if (pooling_type == PoolingType::global_max) { - max_idx_global.resize(output_dim.getDataLen()); + max_idx_global.resize(output_dim[0].getDataLen()); } } @@ -248,9 +258,10 @@ Tensor Pooling2DLayer::pooling2d(unsigned int batch, Tensor &in) { unsigned int width = in.width(); unsigned int p_height = pool_size[0]; unsigned int p_width = pool_size[1]; - unsigned int base_idx = batch * output_dim.getFeatureLen(); + TensorDim &out_dim = output_dim[0]; + unsigned int base_idx = batch * out_dim.getFeatureLen(); - Tensor output(output_dim.channel(), output_dim.height(), output_dim.width()); + Tensor output(out_dim.channel(), out_dim.height(), out_dim.width()); unsigned int J, K; switch (pooling_type) { @@ -265,10 +276,9 @@ Tensor Pooling2DLayer::pooling2d(unsigned int batch, Tensor &in) { for (unsigned int pj = 0; pj < p_width; ++pj) { float val = in.getValue(0, i, j + pi, k + pj); if (max < val) { - max_idx[base_idx + - i * output_dim.height() * output_dim.width() + - J * output_dim.width() + K] = - batch * input_dim.getFeatureLen() + i * height * width + + max_idx[base_idx + i * out_dim.height() * out_dim.width() + + J * out_dim.width() + K] = + batch * input_dim[0].getFeatureLen() + i * height * width + (j + pi) * width + (k + pj); max = val; } @@ -305,7 +315,8 @@ Tensor Pooling2DLayer::pooling2d(unsigned int batch, Tensor &in) { case PoolingType::global_max: { output.setZero(); for (unsigned int i = 0; i < channel; ++i) { - unsigned int idx = batch * input_dim.getFeatureLen() + i * height * width; + unsigned int idx = + batch * input_dim[0].getFeatureLen() + i * height * width; float max = std::numeric_limits::lowest(); max_idx_global[base_idx + i].clear(); for (unsigned int j = 0; j < height; ++j) { diff --git a/nntrainer/src/tensor_dim.cpp b/nntrainer/src/tensor_dim.cpp index 7ce6cae6..236e9b82 100644 --- a/nntrainer/src/tensor_dim.cpp +++ b/nntrainer/src/tensor_dim.cpp @@ -114,6 +114,13 @@ unsigned int TensorDim::rank() const { return rank; } +unsigned int &TensorDim::operator[](unsigned int index) { + if (index >= MAXDIM) + throw std::out_of_range( + "[TensorDim] Tensor Dimension index should be between 0 and 4"); + return dim[index]; +} + std::ostream &operator<<(std::ostream &out, TensorDim const &d) { out << "Shape: " << d.batch() << ":" << d.channel() << ":" << d.height() << ":" << d.width() << std::endl; diff --git a/test/unittest/unittest_nntrainer_layers.cpp b/test/unittest/unittest_nntrainer_layers.cpp index 44933b68..3b92bde8 100644 --- a/test/unittest/unittest_nntrainer_layers.cpp +++ b/test/unittest/unittest_nntrainer_layers.cpp @@ -54,8 +54,8 @@ protected: virtual int reinitialize() { int status = layer.initialize(); EXPECT_EQ(status, ML_ERROR_NONE); - in = nntrainer::Tensor(layer.getInputDimension()); - out = nntrainer::Tensor(layer.getOutputDimension()); + in = nntrainer::Tensor(layer.getInputDimension()[0]); + out = nntrainer::Tensor(layer.getOutputDimension()[0]); return status; } @@ -209,7 +209,7 @@ TEST_F(nntrainer_InputLayer, set_property_02_p) { int status = setProperty("input_shape=3:2:1"); EXPECT_EQ(status, ML_ERROR_NONE); - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 1); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 2); @@ -221,7 +221,7 @@ TEST_F(nntrainer_InputLayer, set_property_03_p) { int status = setProperty("input_shape=1:3:2:1"); EXPECT_EQ(status, ML_ERROR_NONE); - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 1); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 2); @@ -234,7 +234,7 @@ TEST_F(nntrainer_InputLayer, set_property_04_p) { EXPECT_EQ(status, ML_ERROR_NONE); /** Set input shape ignores batch size */ - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 1); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 2); @@ -248,7 +248,7 @@ TEST_F(nntrainer_InputLayer, set_property_05_p) { setBatch(5); EXPECT_EQ(status, ML_ERROR_NONE); - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 5); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 28); @@ -258,7 +258,7 @@ TEST_F(nntrainer_InputLayer, set_property_05_p) { status = setProperty("input_shape=1:3:2:1"); EXPECT_EQ(status, ML_ERROR_NONE); - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 5); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 2); @@ -268,7 +268,7 @@ TEST_F(nntrainer_InputLayer, set_property_05_p) { status = setProperty("input_shape=4:3:2:1"); EXPECT_EQ(status, ML_ERROR_NONE); - dim = layer.getInputDimension(); + dim = layer.getInputDimension()[0]; EXPECT_EQ(dim.getTensorDim(0), 5); EXPECT_EQ(dim.getTensorDim(1), 3); EXPECT_EQ(dim.getTensorDim(2), 2); @@ -443,7 +443,8 @@ protected: virtual int reinitialize() { int status = super::reinitialize(); - label = MAKE_SHARED_TENSOR(nntrainer::Tensor(layer.getOutputDimension())); + label = + MAKE_SHARED_TENSOR(nntrainer::Tensor(layer.getOutputDimension()[0])); loadFile("tc_fc_1_FCLayer.in", in); loadFile("tc_fc_1_FCKernel.in", layer); @@ -458,7 +459,7 @@ protected: std::make_shared(type); status = act_layer->setProperty( - {"input_shape=" + getDimensionString(layer.getOutputDimension())}); + {"input_shape=" + getDimensionString(layer.getOutputDimension()[0])}); EXPECT_EQ(status, ML_ERROR_NONE); status = act_layer->initialize(); @@ -471,7 +472,7 @@ protected: std::make_shared(); status = loss_layer->setProperty( - {"input_shape=" + getDimensionString(layer.getOutputDimension())}); + {"input_shape=" + getDimensionString(layer.getOutputDimension()[0])}); EXPECT_EQ(status, ML_ERROR_NONE); status = loss_layer->initialize(); @@ -880,7 +881,7 @@ TEST_F(nntrainer_BatchNormalizationLayer, forward_backward_training_01_p) { layer.forwarding({MAKE_SHARED_TENSOR(in)})[0]); matchOutput(*forward_result, "tc_bn_fc_1_goldenBNResultForward.out"); - nntrainer::Tensor backward_in(layer.getOutputDimension()); + nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_fc_1_goldenBNLayerBackwardDxIn.out", backward_in); nntrainer::Tensor backward_result = @@ -915,7 +916,7 @@ TEST_F(nntrainer_BatchNormalizationLayer_Conv, forward_backward_training_01_p) { forward_result = layer.forwarding({MAKE_SHARED_TENSOR(in)})[0]; matchOutput(*forward_result, "tc_bn_conv_1_goldenBNResultForward.out"); - nntrainer::Tensor backward_in(layer.getOutputDimension()); + nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_conv_1_goldenBNLayerBackwardDxIn.out", backward_in); nntrainer::Tensor backward_result = @@ -951,7 +952,7 @@ TEST_F(nntrainer_BatchNormalizationLayer_Conv2, forward_result = layer.forwarding({MAKE_SHARED_TENSOR(in)})[0]; matchOutput(*forward_result, "tc_bn_conv_2_goldenBNResultForward.out"); - nntrainer::Tensor backward_in(layer.getOutputDimension()); + nntrainer::Tensor backward_in(layer.getOutputDimension()[0]); loadFile("tc_bn_conv_2_goldenBNLayerBackwardDxIn.out", backward_in); nntrainer::Tensor backward_result = @@ -1286,7 +1287,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_03_p) { EXPECT_EQ(status, ML_ERROR_NONE); layer2.setBatch(1); status = layer2.setProperty( - {"input_shape=" + getDimensionString(layer1.getOutputDimension())}); + {"input_shape=" + getDimensionString(layer1.getOutputDimension()[0])}); EXPECT_EQ(status, ML_ERROR_NONE); status = layer2.initialize(); EXPECT_EQ(status, ML_ERROR_NONE); @@ -1814,9 +1815,9 @@ class nntrainer_AdditionLayer : public nntrainer_abstractLayer { protected: virtual void prepareLayer() { + setProperty("num_inputs=1"); setInputDim("3:28:28"); setBatch(32); - setProperty("num_inputs=1"); } }; @@ -1860,25 +1861,29 @@ TEST_F(nntrainer_AdditionLayer, forwarding_01_n) { EXPECT_THROW(layer.forwarding({input}), std::runtime_error); } -TEST_F(nntrainer_AdditionLayer, forwarding_02_n) { +/* + *Disabled until input_layer keyward is enabled. + */ + +TEST_F(nntrainer_AdditionLayer, DISABLED_forwarding_02_n) { setProperty("num_inputs=2"); sharedTensor input = std::shared_ptr( new nntrainer::Tensor[1], std::default_delete()); nntrainer::Tensor &in = *input; - in = nntrainer::Tensor(layer.getInputDimension()); + in = nntrainer::Tensor(layer.getInputDimension()[0]); EXPECT_THROW(layer.forwarding({input}), std::runtime_error); } -TEST_F(nntrainer_AdditionLayer, forwarding_03_p) { +TEST_F(nntrainer_AdditionLayer, DISABLED_forwarding_03_p) { setProperty("num_inputs=2"); sharedTensor input = std::shared_ptr( new nntrainer::Tensor[2], std::default_delete()); nntrainer::Tensor &in = *input; - in = nntrainer::Tensor(layer.getInputDimension()); + in = nntrainer::Tensor(layer.getInputDimension()[0]); input.get()[1] = *input; diff --git a/test/unittest/unittest_nntrainer_models.cpp b/test/unittest/unittest_nntrainer_models.cpp index f79daf61..81566580 100644 --- a/test/unittest/unittest_nntrainer_models.cpp +++ b/test/unittest/unittest_nntrainer_models.cpp @@ -69,15 +69,15 @@ public: unsigned int num_weights = node->getNumWeights(); node->setTrainable(false); - expected_input = nntrainer::Tensor(node->getInputDimension()); + expected_input = nntrainer::Tensor(node->getInputDimension()[0]); for (unsigned int i = 0; i < num_weights; ++i) { const nntrainer::Weight &w = node->weightAt(i); expected_weights.push_back(w); } - expected_output = nntrainer::Tensor(node->getOutputDimension()); - expected_dx = nntrainer::Tensor(node->getInputDimension()); + expected_output = nntrainer::Tensor(node->getOutputDimension()[0]); + expected_dx = nntrainer::Tensor(node->getInputDimension()[0]); } /** @@ -282,7 +282,7 @@ void GraphWatcher::compareFor(const std::string &reference, throw std::runtime_error("ref is bad!"); } - nntrainer::Tensor in(nn.getInputDimension()); + nntrainer::Tensor in(nn.getInputDimension()[0]); nntrainer::Tensor lb(label_shape); in.read(ref);