namespace impl {
namespace cpu {
-template <bool with_relu>
-struct _cpu_convolution_fwd_pd_t: public _convolution_fwd_pd_t<with_relu> {
+struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
using cpu_memory_pd_t = cpu_memory_t::pd_t;
- _cpu_convolution_fwd_pd_t(engine_t *engine,
- const typename _cpu_convolution_fwd_pd_t::base_desc_t *adesc,
+ cpu_convolution_fwd_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
const primitive_attr_t *attr,
- const typename _cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
- : _convolution_fwd_pd_t<with_relu>(engine, adesc, attr, hint_fwd_pd)
- , src_pd_(this->engine_, &this->cdesc_().src_desc)
- , dst_pd_(this->engine_, &this->cdesc_().dst_desc)
- , weights_pd_(this->engine_, &this->cdesc_().weights_desc)
- , bias_pd_(this->engine_, &this->cdesc_().bias_desc) {}
- virtual ~_cpu_convolution_fwd_pd_t() {}
+ const typename cpu_convolution_fwd_pd_t::base_class *hint_fwd_pd)
+ : convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_pd_(this->engine_, &this->desc()->src_desc)
+ , dst_pd_(this->engine_, &this->desc()->dst_desc)
+ , weights_pd_(this->engine_, &this->desc()->weights_desc)
+ , bias_pd_(this->engine_, &this->desc()->bias_desc) {}
+ virtual ~cpu_convolution_fwd_pd_t() {}
virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
{ return index == 0 ? &src_pd_ : nullptr; }
return nullptr;
}
- bool want_padded_bias() const {
- if (!this->with_bias()) return false;
+ bool has_padded_dst() const {
memory_desc_wrapper dst_d(&dst_pd_);
if (!dst_d.is_blocking_desc()) return false;
return this->OC() != dst_d.blocking_desc().padding_dims[1];
}
+ bool wants_padded_bias() const {
+ if (!this->with_bias()) return false;
+ return has_padded_dst();
+ }
+
+ bool wants_zero_pad_dst(bool jit_impl = true) const {
+ if (!has_padded_dst()) return false;
+ const auto &po = this->attr()->post_ops_;
+ int idx;
+ if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
+ return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
+ jit_impl);
+ }
+
protected:
cpu_memory_pd_t src_pd_, dst_pd_;
cpu_memory_pd_t weights_pd_, bias_pd_;
inline memory_format_t src_format()
{
using namespace memory_format;
- return utils::pick(this->cdesc_().src_desc.ndims - 3, ncw, nchw, ncdhw);
+ return utils::pick(this->desc()->src_desc.ndims - 3, ncw, nchw, ncdhw);
}
inline memory_format_t wei_format()
{
using namespace memory_format;
return this->with_groups()
- ? utils::pick(this->cdesc_().src_desc.ndims - 3, goiw, goihw, goidhw)
- : utils::pick(this->cdesc_().src_desc.ndims - 3, oiw, oihw, oidhw);
+ ? utils::pick(this->desc()->src_desc.ndims - 3, goiw, goihw, goidhw)
+ : utils::pick(this->desc()->src_desc.ndims - 3, oiw, oihw, oidhw);
}
virtual status_t set_default_params() {
CHECK(weights_pd_.set_format(wei_format()));
if (bias_pd_.desc()->format == any)
CHECK(bias_pd_.set_format(x));
+ if (this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
-using cpu_convolution_fwd_pd_t = _cpu_convolution_fwd_pd_t<false>;
-using cpu_convolution_relu_fwd_pd_t = _cpu_convolution_fwd_pd_t<true>;
-
struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
using cpu_memory_pd_t = cpu_memory_t::pd_t;
CHECK(weights_pd_.set_format(wei_format()));
if (bias_pd_.desc()->format == any)
CHECK(bias_pd_.set_format(x));
+ if (this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
return nullptr;
}
- bool want_padded_bias() const {
+ bool wants_padded_bias() const {
if (!this->with_bias()) return false;
memory_desc_wrapper diff_dst_d(&diff_dst_pd_);
if (!diff_dst_d.is_blocking_desc()) return false;
CHECK(diff_weights_pd_.set_format(wei_format()));
if (diff_bias_pd_.desc()->format == any)
CHECK(diff_bias_pd_.set_format(x));
+ if (this->desc()->alg_kind == alg_kind::convolution_auto)
+ CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};