Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_conv_node.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include <mkldnn_node.h>
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 namespace MKLDNNPlugin {
14
15 class MKLDNNConvolutionNode : public MKLDNNNode {
16 public:
17     MKLDNNConvolutionNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
18     ~MKLDNNConvolutionNode() override = default;
19
20     void getSupportedDescriptors() override;
21     void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
22                           const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
23     void initDescriptor(const InferenceEngine::LayerConfig& config) override;
24     void createPrimitive() override;
25     void initSupportedPrimitiveDescriptors() override;
26     bool created() const override;
27     bool canBeInPlace() const override {
28         return false;
29     }
30     void setPostOps(mkldnn::primitive_attr &attr, bool initWeights);
31
32 protected:
33     void addScaleToPrimitiveAttr(mkldnn::primitive_attr attr) const;
34
35 private:
36     static Register<MKLDNNConvolutionNode> reg;
37     bool withBiases;
38     bool withActivation;
39     bool withSum;
40     bool isDW;
41     bool isMerged;
42     bool isGrouped;
43     std::vector<ptrdiff_t> stride;
44     std::vector<ptrdiff_t> dilation;
45     std::vector<ptrdiff_t> paddingL;
46     std::vector<ptrdiff_t> paddingR;
47     InferenceEngine::SizeVector weightDims;
48     InferenceEngine::SizeVector biasesDims;
49
50     ptrdiff_t dw_conv_oc;
51     ptrdiff_t dw_conv_ih;
52     ptrdiff_t dw_conv_iw;
53     std::vector<ptrdiff_t> dw_conv_kernel;
54     std::vector<ptrdiff_t> dw_conv_strides;
55     std::vector<MKLDNNMemoryPtr> PostOpsIntBlobMemory;
56
57     InferenceEngine::ConvolutionLayer* convLayer;
58     InferenceEngine::Blob::Ptr wScale, oScale;
59 };
60
61 }  // namespace MKLDNNPlugin
62