[bugfix] Fix sgemv_cl function call from blas_kernel_interface
authorDebadri Samaddar <s.debadri@samsung.com>
Wed, 18 Sep 2024 11:00:20 +0000 (16:30 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 14 Oct 2024 07:33:27 +0000 (16:33 +0900)
Fixed sgemv_cl function call. Failing unittest after recent changes.

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/cl_operations/blas_kernel_interface.cpp

index 466cd66e3b61a2d461a05a5bd41c7552292713cc..23af3f9799575c910903af1262ba3860c6dece88 100644 (file)
@@ -138,8 +138,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans,
     /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
     /// Effectively a translation of sgemv
     else if (M == 1) {
-      trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb)
-              : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb);
+      trans_m ? sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb)
+              : sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb);
     }
     /// case others: use gemm
     else {
@@ -170,8 +170,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans,
     /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
     /// Effectively a translation of sgemv
     else if (M == 1) {
-      trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb)
-              : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb);
+      trans_m ? sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb)
+              : sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb);
     }
     /// case others: use sgemm
     else {