Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_reshape_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_reshape_node.h"
6 #include <ie_layers.h>
7 #include <string>
8 #include <mkldnn_types.h>
9 #include <mkldnn_extension_utils.h>
10
11 using namespace mkldnn;
12 using namespace MKLDNNPlugin;
13 using namespace InferenceEngine;
14
15 MKLDNNReshapeNode::MKLDNNReshapeNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
16
17 void MKLDNNReshapeNode::getSupportedDescriptors() {
18     if (getParentEdges().size() != 1)
19         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
20     if (getChildEdges().empty())
21         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
22 }
23
24 void MKLDNNReshapeNode::initSupportedPrimitiveDescriptors() {
25     if (!supportedPrimitiveDescriptors.empty())
26         return;
27
28     InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
29     if (precision != InferenceEngine::Precision::FP32)
30         precision = InferenceEngine::Precision::FP32;
31     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
32     precision = getCnnLayer()->outData[0]->getPrecision();
33     if (precision != InferenceEngine::Precision::FP32)
34         precision = InferenceEngine::Precision::FP32;
35     auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
36
37     auto& inDims = getParentEdgeAt(0)->getDims();
38     auto& outDims = getChildEdgeAt(0)->getDims();
39     memory::format outFormat = MKLDNNMemory::GetPlainFormat(outDims);
40     InferenceEngine::LayerConfig config;
41     config.dynBatchSupport = true;
42     config.inConfs.resize(1);
43     config.outConfs.resize(1);
44     config.inConfs[0].inPlace = -1;
45     config.inConfs[0].constant = false;
46     config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, MKLDNNMemory::GetPlainFormat(getParentEdgeAt(0)->getDims()));
47     config.outConfs[0].inPlace = 0;
48     config.outConfs[0].constant = false;
49     config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, outFormat);
50     supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
51 }
52
53 void MKLDNNReshapeNode::createPrimitive() {
54     auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
55     auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
56     if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
57         THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
58     if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
59         THROW_IE_EXCEPTION << "Input memory didn't allocate.";
60     if (getSelectedPrimitiveDescriptor() == nullptr)
61         THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
62 }
63
64 bool MKLDNNReshapeNode::created() const {
65     return getType() == Reshape || getType() == Flatten;
66 }