Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_dims.h
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #pragma once
7
8 #include "perf_count.h"
9 #include <vector>
10 #include <utility>
11 #include <mkldnn_types.h>
12 #include <mkldnn.hpp>
13
14 namespace MKLDNNPlugin {
15
16 class MKLDNNDims {
17 public:
18     MKLDNNDims() {
19     }
20
21     explicit MKLDNNDims(const InferenceEngine::SizeVector& size) {
22         dims = std::vector<int>(size.begin(), size.end());
23     }
24
25     explicit MKLDNNDims(const std::vector<int>& dim) {
26         dims = dim;
27     }
28
29     MKLDNNDims(const mkldnn_dims_t dnn_dims, int dnn_ndims) {
30         dims = std::vector<int>(dnn_dims, dnn_dims + dnn_ndims);
31     }
32
33     explicit MKLDNNDims(std::initializer_list<int> ilist) : dims(ilist) {}
34
35     InferenceEngine::SizeVector ToSizeVector() const {
36         InferenceEngine::SizeVector size;
37         for (auto i : dims) {
38             size.push_back(i);
39         }
40
41         return size;
42     }
43
44     int ndims() const {
45         return dims.size();
46     }
47
48     int size() const {
49         return size(0);
50     }
51
52     int size(int start) const {
53         int size = 1;
54
55         for (int i = start; i < dims.size(); i++) {
56             size *= dims[i];
57         }
58
59         return size;
60     }
61
62     void insert(int at, int val) {
63         dims.insert(dims.begin() + at, val);
64     }
65
66     void push_back(int val) {
67         dims.push_back(val);
68     }
69
70     void swap(int from, int to) {
71         int tmp = dims[from];
72         dims[from] = dims[to];
73         dims[to] = tmp;
74     }
75
76     operator mkldnn::memory::dims() const {
77         return dims;
78     }
79
80     bool operator == (const MKLDNNDims& rhs) {
81         if (dims.size() != rhs.dims.size()) {
82             return false;
83         }
84
85         return std::equal(rhs.dims.begin(), rhs.dims.end(), dims.begin());
86     }
87
88     bool operator != (const MKLDNNDims& rhs) {
89         return !(*this == rhs);
90     }
91
92     int& operator[](int idx) {
93         return dims[idx];
94     }
95
96     int operator[](int idx) const {
97         return dims[idx];
98     }
99
100 private:
101     std::vector<int> dims;
102 };
103
104 }  // namespace MKLDNNPlugin