1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_roi_pooling_node.h"
6 #include "desc_iterator.hpp"
11 #include <mkldnn_extension_utils.h>
13 using namespace mkldnn;
14 using namespace MKLDNNPlugin;
15 using namespace InferenceEngine;
17 MKLDNNROIPoolingNode::MKLDNNROIPoolingNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
19 void MKLDNNROIPoolingNode::getSupportedDescriptors() {
23 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
24 if (precision != InferenceEngine::Precision::FP32)
25 precision = InferenceEngine::Precision::FP32;
26 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
27 precision = getCnnLayer()->outData[0]->getPrecision();
28 if (precision != InferenceEngine::Precision::FP32)
29 precision = InferenceEngine::Precision::FP32;
30 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
32 GenericLayer* genericLayer = getCnnLayer().get();
34 if (genericLayer == nullptr)
35 THROW_IE_EXCEPTION << "Cannot convert ROIPooling layer.";
37 if (getParentEdges().empty())
38 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
39 if (getChildEdges().empty())
40 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
42 pooled_h = genericLayer->GetParamAsInt("pooled_h");
43 pooled_w = genericLayer->GetParamAsInt("pooled_w");
44 spatial_scale = genericLayer->GetParamAsFloat("spatial_scale");
45 std::string m = genericLayer->GetParamAsString("method", "max");
47 method = mkldnn::algorithm::roi_pooling_max;
48 } else if (m == "bilinear") {
49 method = mkldnn::algorithm::roi_pooling_bilinear;
51 THROW_IE_EXCEPTION << "Unsupported roi pooling method";
54 auto parentDims = getParentEdgeAt(0)->getDims();
55 for (auto format : getAvailableFormatsForDims(parentDims)) {
56 std::vector<InferenceEngine::TensorDesc> srcs;
57 srcs.push_back(MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, format));
58 srcs.push_back(MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), inputDataType, memory::nc));
59 MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType, format);
61 createDescriptor(srcs, {out_candidate});
65 void MKLDNNROIPoolingNode::createPrimitive() {
69 std::vector<memory::desc> srcs;
70 for (size_t i = 0; i < getParentEdges().size(); i++) {
71 srcs.push_back(getParentEdgeAt(i)->getMemory().GetDescriptor());
74 memory::desc out_candidate = getChildEdgeAt(0)->getMemory().GetDescriptor();
76 MKLDNNDescriptor desc(std::shared_ptr<roi_pooling_forward::desc>(
77 new roi_pooling_forward::desc(prop_kind::forward_scoring, method, srcs, out_candidate, pooled_h, pooled_w,
81 std::shared_ptr<roi_pooling_forward::desc> selected_desc_ptr = descs[0];
83 const PrimitiveDescInfo *selected_pd = getSelectedPrimitiveDescriptor();
84 if (selected_pd == nullptr)
85 THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set for node " << getName() << ".";
87 auto prim_desc = roi_pooling_forward::primitive_desc(*selected_desc_ptr, getEngine());
88 primitive_desc_iterator itpd = descs[0].createPrimitiveDescriptorIterator(getEngine());
90 std::vector<primitive::at> src_p;
91 for (size_t i = 0; i < getParentEdges().size(); i++) {
92 src_p.push_back(getParentEdgeAt(i)->getMemoryPtr()->GetPrimitive());
94 prim.reset(new roi_pooling_forward(prim_desc, src_p, getChildEdgeAt(0)->getMemory().GetPrimitive()));
97 bool MKLDNNROIPoolingNode::created() const {
98 return getType() == ROIPooling;
101 void MKLDNNROIPoolingNode::createDescriptor(const std::vector<InferenceEngine::TensorDesc> &inputDesc,
102 const std::vector<InferenceEngine::TensorDesc> &outputDesc) {
103 std::vector<memory::desc> srcs;
104 srcs.push_back(MKLDNNMemoryDesc(inputDesc[0]));
105 srcs.push_back(MKLDNNMemoryDesc(inputDesc[1]));
106 MKLDNNMemoryDesc out_candidate(outputDesc[0]);
108 MKLDNNDescriptor desc(std::shared_ptr<roi_pooling_forward::desc>(
109 new roi_pooling_forward::desc(prop_kind::forward_scoring, method, srcs, out_candidate, pooled_h, pooled_w,
111 descs.push_back(desc);