arm_compute v17.04
[platform/upstream/armcl.git] / src / core / CL / kernels / CLGEMMMatrixMultiplyKernel.cpp
index 8edaf93..b22f059 100644 (file)
@@ -63,7 +63,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
 
     if(output->info()->dimension(1) == 196)
     {
-        _lws_hint = cl::NDRange(2, 7);
+        _lws_hint = cl::NDRange(1, 7);
     }
     else
     {
@@ -85,15 +85,19 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
         std::string data_type_name = lower_string(string_from_data_type(input0->info()->data_type()));
         _kernel                    = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_vm_" + data_type_name), build_opts));
 
-        const unsigned int processed_elements_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
-
         // Configure window kernel
-        Window                win = calculate_max_window(*output->info(), Steps(processed_elements_x));
-        AccessWindowRectangle input0_access(input0->info(), 0, 0, processed_elements_x, 1);
-        AccessWindowRectangle input1_access(input1->info(), 0, 0, processed_elements_x, 1);
-        AccessWindowRectangle output_access(output->info(), 0, 0, processed_elements_x, 1);
+        const unsigned int num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
+
+        Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
+
+        AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_processed_per_iteration_x, 1);
+        AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_processed_per_iteration_x, 1);
+        AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, 1);
+
         update_window_and_padding(win, input0_access, input1_access, output_access);
+
         output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
+
         ICLKernel::configure(win);
     }
     else
@@ -104,16 +108,20 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
         std::string data_type_name = lower_string(string_from_data_type(input0->info()->data_type()));
         _kernel                    = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_mm_" + data_type_name), build_opts));
 
-        const unsigned int     processed_elements_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
-        constexpr unsigned int processed_elements_y = 4;
-
         // Configure window kernel
-        Window                win = calculate_max_window(*output->info(), Steps(processed_elements_x, processed_elements_y));
-        AccessWindowRectangle input0_access(input0->info(), 0, 0, processed_elements_y, 1);
-        AccessWindowRectangle input1_access(input1->info(), 0, 0, processed_elements_x, 1);
-        AccessWindowRectangle output_access(output->info(), 0, 0, processed_elements_x, processed_elements_y);
+        const unsigned int     num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(input0->info()->data_type());
+        constexpr unsigned int num_elems_processed_per_iteration_y = 4;
+
+        Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+        AccessWindowRectangle input0_access(input0->info(), 0, 0, num_elems_processed_per_iteration_y, 1);
+        AccessWindowRectangle input1_access(input1->info(), 0, 0, num_elems_processed_per_iteration_x, 1);
+        AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+
         update_window_and_padding(win, input0_access, input1_access, output_access);
+
         output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
+
         ICLKernel::configure(win);
     }
 }