: binarization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_pd_(engine_, &desc_.src_desc)
, dst_pd_(engine_, &desc_.dst_desc)
- , weights_pd_(engine_, &desc_.weights_desc) {}
+ , weights_pd_(engine_, &desc_.weights_desc)
+ , output_mask_pd_(engine_, &desc_.output_mask_desc) {}
virtual ~cpu_binarization_fwd_pd_t() {}
virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
{ return index == 0 ? &dst_pd_ : nullptr; }
virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
if (index == 0) return &weights_pd_;
+ if (index == 1) return &output_mask_pd_;
return nullptr;
}
protected:
- cpu_memory_pd_t src_pd_, dst_pd_, weights_pd_;
+ cpu_memory_pd_t src_pd_, dst_pd_, weights_pd_, output_mask_pd_;
inline memory_format_t src_format()
{
CHECK(dst_pd_.set_format(src_pd_.desc()->format));
if (weights_pd_.desc()->format == any)
CHECK(weights_pd_.set_format(wei_format()));
+ if (output_mask_pd_.desc()->format == any)
+ CHECK(output_mask_pd_.set_format(wei_format()));
return status::success;
}