/**
* @brief A constructor from ngraph::Function object
+ * This constructor wraps existing ngraph::Function
+ * If you want to avoid modification of original Function, please create a copy
* @param network Pointer to the ngraph::Function object
*/
- explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
+ explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network);
/**
* @brief A destructor
void dumpGraph(InferenceEngine::ICNNNetwork& network,
- const std::vector<std::shared_ptr<const ngraph::Function>>& subFunctions,
+ const std::vector<std::shared_ptr<ngraph::Function>>& subFunctions,
std::ostream& stream) {
static const std::array<const char *, 9> colors{{"#FFC405",
"#20F608",
InputsDataMap externalInputsData;
network.getInputsInfo(externalInputsData);
networks.resize(orderedSubgraphs.size());
- std::vector<std::shared_ptr<const ngraph::Function>> subFunctions(orderedSubgraphs.size());
+ std::vector<std::shared_ptr<ngraph::Function>> subFunctions(orderedSubgraphs.size());
std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
int id = 0;
for (auto&& subgraph : orderedSubgraphs) {
networks[id]._device = subgraph._affinity;
subFunctions[id] =
- std::make_shared<const ngraph::Function>(subgraph._results, subgraph._parameters,
+ std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
_name + '_' + std::to_string(id));
networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
// update of pre-processing info
return specialized_function;
}
-// WA: for cnnNetwork ngraph constructor
-CNNNetwork::CNNNetwork(const std::shared_ptr<const ngraph::Function>& graph) {
+CNNNetwork::CNNNetwork(const std::shared_ptr<ngraph::Function>& graph) {
if (graph == nullptr) {
THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
}
- // Copy nGraph function
- network = std::make_shared<CNNNetworkNGraphImpl>(copyFunction(graph, false, {}));
+ // Create CNNNetworkNGraphImpl
+ network = std::make_shared<CNNNetworkNGraphImpl>(graph);
actual = network.get();
if (actual == nullptr) {
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
}
}
+CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) {
+ if (network.getFunction() == nullptr) {
+ THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!";
+ }
+
+ _ngraph_function = copyFunction(network.getFunction(), false, {});
+ InputsDataMap inputs;
+ OutputsDataMap outputs;
+ network.getInputsInfo(inputs);
+ network.getOutputsInfo(outputs);
+
+ for (const auto& outputInfo : outputs) {
+ const auto& name = outputInfo.second->getName();
+ DataPtr output = std::make_shared<Data>(name, outputInfo.second->getTensorDesc());
+ _outputData[name] = output;
+ _data[name] = output;
+ }
+ for (const auto& inputInfo : inputs) {
+ InputInfo::Ptr info = std::make_shared<InputInfo>();
+ const auto& name = inputInfo.second->getInputData()->getName();
+ DataPtr input = std::make_shared<Data>(name, inputInfo.second->getInputData()->getTensorDesc());
+ _data[name] = input;
+ info->setInputData(input);
+ info->getPreProcess() = inputInfo.second->getPreProcess();
+ info->setPrecision(inputInfo.second->getPrecision());
+ info->setLayout(inputInfo.second->getLayout());
+ _inputData[name] = info;
+ }
+}
+
void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
if (cnnNetwork) cnnNetwork->setInputInfo(data);
_inputData[data->name()] = data;
class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
public:
CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
+ CNNNetworkNGraphImpl(const ICNNNetwork& nGraph);
~CNNNetworkNGraphImpl() override = default;
void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
#include "graph_tools.hpp"
#include "net_pass.h"
#include "precision_utils.h"
+#include "cnn_network_ngraph_impl.hpp"
using std::string;
}
std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
- if (auto func = network.getFunction()) {
- CNNNetwork net(func);
-
- InputsDataMap originInputs;
- OutputsDataMap originOutputs;
- network.getInputsInfo(originInputs);
- network.getOutputsInfo(originOutputs);
- InputsDataMap clonedInputs = net.getInputsInfo();
- OutputsDataMap clonedOutputs = net.getOutputsInfo();
-
- for (const auto& outputInfo : originOutputs) {
- if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end())
- THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs";
- clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision());
- clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout());
- }
- for (const auto& inputInfo : originInputs) {
- if (clonedInputs.find(inputInfo.first) == clonedInputs.end())
- THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs";
- clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision());
- clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout());
- clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess();
- }
- return net;
+ if (network.getFunction()) {
+ return std::make_shared<details::CNNNetworkNGraphImpl>(network);
}
return cloneNet(network);
}
TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
- std::shared_ptr<const ngraph::Function> nlptr = nullptr;
+ std::shared_ptr<ngraph::Function> nlptr = nullptr;
ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
}
#include <ngraph/op/relu.hpp>
#include <ngraph/op/result.hpp>
#include <ngraph/opsets/opset.hpp>
+#include <ngraph/graph_util.hpp>
#include <ie_util_internal.hpp>
#include <ie_core.hpp>
}
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
+ std::shared_ptr<const ngraph::Function> ngraph;
+ {
+ ngraph::PartialShape shape({1, 3, 22, 22});
+ ngraph::element::Type type(ngraph::element::Type_t::f32);
+ auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+ param->set_friendly_name("data");
+ auto relu = std::make_shared<ngraph::op::Relu>(param);
+ auto result = std::make_shared<ngraph::op::Result>(relu);
+
+ ngraph::ParameterVector params = {param};
+ ngraph::ResultVector results = {result};
+
+ ngraph = std::make_shared<const ngraph::Function>(results, params);
+ }
+
+ ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+ ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+
+ CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph));
+ std::map<std::string, std::vector<size_t>> shapes;
+ shapes["data"] = {1, 3, 25, 25};
+
+ ASSERT_NO_THROW(cnnNetwork.reshape(shapes));
+
+ auto changedFunction = cnnNetwork.getFunction();
+ ASSERT_NE(nullptr, changedFunction);
+ ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+ ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+ ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+ ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+}
+
+TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) {
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ASSERT_NE(nullptr, changedFunction);
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
- ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
- ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+ ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+ ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
}
class CustomTestOp: public ngraph::op::Op {