ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0));
ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1));
- _input = input;
- _output = output;
- const unsigned int processed_elements = max_cl_vector_width / data_size_from_type(input->info()->data_type());
+ _input = input;
+ _output = output;
+ const unsigned int num_elems_processed_per_iteration = max_cl_vector_width / data_size_from_type(input->info()->data_type());
std::ostringstream ma_arguments;
ma_arguments << "-DBETA=" << beta;
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_ma_" + data_type_name), build_opts));
// Configure kernel window
- Window win = calculate_max_window(*_input->info(), Steps(processed_elements));
- AccessWindowHorizontal input_access(input->info(), 0, processed_elements);
- AccessWindowHorizontal output_access(output->info(), 0, processed_elements);
+ Window win = calculate_max_window(*_input->info(), Steps(num_elems_processed_per_iteration));
+
+ AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
+
update_window_and_padding(win, input_access, output_access);
+
output_access.set_valid_region(win, input->info()->valid_region());
+
ICLKernel::configure(win);
}