1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_reshape_node.h"
8 #include <mkldnn_types.h>
9 #include <mkldnn_extension_utils.h>
11 using namespace mkldnn;
12 using namespace MKLDNNPlugin;
13 using namespace InferenceEngine;
15 MKLDNNReshapeNode::MKLDNNReshapeNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
17 void MKLDNNReshapeNode::getSupportedDescriptors() {
18 if (getParentEdges().size() != 1)
19 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
20 if (getChildEdges().empty())
21 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
24 void MKLDNNReshapeNode::initSupportedPrimitiveDescriptors() {
25 if (!supportedPrimitiveDescriptors.empty())
28 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
29 if (precision != InferenceEngine::Precision::FP32)
30 precision = InferenceEngine::Precision::FP32;
31 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
32 precision = getCnnLayer()->outData[0]->getPrecision();
33 if (precision != InferenceEngine::Precision::FP32)
34 precision = InferenceEngine::Precision::FP32;
35 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
37 auto& inDims = getParentEdgeAt(0)->getDims();
38 auto& outDims = getChildEdgeAt(0)->getDims();
39 memory::format outFormat = MKLDNNMemory::GetPlainFormat(outDims);
40 InferenceEngine::LayerConfig config;
41 config.dynBatchSupport = true;
42 config.inConfs.resize(1);
43 config.outConfs.resize(1);
44 config.inConfs[0].inPlace = -1;
45 config.inConfs[0].constant = false;
46 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, MKLDNNMemory::GetPlainFormat(getParentEdgeAt(0)->getDims()));
47 config.outConfs[0].inPlace = 0;
48 config.outConfs[0].constant = false;
49 config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, outFormat);
50 supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
53 void MKLDNNReshapeNode::createPrimitive() {
54 auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
55 auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
56 if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
57 THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
58 if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
59 THROW_IE_EXCEPTION << "Input memory didn't allocate.";
60 if (getSelectedPrimitiveDescriptor() == nullptr)
61 THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
64 bool MKLDNNReshapeNode::created() const {
65 return getType() == Reshape || getType() == Flatten;