return cnnNetwork->getBatchSize();
}
auto params = _ngraph_function->get_parameters();
- if (params.empty() || !params[0]->get_partial_shape().is_static()) return 0;
-
- auto shape = _ngraph_function->get_parameters()[0]->get_shape();
-
- // WA: for speech recognition layouts (copy-past from CNNNetwork)
- if (shape.size() == 3 || shape.size() == 1) {
- return 1;
+ for (const auto& param : params) {
+ if (param->get_partial_shape().is_dynamic())
+ continue;
+ auto shape = param->get_shape();
+ // WA: for speech recognition and scalar layouts (copy-past from CNNNetwork)
+ if (!shape.empty() && shape.size() != 3 && shape.size() != 1)
+ return shape[0];
}
- return shape[0];
+ return 1;
}
std::shared_ptr<ngraph::Function> CNNNetworkNGraphImpl::cloneFunction(bool constFolding) const {
SizeVector dims = _inputData.cbegin()->second->getTensorDesc().getDims();
- // 3D input layout doesn't have batch notation
- if (dims.size() == 3 || dims.size() == 1) {
- return DescriptionBuffer(PARAMETER_MISMATCH, responseDesc) << "Cannot set batch for 1D/3D input";
+ // 3D/1D/0D input layouts don't have batch notation
+ if (dims.size() == 3 || dims.size() == 1 || dims.empty()) {
+ return DescriptionBuffer(PARAMETER_MISMATCH, responseDesc) << "Cannot set batch for 0D/1D/3D input";
}
const std::map<CNNLayer*, bool> layersMap = getConstLayersMap(*this);
ASSERT_EQ(nullptr, cnnNet.getFunction());
}
+TEST(CNNNGraphImplTests, TestGetBatchScalar) {
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::Shape shape({});
+ ngraph::element::Type type(ngraph::element::Type_t::f32);
+ auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+ auto relu = std::make_shared<ngraph::op::Relu>(param);
+ auto result = std::make_shared<ngraph::op::Result>(relu);
+
+ ngraph::ParameterVector params = {param};
+ ngraph::ResultVector results = {result};
+
+ ngraph = std::make_shared<ngraph::Function>(results, params);
+ }
+
+ InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+ ASSERT_EQ(1, cnnNet.getBatchSize());
+}
+
+TEST(CNNNGraphImplTests, TestSetBatchScalar) {
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ ngraph::Shape shape({});
+ ngraph::element::Type type(ngraph::element::Type_t::f32);
+ auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+ auto relu = std::make_shared<ngraph::op::Relu>(param);
+ auto result = std::make_shared<ngraph::op::Result>(relu);
+
+ ngraph::ParameterVector params = {param};
+ ngraph::ResultVector results = {result};
+
+ ngraph = std::make_shared<ngraph::Function>(results, params);
+ }
+
+ InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+ ASSERT_EQ(1, cnnNet.getBatchSize());
+ ASSERT_EQ(PARAMETER_MISMATCH, cnnNet.setBatchSize(2, nullptr)); // triggers conversion
+}
+
TEST(CNNNGraphImplTests, TestSaveAffinity) {
const std::string testAffinity = "testAffinity";
std::shared_ptr<ngraph::Function> ngraph;