Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn / cpu_engine.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "inference_engine.hpp"
8 #include "desc_layer.h"
9 #include "desc_tensor.h"
10 #include "desc_tensor_comb.h"
11
12 #include "cpu_prim_layer.h"
13 #include "cpu_prim_tensor.h"
14
15 #include "mkldnn.hpp"
16 #include <memory>
17 #include <vector>
18
19 using namespace InferenceEngine;
20
21 namespace MKLDNNPlugin {
22 class CpuEngine;
23
24 using CpuEnginePtr = std::shared_ptr<CpuEngine>;
25
26 class CpuEngine : public details::no_copy {
27 public:
28     CpuEngine() : eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0)) {}
29
30     void bindThreads();
31
32     void createDescription(DescTensorPtr tns, bool isWeights = false);
33
34     void createDescription(DescLayerPtr layer);
35
36     void setFlatFormat(DescTensorPtr tns);
37
38     void createPrimitive(DescTensorPtr tns);
39
40     void createPrimitive(DescLayerPtr tns);
41
42     void setData(const TBlob<float> &src, DescTensorPtr dst);
43
44     void getData(const DescTensorPtr src, TBlob<float> &dst);
45
46     void subtraction(DescTensorPtr dst, DescTensorPtr sub);
47
48     void subtraction(DescTensorPtr dst, std::vector<float> sub);
49
50     void score(std::vector<DescLayerPtr> layers);
51
52     void score(DescLayerPtr layer);
53
54     void process(std::vector<mkldnn::primitive> exec_queue);
55
56     mkldnn::engine eng;  // TODO: Move me back to private section
57
58 private:
59     static inline mkldnn::memory::desc *get_desc(std::vector<DescTensorPtr> tensors, size_t indx = 0);
60
61     static inline mkldnn::memory::desc *get_desc(DescTensorPtr tns);
62
63     static inline mkldnn::memory *get_prim(std::vector<DescTensorPtr> tns, size_t indx = 0);
64
65     static inline mkldnn::memory *get_prim(DescTensorPtr tns);
66
67     void createPrimitiveCombined(DescTensorComb &tns, void *data);
68 };
69 }  // namespace MKLDNNPlugin