Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / conv_common.hpp
index d3969ec..624338e 100644 (file)
 #include "mkldnn_common.hpp"
 #include "mkldnn_memory.hpp"
 
+namespace deconv {
+/* some extra control parameters which shouldn't be placed in prb_t */
+extern const char *skip_impl; /* NULL or "" means do not skip anything */
+extern bool allow_unimpl; /* true means do not treat unimplemented as error */
+extern const char *perf_template; /* performance output template */
+}
+
 namespace conv {
 
-enum alg_t { DIRECT, WINO };
+enum alg_t { DIRECT, WINO, AUTO };
 alg_t str2alg(const char *str);
 const char *alg2str(alg_t alg);
-
-enum merge_t { NONE, RELU, };
-merge_t str2merge(const char *str);
-const char *merge2str(merge_t merge);
+alg_t alg_kind2alg(mkldnn_alg_kind_t alg);
 
 struct desc_t {
     int g, mb;
@@ -44,6 +48,7 @@ struct desc_t {
     int sd, sh, sw;
     int pd, ph, pw;
     int dd, dh, dw;
+    bool has_groups;
 
     const char *name;
 };
@@ -95,12 +100,13 @@ extern const _dt_conf_t conf_u8s8u8s32_wino;
 
 const dt_conf_t *str2cfg(const char *str);
 const char *cfg2str(const dt_conf_t *cfg);
+const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg);
 
 struct prb_t: public desc_t {
     prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg, alg_t alg,
-            merge_t merge, const attr_t &attr, int mb = 0)
-        : desc_t(desc), dir(dir), cfg(cfg), alg(alg), merge(merge), attr(attr)
-        , ops(0), scales(NULL) {
+            const attr_t &attr, int mb = 0, bool is_deconv = false)
+        : desc_t(desc), dir(dir), cfg(cfg), alg(alg), attr(attr)
+        , ops(0), scales(NULL), is_deconv(is_deconv) {
         if (mb) this->mb = mb;
         count_ops();
         generate_oscales();
@@ -110,11 +116,11 @@ struct prb_t: public desc_t {
     dir_t dir;
     const dt_conf_t *cfg;
     alg_t alg;
-    merge_t merge;
     attr_t attr;
 
     double ops;
     float *scales;
+    bool is_deconv;
 
     void count_ops();
     void generate_oscales();