From c521a052f9bf7e418011f98a662be2c434d350ba Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 20 Sep 2022 11:37:26 -0400 Subject: [PATCH] [mlir][arith] Add comparison-based integration tests Introduces a simple framework for runtime tests of the wide integer emulation. In these tests, we are only interested in checking that both wide and narrow calculation produce the same results, and do not check for exact results. This allows us to cover more of the input space, as we do not have to hardcode each of the expected outputs. Introduce common helper functions to check the results, print a message on mismatch, and sample the input space. Implement runtime comparrison tests for `arith.muli` and `arith.shrui`. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134184 --- ...est-wide-int-emulation-compare-results-i16.mlir | 167 +++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir new file mode 100644 index 0000000..2b1afb3 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir @@ -0,0 +1,167 @@ +// Check that the wide integer emulation produces the same result as wide +// calculations. Emulate i16 ops with i8 ops. + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs="%mlir_lib_dir/libmlir_c_runner_utils%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext" | \ +// RUN: FileCheck %s + +// CHECK-NOT: Mismatch + +//===----------------------------------------------------------------------===// +// Common Utility Functions +//===----------------------------------------------------------------------===// + +llvm.mlir.global internal constant @str_mismatch("Mismatch\0A") +func.func private @printCString(!llvm.ptr) -> () +// Prints 'Mismatch' to stdout. +func.func @printMismatch() -> () { + %0 = llvm.mlir.addressof @str_mismatch : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + func.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +// Prints both binary op operands and the first result. If the second result +// does not match, prints the second result and a 'Mismatch' message. +func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> () { + %vec_zero = arith.constant dense<0> : vector<2xi16> + %ins0 = vector.insert %lhs, %vec_zero[0] : i16 into vector<2xi16> + %operands = vector.insert %rhs, %ins0[1] : i16 into vector<2xi16> + vector.print %operands : vector<2xi16> + vector.print %res0 : i16 + %mismatch = arith.cmpi ne, %res0, %res1 : i16 + scf.if %mismatch -> () { + vector.print %res1 : i16 + func.call @printMismatch() : () -> () + } + return +} + +func.func @xorshift(%i : i16) -> (i16) { + %cst8 = arith.constant 8 : i16 + %shifted = arith.shrui %i, %cst8 : i16 + %res = arith.xori %i, %shifted : i16 + return %res : i16 +} + +// Returns a hash of the input number. This is used we want to sample a bunch +// of i16 inputs with close to uniform distribution but without fixed offsets +// between each sample. +func.func @xhash(%i : i16) -> (i16) { + %pattern = arith.constant 21845 : i16 // Alternating ones and zeros. + %prime = arith.constant 25867 : i16 // Large i16 prime. + %xi = func.call @xorshift(%i) : (i16) -> (i16) + %inner = arith.muli %xi, %pattern : i16 + %xinner = func.call @xorshift(%inner) : (i16) -> (i16) + %res = arith.muli %xinner, %prime : i16 + return %res : i16 +} + +//===----------------------------------------------------------------------===// +// Test arith.muli +//===----------------------------------------------------------------------===// + +// Ops in this function will be emulated using i8 ops. +func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.muli %lhs, %rhs : i16 + return %res : i16 +} + +// Performs both wide and emulated `arith.muli`, and checks that the results +// match. +func.func @check_muli(%lhs : i16, %rhs : i16) -> () { + %wide = arith.muli %lhs, %rhs : i16 + %emulated = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16) + func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () + return +} + +// Checks that `arith.muli` is emulated properly by sampling the input space. +// In total, this test function checks 500 * 500 = 250k input pairs. +func.func @test_muli() -> () { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx500 = arith.constant 500 : index + + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + + scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) { + %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) + + scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) { + %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16) + func.call @check_muli(%arg_lhs, %arg_rhs) : (i16, i16) -> () + + %rhs_next = arith.addi %rhs, %cst1 : i16 + scf.yield %rhs_next : i16 + } + + %lhs_next = arith.addi %lhs, %cst1 : i16 + scf.yield %lhs_next : i16 + } + + return +} + +//===----------------------------------------------------------------------===// +// Test arith.shrui +//===----------------------------------------------------------------------===// + +// Ops in this function will be emulated using i8 ops. +func.func @emulate_shrui(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.shrui %lhs, %rhs : i16 + return %res : i16 +} + +// Performs both wide and emulated `arith.shrui`, and checks that the results +// match. +func.func @check_shrui(%lhs : i16, %rhs : i16) -> () { + %wide = arith.shrui %lhs, %rhs : i16 + %emulated = func.call @emulate_shrui(%lhs, %rhs) : (i16, i16) -> (i16) + func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () + return +} + +// Checks that `arith.shrui` is emulated properly by sampling the input space. +// Checks all valid shift amounts for i16: 0 to 15. +// In total, this test function checks 100 * 16 = 1.6k input pairs. +func.func @test_shrui() -> () { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx16 = arith.constant 16 : index + %idx100 = arith.constant 100 : index + + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + + scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { + %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) + + scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { + func.call @check_shrui(%arg_lhs, %rhs) : (i16, i16) -> () + %rhs_next = arith.addi %rhs, %cst1 : i16 + scf.yield %rhs_next : i16 + } + + %lhs_next = arith.addi %lhs, %cst1 : i16 + scf.yield %lhs_next : i16 + } + + return +} + +//===----------------------------------------------------------------------===// +// Entry Point +//===----------------------------------------------------------------------===// + +func.func @entry() { + func.call @test_muli() : () -> () + func.call @test_shrui() : () -> () + return +} -- 2.7.4