Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_dims.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "perf_count.h"
8 #include <vector>
9 #include <utility>
10 #include <mkldnn_types.h>
11 #include <ie_common.h>
12 #include <mkldnn.hpp>
13
14 namespace MKLDNNPlugin {
15
16 class MKLDNNDims {
17 public:
18     MKLDNNDims() = default;
19
20     explicit MKLDNNDims(const InferenceEngine::SizeVector& size) {
21         dims = std::vector<ptrdiff_t>(size.begin(), size.end());
22     }
23
24     explicit MKLDNNDims(const std::vector<ptrdiff_t>& dim) {
25         dims = dim;
26     }
27
28     MKLDNNDims(const mkldnn_dims_t dnn_dims, int dnn_ndims) {
29         dims = std::vector<ptrdiff_t>(dnn_dims, dnn_dims + dnn_ndims);
30     }
31
32     explicit MKLDNNDims(std::initializer_list<ptrdiff_t> ilist) : dims(ilist) {}
33     explicit MKLDNNDims(std::initializer_list<size_t > ilist) : dims(ilist.begin(), ilist.end()) {}
34
35     InferenceEngine::SizeVector ToSizeVector() const {
36         InferenceEngine::SizeVector size;
37         for (auto i : dims) {
38             size.push_back(i);
39         }
40
41         return size;
42     }
43
44     int ndims() const {
45         return dims.size();
46     }
47
48     ptrdiff_t size() const {
49         return size(0);
50     }
51
52     ptrdiff_t size(int start) const {
53         ptrdiff_t size = 1;
54
55         for (int i = start; i < dims.size(); i++) {
56             size *= dims[i];
57         }
58
59         return size;
60     }
61
62     void push_back(int val) {
63         dims.push_back(val);
64     }
65
66     operator mkldnn::memory::dims() const {
67         return dims;
68     }
69
70     bool operator == (const MKLDNNDims& rhs) const {
71         if (dims.size() != rhs.dims.size()) {
72             return false;
73         }
74
75         return std::equal(rhs.dims.begin(), rhs.dims.end(), dims.begin());
76     }
77
78     bool operator != (const MKLDNNDims& rhs) const {
79         return !(*this == rhs);
80     }
81
82     ptrdiff_t& operator[](int idx) {
83         return dims[idx];
84     }
85
86     ptrdiff_t operator[](int idx) const {
87         return dims[idx];
88     }
89
90 private:
91     std::vector<ptrdiff_t> dims;
92 };
93
94 }  // namespace MKLDNNPlugin