fd74408834f421ab2aa9b33b22f8f64ad68dc2bb
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / in / convolution_simple_small.h
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 INST_TEST_CASE(SimpleSmall_ZeroDim,
18     PARAMS(nchw, oihw, FMT_BIAS, nchw, 0, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
19     PARAMS(nchw, oihw, FMT_BIAS, nchw, 0, 1, 4, 0, 4, 6, 0, 4, 3, 3, 1, 1, 1, 1),
20     PARAMS(nchw, oihw, FMT_BIAS, nchw, 0, 1, 4, 0, 4, 6, 2, 4, 1, 3, 1, 1, 1, 1),
21     PARAMS(nchw, oihw, FMT_BIAS, nchw, 0, 1, 4, 2, 4, 6, 2, 4, 3, 3, 0, 1, 1, 1)
22 );
23
24 INST_TEST_CASE(SimpleSmall_NCHW_expected_failures,
25     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, 0, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
26     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, 4, 4, 4, 0, 4, 4, 3, 3, 1, 1, 1, 1),
27     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, 4, 4, 4, 6, 4, 4, 0, 3, 1, 1, 1, 1),
28     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, -4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
29     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, 4, 4, 4, 6, 4, 4, 3, 3, -1, 1, 1, 1),
30     PARAMS_EXPECT_FAIL(nchw, oihw, FMT_BIAS, nchw, mkldnn_invalid_arguments, 1, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 0, 0)
31 );
32
33 INST_TEST_CASE(SimpleSmall_Blocked16_padded,
34     // non-1x1 (all)
35     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 17, 13, 13, 23, 12, 12, 3, 3, 0, 0, 1, 1),
36     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 21, 13, 13, 16, 12, 12, 3, 3, 0, 0, 1, 1),
37     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 23, 13, 13, 19, 12, 12, 3, 3, 0, 0, 1, 1),
38     // 1x1 (fwd, bwd_w)
39     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 17, 13, 13, 23, 13, 13, 1, 1, 0, 0, 1, 1),
40     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 21, 13, 13, 16, 13, 13, 1, 1, 0, 0, 1, 1),
41     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 23, 13, 13, 19, 13, 13, 1, 1, 0, 0, 1, 1),
42     // 1x1 (bwd_d)
43     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16_IOhw16o16i, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 17, 13, 13, 23, 13, 13, 1, 1, 0, 0, 1, 1),
44     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16_IOhw16o16i, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 21, 13, 13, 16, 13, 13, 1, 1, 0, 0, 1, 1),
45     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16_IOhw16o16i, FMT_BIAS, FMT_DATA_BLOCKED16, 2, 1, 23, 13, 13, 19, 13, 13, 1, 1, 0, 0, 1, 1)
46 );
47
48 INST_TEST_CASE(SimpleSmall_Blocked8_padded,
49     // non-1x1 (all)
50     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 17, 13, 13, 23, 12, 12, 3, 3, 0, 0, 1, 1),
51     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 21, 13, 13, 16, 12, 12, 3, 3, 0, 0, 1, 1),
52     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 23, 13, 13, 19, 12, 12, 3, 3, 0, 0, 1, 1),
53     // 1x1 (all)
54     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 17, 13, 13, 23, 13, 13, 1, 1, 0, 0, 1, 1),
55     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 21, 13, 13, 16, 13, 13, 1, 1, 0, 0, 1, 1),
56     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED, 2, 1, 23, 13, 13, 19, 13, 13, 1, 1, 0, 0, 1, 1)
57 );
58
59 INST_TEST_CASE(SimpleSmall_NCHW,
60     PARAMS(nchw, oihw, FMT_BIAS, nchw,
61         2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
62     PARAMS(nchw, oihw, FMT_BIAS, nchw,
63         2, 1, 4, 4, 4, 6, 2, 2, 3, 3, 0, 0, 1, 1),
64     PARAMS(nhwc, oihw, FMT_BIAS, nhwc,
65         2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
66     PARAMS(nhwc, oihw, FMT_BIAS, nhwc,
67         2, 1, 4, 4, 4, 6, 2, 2, 3, 3, 0, 0, 1, 1),
68     PARAMS(nhwc, hwio, FMT_BIAS, nhwc,
69         2, 1, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1),
70     PARAMS(nhwc, hwio, FMT_BIAS, nhwc,
71         2, 1, 4, 4, 4, 6, 2, 2, 3, 3, 0, 0, 1, 1),
72     PARAMS(nhwc, hwigo, FMT_BIAS, nhwc,
73         2, 2, 4, 4, 4, 6, 4, 4, 3, 3, 0, 0, 1, 1),
74     PARAMS(nhwc, hwigo, FMT_BIAS, nhwc,
75         2, 2, 4, 4, 4, 6, 4, 4, 3, 3, 1, 1, 1, 1)
76 );
77
78 INST_TEST_CASE(SimpleSmall_Blocked,
79     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
80         2, 1, 32, 13, 13, 32, 12, 12, 3, 3, 0, 0, 1, 1),
81     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
82         2, 1, 32, 3, 3, 32, 4, 4, 3, 3, 1, 1, 1, 1),
83     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
84         2, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1, 1),
85     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
86         2, 1, 32, 3, 3, 32, 2, 2, 3, 3, 0, 0, 1, 1),
87     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
88         2, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1, 1),
89     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
90         2, 1, 32, 13, 13, 48, 13, 13, 3, 3, 1, 1, 1, 1),
91     PARAMS(FMT_DATA_BLOCKED, FMT_WEIGHTS_BLOCKED, FMT_BIAS, FMT_DATA_BLOCKED,
92         2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1)
93 );
94
95 INST_TEST_CASE(SimpleSmall_Blocked16,
96     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
97         2, 1, 32, 13, 13, 32, 12, 12, 3, 3, 0, 0, 1, 1),
98     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
99         2, 1, 32, 3, 3, 32, 4, 4, 3, 3, 1, 1, 1, 1),
100     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
101         2, 1, 32, 4, 4, 32, 4, 4, 3, 3, 0, 0, 1, 1),
102     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
103         2, 1, 32, 3, 3, 32, 2, 2, 3, 3, 0, 0, 1, 1),
104     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
105         2, 1, 32, 2, 2, 32, 2, 2, 3, 3, 1, 1, 1, 1),
106     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
107         2, 1, 32, 13, 13, 48, 13, 13, 3, 3, 1, 1, 1, 1),
108     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
109         2, 1, 32, 13, 13, 48, 11, 11, 3, 3, 0, 0, 1, 1),
110     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
111         2, 1, 32, 8, 8, 48, 5, 5, 4, 4, 0, 0, 1, 1),
112     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
113         2, 1, 32, 7, 7, 48, 10, 10, 4, 4, 3, 3, 1, 1),
114     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
115         2, 1, 32, 1, 1, 48, 2, 2, 4, 4, 2, 2, 1, 1),
116     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
117         2, 1, 32, 28, 28, 48, 13, 13, 4, 4, 0, 0, 2, 2),
118     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
119         2, 1, 32, 28, 28, 48, 14, 14, 4, 4, 1, 1, 2, 2),
120     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
121         2, 1, 32, 26, 26, 48, 13, 13, 4, 4, 1, 1, 2, 2),
122     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
123         2, 1, 32, 84, 84, 48, 28, 28, 5, 5, 1, 1, 3, 3),
124     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
125         2, 1, 32, 21, 21, 48, 7, 7, 5, 5, 1, 1, 3, 3),
126     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
127         2, 1, 32, 18, 18, 48, 5, 5, 6, 6, 2, 2, 4, 4),
128     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
129         2, 1, 32, 34, 71, 48, 11, 23, 7, 8, 2, 1, 3, 3),
130     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
131         2, 1, 32, 6, 6, 48, 2, 2, 3, 3, 0, 0, 2, 2),
132     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
133         2, 1, 32, 9, 9, 48, 2, 2, 5, 5, 0, 0, 3, 3)
134 );
135
136 INST_TEST_CASE(SimpleSmall_Regression,
137     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
138         2, 1, 32, 16, 16, 32, 16, 16, 3, 3, 0, 0, 1, 1),
139     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
140         2, 1, 32, 28, 28, 32, 28, 28, 3, 3, 0, 0, 1, 1),
141     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
142         2, 1, 32, 32, 32, 32, 32, 32, 3, 3, 0, 0, 1, 1),
143     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
144         2, 1, 32, 3, 3, 32, 2, 2, 3, 3, 1, 1, 1, 1),
145     PARAMS(FMT_DATA_BLOCKED16, FMT_WEIGHTS_BLOCKED16, FMT_BIAS, FMT_DATA_BLOCKED16,
146         2, 1, 32, 34, 34, 32, 34, 34, 5, 5, 2, 2, 1, 1)
147 );
148
149 INST_TEST_CASE(SimpleSmall_Depthwise_Blocked,
150     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
151         2, 8, 8, 16, 16, 8, 16, 16, 3, 3, 0, 0, 1, 1),
152     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
153         2, 32, 32, 9, 9, 32, 2, 2, 5, 5, 0, 0, 3, 3),
154     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
155         2, 64, 64, 26, 26, 64, 13, 13, 4, 4, 1, 1, 2, 2),
156     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
157         2, 32, 32, 111, 111, 32, 112, 112, 1, 1, 0, 0, 1, 1),
158     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
159         1, 64, 64, 1, 2, 64, 1, 1, 3, 3, 1, 1, 1, 2),
160     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
161         1, 16, 16, 16, 32, 16, 16, 18, 3, 3, 1, 2, 1, 2),
162     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
163         1, 24, 24, 32, 16, 24, 16, 14, 3, 3, 1, 0, 2, 1),
164     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
165         1, 16, 16, 32, 16, 16, 18, 16, 3, 3, 2, 1, 2, 1),
166     PARAMS(FMT_DATA_BLOCKED, Goihw8g, FMT_BIAS, FMT_DATA_BLOCKED,
167         1, 8, 8, 500, 500, 8, 698, 698, 3, 3, 100, 100, 1, 1)
168 );
169
170 INST_TEST_CASE(SimpleSmall_Depthwise_Blocked16,
171     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
172         2, 16, 16, 16, 16, 16, 16, 16, 3, 3, 0, 0, 1, 1),
173     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
174         2, 32, 32, 9, 9, 32, 2, 2, 5, 5, 0, 0, 3, 3),
175     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
176         2, 64, 64, 26, 26, 64, 13, 13, 4, 4, 1, 1, 2, 2),
177     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
178         2, 32, 32, 111, 111, 32, 112, 112, 1, 1, 0, 0, 1, 1),
179     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
180         2, 64, 64, 1, 2, 64, 1, 1, 3, 3, 1, 1, 1, 2),
181     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
182         1, 32, 32, 16, 32, 32, 14, 16, 3, 3, 0, 1, 1, 2),
183     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
184         1, 16, 16, 16, 32, 16, 16, 18, 3, 3, 1, 2, 1, 2),
185     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
186         1, 32, 32, 32, 16, 32, 16, 14, 3, 3, 1, 0, 2, 1),
187     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
188         1, 16, 16, 32, 16, 16, 18, 16, 3, 3, 2, 1, 2, 1),
189     PARAMS(FMT_DATA_BLOCKED16, Goihw16g, FMT_BIAS, FMT_DATA_BLOCKED16,
190         1, 16, 16, 500, 500, 16, 698, 698, 3, 3, 100, 100, 1, 1)
191 );