virtual const memory_pd_t *input_pd(int index = 0) const override {
switch (index) {
case 0: return src_pd();
- case 1: return weights_pd(index - 1);
+ case 1: case 2: return weights_pd(index - 1);
default: return nullptr;
}
}
virtual const memory_pd_t *output_pd(int index = 0) const override
{ return index == 0 ? dst_pd() : nullptr; }
- virtual int n_inputs() const override { return 2; }
+ virtual int n_inputs() const override { return 3; }
virtual int n_outputs() const override { return 1; }
virtual status_t query(query_t what, int idx, void *result) const override