cnnLayer->insData.resize(inputCount);
for (size_t i = 0; i < layer->get_output_size(); i++) {
+ // Memory node with index = 1 has no inputs according to the specification.
+ // For proper conversion, we must cut off all the layers and data nodes above ReadValue,
+ // if they are connected only with this layer.
+ // Now MO generates only constants or constant sub-graphs as input to ReadValue op.
+ if (std::dynamic_pointer_cast<::ngraph::op::Constant>(layer)) {
+ bool all_to_read_value = !layer->output(i).get_target_inputs().empty();
+ for (const auto &output_input : layer->output(i).get_target_inputs()) {
+ all_to_read_value
+ &= dynamic_cast<ngraph::op::ReadValue *>(output_input.get_node()) != nullptr;
+ }
+ if (all_to_read_value)
+ continue;
+ }
+
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
cnnLayer->outData.clear();
continue;
std::string outName = layer->get_friendly_name();
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
-
SizeVector dims;
dims = layer->get_output_shape(i);
for (const auto &dim : dims) {
#include <ie_core.hpp>
#include <net_pass.h>
+#include <ngraph/opsets/opset3.hpp>
#include <ngraph/function.hpp>
#include <ngraph/variant.hpp>
#include <ngraph/op/maximum.hpp>
InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
}
+TEST(CNNNGraphImplTests, CanSetBatchReadValue) {
+ std::shared_ptr<ngraph::Function> ngraph;
+ {
+ auto input = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2});
+ auto constant = std::make_shared<ngraph::opset3::Constant>(ngraph::element::f32, ngraph::Shape{1, 2},
+ std::vector<float>{1, 2});
+
+ auto read_value = std::make_shared<ngraph::opset3::ReadValue>(constant, "variable_id");
+ auto add = std::make_shared<ngraph::opset3::Add>(input, read_value);
+ auto result = std::make_shared<ngraph::op::Result>(add);
+
+ ngraph::ParameterVector params = {input};
+ ngraph::ResultVector results = {result};
+
+ ngraph = std::make_shared<ngraph::Function>(results, params);
+ }
+
+ InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+ auto status = cnnNet.getCNNNetwork()->setBatchSize(4, nullptr);
+ EXPECT_EQ(status, StatusCode::OK);
+}
IE_SUPPRESS_DEPRECATED_END