From 5ae735516bb66183039338e3b9c1eefb5f52d775 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Tue, 19 Jan 2021 23:15:20 +0900 Subject: [PATCH] [inference] Add input validation for inference Add input validation for inference of the neural network Signed-off-by: Parichay Kapoor --- nntrainer/models/neuralnet.cpp | 32 +++++++++++++++++++++++++++++++- nntrainer/models/neuralnet.h | 7 +++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index 51cc04c..16e8228 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -438,6 +438,33 @@ void NeuralNetwork::setBatchSize(unsigned int batch) { throw std::invalid_argument("Error setting batchsize for the dataset"); } +bool NeuralNetwork::validateInput(sharedConstTensors X) { + + auto &first_layer = model_graph.getSortedLayerNode(0).layer; + auto input_dim = first_layer->getInputDimension(); + if (X.size() != input_dim.size()) { + ml_loge("Error: provided number of inputs %d, required %d", (int)X.size(), + (int)input_dim.size()); + return false; + } + + for (unsigned int dim = 0; dim < input_dim.size(); dim++) { + if (input_dim[dim] != X[dim]->getDim()) { + ml_loge("Error: provided input shape does not match required shape"); + std::stringstream ss; + ss << X[dim]->getDim(); + ml_loge("Provided tensor summary : %s", ss.str().c_str()); + + ss.str(std::string()); + ss << input_dim[dim]; + ml_loge("Required tensor summary : %s", ss.str().c_str()); + return false; + } + } + + return true; +} + sharedConstTensors NeuralNetwork::inference(sharedConstTensors X) { if (batch_size != X[0]->batch()) { /** @@ -447,9 +474,12 @@ sharedConstTensors NeuralNetwork::inference(sharedConstTensors X) { setBatchSize(X[0]->batch()); } + sharedConstTensors out; + if (!validateInput(X)) + return out; + assignMem(false); - sharedConstTensors out; try { START_PROFILE(profile::NN_FORWARD); forwarding(X, {}, false); diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index a360b02..28d8913 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -636,6 +636,13 @@ private: * @param path path to set as a save path */ void setSavePath(const std::string &path); + + /** + * @brief Match the given tensor shape with input shape of the model + * @param[in] X input tensor + * @retval true if matches, false is error + */ + bool validateInput(sharedConstTensors X); }; } /* namespace nntrainer */ -- 2.7.4