const layout& weights_layout,
std::shared_ptr<const convolution> conv) {
// A set of rules that define when bs_fs_yx_bsv16_fsv16 mem format can be used
- bool correct_batch = input_layout.size.batch[0] >= 16;
+ bool correct_batch = input_layout.size.batch[0] > 16;
bool correct_feature = (input_layout.size.feature[0] % 16 == 0 || input_layout.size.feature[0] == 3) && conv->output_size.feature[0] % 16 == 0;
bool fp16_ver = input_layout.data_type == data_types::f16 && input_layout.size.batch[0] % 32 == 0;
bool fp32_ver = input_layout.data_type == data_types::f32 && input_layout.size.batch[0] % 16 == 0;
bool single_group = conv->groups == 1;
- bool int8_sup = (input_layout.data_type == data_types::i8 || input_layout.data_type == data_types::u8) &&
- input_layout.size.batch[0] % 16 == 0 && weights_layout.data_type == data_types::i8 &&
- (conv->activations_zero_points.empty() && conv->weights_zero_points.empty());
+
+ bool int8_sup = (input_layout.data_type == data_types::i8 || input_layout.data_type == data_types::u8);
+ if (int8_sup)
+ correct_batch = input_layout.size.batch[0] >= 16;
+ int8_sup &= (input_layout.size.batch[0] % 16 == 0 && weights_layout.data_type == data_types::i8 &&
+ conv->activations_zero_points.empty() && conv->weights_zero_points.empty());
auto ks_x = weights_layout.size.spatial[0];
auto ks_y = weights_layout.size.spatial[1];
int8_sup &= (input_layout.size.spatial[2] == 1 && ((ks_x == 1 && ks_y == 1) || (ks_x == 3 && ks_y == 3) || (ks_x == 7 && ks_y == 7)) &&
- input_layout.size.batch[0] % 16 == 0 && weights_layout.size.batch[0] % 32 == 0 && conv->groups == 1 &&
- conv->split() == 1 && conv->dilation == tensor{1});
+ weights_layout.size.batch[0] % 32 == 0 && conv->split() == 1 && conv->dilation == tensor{1});
return (int8_sup || fp16_ver || fp32_ver) && correct_feature && correct_batch && single_group;
}