Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_fullyconnected_node.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_common.h>
8 #include <mkldnn_node.h>
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 namespace MKLDNNPlugin {
14
15 class MKLDNNFullyConnectedNode : public MKLDNNNode {
16 public:
17     MKLDNNFullyConnectedNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng);
18     ~MKLDNNFullyConnectedNode() override = default;
19
20     void getSupportedDescriptors() override;
21     void createPrimitive() override;
22     bool created() const override;
23     bool canBeInPlace() const override {
24         return false;
25     }
26
27     const std::vector<impl_desc_type>& getPrimitivesPriority() override;
28     void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
29                           const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;
30
31 protected:
32     std::shared_ptr<mkldnn::primitive_attr> initPrimitiveAttr() const override;
33
34 private:
35     static Register<MKLDNNFullyConnectedNode> reg;
36     InferenceEngine::SizeVector weightsDims;
37     InferenceEngine::SizeVector biasesDims;
38     mkldnn::memory::format weightsFormatForSrcFormat(mkldnn::memory::format sourceFormat);
39
40     InferenceEngine::Blob::Ptr wScale, oScale;
41 };
42
43 }  // namespace MKLDNNPlugin
44