[mlir][test] Fix how the number of flops is calculated
authorAndrzej Warzynski <andrzej.warzynski@gmail.com>
Mon, 21 Nov 2022 07:15:26 +0000 (12:45 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Mon, 21 Nov 2022 08:10:24 +0000 (13:40 +0530)
Make sure that the number of repetitions is correctly incorporated when
calculating the number of floating point operations.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D138382

mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir

index 5ebafbe..8a427dd 100644 (file)
@@ -10,10 +10,10 @@ func.func @main() {
   linalg.fill ins(%cf1 : f32) outs(%A : memref<16x16xf32>)
   linalg.fill ins(%cf1 : f32) outs(%B : memref<16x16xf32>)
 
-  %reps = arith.constant 1 : index
+  %num_reps = arith.constant 5 : index
 
   %t_start = call @rtclock() : () -> f64
-  affine.for %arg0 = 0 to 5 {
+  affine.for %arg0 = 0 to %num_reps {
     linalg.fill ins(%cf1 : f32) outs(%C : memref<16x16xf32>)
     func.call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> ()
   }
@@ -31,16 +31,19 @@ func.func @main() {
   %N = memref.dim %C, %c1 : memref<16x16xf32>
   %K = memref.dim %A, %c1 : memref<16x16xf32>
 
+  // num_flops_per_iter = 2*M*N*K
   %f1 = arith.muli %M, %N : index
   %f2 = arith.muli %f1, %K : index
+  %num_flops_per_iter = arith.muli %c2, %f2 : index
 
-  // 2*M*N*K.
-  %f3 = arith.muli %c2, %f2 : index
-  %num_flops = arith.muli %reps, %f3 : index
-  %num_flops_i = arith.index_cast %num_flops : index to i16
-  %num_flops_f = arith.sitofp %num_flops_i : i16 to f64
-  %flops = arith.divf %num_flops_f, %t : f64
-  call @printFlops(%flops) : (f64) -> ()
+  // num_flops_total = num_flops_per_iter * num_reps
+  %num_flops_total = arith.muli %num_flops_per_iter, %num_reps: index
+
+  // Print the number of flops per second
+  %num_flops_total_i = arith.index_cast %num_flops_total : index to i16
+  %num_flops_total_f = arith.uitofp %num_flops_total_i : i16 to f64
+  %flops_per_s = arith.divf %num_flops_total_f, %t : f64
+  call @printFlops(%flops_per_s) : (f64) -> ()
 
   memref.dealloc %A : memref<16x16xf32>
   memref.dealloc %B : memref<16x16xf32>