#include "mkldnn_common.hpp"
#include "mkldnn_memory.hpp"
+namespace deconv {
+/* some extra control parameters which shouldn't be placed in prb_t */
+extern const char *skip_impl; /* NULL or "" means do not skip anything */
+extern bool allow_unimpl; /* true means do not treat unimplemented as error */
+extern const char *perf_template; /* performance output template */
+}
+
namespace conv {
-enum alg_t { DIRECT, WINO };
+enum alg_t { DIRECT, WINO, AUTO };
alg_t str2alg(const char *str);
const char *alg2str(alg_t alg);
-
-enum merge_t { NONE, RELU, };
-merge_t str2merge(const char *str);
-const char *merge2str(merge_t merge);
+alg_t alg_kind2alg(mkldnn_alg_kind_t alg);
struct desc_t {
int g, mb;
int sd, sh, sw;
int pd, ph, pw;
int dd, dh, dw;
+ bool has_groups;
const char *name;
};
const dt_conf_t *str2cfg(const char *str);
const char *cfg2str(const dt_conf_t *cfg);
+const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg);
struct prb_t: public desc_t {
prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg, alg_t alg,
- merge_t merge, const attr_t &attr, int mb = 0)
- : desc_t(desc), dir(dir), cfg(cfg), alg(alg), merge(merge), attr(attr)
- , ops(0), scales(NULL) {
+ const attr_t &attr, int mb = 0, bool is_deconv = false)
+ : desc_t(desc), dir(dir), cfg(cfg), alg(alg), attr(attr)
+ , ops(0), scales(NULL), is_deconv(is_deconv) {
if (mb) this->mb = mb;
count_ops();
generate_oscales();
dir_t dir;
const dt_conf_t *cfg;
alg_t alg;
- merge_t merge;
attr_t attr;
double ops;
float *scales;
+ bool is_deconv;
void count_ops();
void generate_oscales();