Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_bin_conv_node.h
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include <mkldnn_node.h>
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 namespace MKLDNNPlugin {
14
15 class MKLDNNBinaryConvolutionNode : public MKLDNNNode {
16 public:
17     MKLDNNBinaryConvolutionNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
18     ~MKLDNNBinaryConvolutionNode() override = default;
19
20     void getSupportedDescriptors() override;
21     void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
22                           const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
23     void initDescriptor(const InferenceEngine::LayerConfig& config) override;
24     void createPrimitive() override;
25     void initSupportedPrimitiveDescriptors() override;
26     bool created() const override;
27     bool canBeInPlace() const override {
28         return false;
29     }
30     void setPostOps(mkldnn::primitive_attr &attr, bool initWeights);
31     void pushBinarizationThreshold(float value);
32
33 private:
34     static Register<MKLDNNBinaryConvolutionNode> reg;
35     bool withSum;
36     bool withBinarization;
37     bool isDW;
38     bool isMerged;
39     bool isGrouped;
40     std::vector<ptrdiff_t> stride;
41     std::vector<ptrdiff_t> dilation;
42     std::vector<ptrdiff_t> paddingL;
43     std::vector<ptrdiff_t> paddingR;
44     InferenceEngine::SizeVector weightDims;
45     InferenceEngine::SizeVector biasesDims;
46
47     ptrdiff_t dw_conv_oc;
48     ptrdiff_t dw_conv_ih;
49     ptrdiff_t dw_conv_iw;
50     std::vector<ptrdiff_t> dw_conv_kernel;
51     std::vector<ptrdiff_t> dw_conv_strides;
52     std::vector<MKLDNNMemoryPtr> PostOpsIntBlobMemory;
53
54     float pad_value;
55
56     std::vector<float> binarizationThresholds;
57 };
58
59 }  // namespace MKLDNNPlugin
60