From 7d299333cbce18c37fdef94569720cf5c61c760a Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 20 Sep 2022 18:51:20 -0400 Subject: [PATCH] [mlir][arith] Add integration tests for addi emulation This includes tests with the exact expected values and comparison-based tests. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134321 --- .../CPU/test-wide-int-emulation-addi-i16.mlir | 85 ++++++++++++++++++++++ ...est-wide-int-emulation-compare-results-i16.mlir | 48 ++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-addi-i16.mlir diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-addi-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-addi-i16.mlir new file mode 100644 index 0000000..96a5469 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-addi-i16.mlir @@ -0,0 +1,85 @@ +// Check that the wide integer addition emulation produces the same result as +// wide addition. Emulate i16 ops with i8 ops. + +// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// Ops in this function *only* will be emulated using i8 types. +func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.addi %lhs, %rhs : i16 + return %res : i16 +} + +func.func @check_addi(%lhs : i16, %rhs : i16) -> () { + %res = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16) + vector.print %res : i16 + return +} + +func.func @entry() { + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + %cst_1 = arith.constant -1 : i16 + %cst_3 = arith.constant -3 : i16 + + %cst13 = arith.constant 13 : i16 + %cst37 = arith.constant 37 : i16 + %cst42 = arith.constant 42 : i16 + + %cst256 = arith.constant 256 : i16 + %cst_i16_max = arith.constant 32767 : i16 + %cst_i16_min = arith.constant -32768 : i16 + + // CHECK: 0 + func.call @check_addi(%cst0, %cst0) : (i16, i16) -> () + // CHECK-NEXT: 1 + func.call @check_addi(%cst0, %cst1) : (i16, i16) -> () + // CHECK-NEXT: 2 + func.call @check_addi(%cst1, %cst1) : (i16, i16) -> () + // CHECK-NEXT: 0 + func.call @check_addi(%cst1, %cst_1) : (i16, i16) -> () + // CHECK-NEXT: -2 + func.call @check_addi(%cst_1, %cst_1) : (i16, i16) -> () + // CHECK-NEXT: -2 + func.call @check_addi(%cst1, %cst_3) : (i16, i16) -> () + + // CHECK-NEXT: 26 + func.call @check_addi(%cst13, %cst13) : (i16, i16) -> () + // CHECK-NEXT: 50 + func.call @check_addi(%cst13, %cst37) : (i16, i16) -> () + // CHECK-NEXT: 79 + func.call @check_addi(%cst37, %cst42) : (i16, i16) -> () + + // CHECK-NEXT: 255 + func.call @check_addi(%cst_1, %cst256) : (i16, i16) -> () + // CHECK-NEXT: 269 + func.call @check_addi(%cst256, %cst13) : (i16, i16) -> () + // CHECK-NEXT: 293 + func.call @check_addi(%cst256, %cst37) : (i16, i16) -> () + // CHECK-NEXT: 253 + func.call @check_addi(%cst256, %cst_3) : (i16, i16) -> () + + // CHECK-NEXT: -32756 + func.call @check_addi(%cst13, %cst_i16_max) : (i16, i16) -> () + // CHECK-NEXT: -32731 + func.call @check_addi(%cst_i16_min, %cst37) : (i16, i16) -> () + + // CHECK-NEXT: -2 + func.call @check_addi(%cst_i16_max, %cst_i16_max) : (i16, i16) -> () + // CHECK-NEXT: -32755 + func.call @check_addi(%cst_i16_min, %cst13) : (i16, i16) -> () + // CHECK-NEXT: 0 + func.call @check_addi(%cst_i16_min, %cst_i16_min) : (i16, i16) -> () + + return +} diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir index 2b1afb3..6ca2790 100644 --- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir +++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir @@ -63,6 +63,53 @@ func.func @xhash(%i : i16) -> (i16) { } //===----------------------------------------------------------------------===// +// Test arith.addi +//===----------------------------------------------------------------------===// + +// Ops in this function will be emulated using i8 ops. +func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.addi %lhs, %rhs : i16 + return %res : i16 +} + +// Performs both wide and emulated `arith.muli`, and checks that the results +// match. +func.func @check_addi(%lhs : i16, %rhs : i16) -> () { + %wide = arith.addi %lhs, %rhs : i16 + %emulated = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16) + func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () + return +} + +// Checks that `arith.addi` is emulated properly by sampling the input space. +// In total, this test function checks 500 * 500 = 250k input pairs. +func.func @test_addi() -> () { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx500 = arith.constant 500 : index + + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + + scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) { + %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) + + scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) { + %arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16) + func.call @check_addi(%arg_lhs, %arg_rhs) : (i16, i16) -> () + + %rhs_next = arith.addi %rhs, %cst1 : i16 + scf.yield %rhs_next : i16 + } + + %lhs_next = arith.addi %lhs, %cst1 : i16 + scf.yield %lhs_next : i16 + } + + return +} + +//===----------------------------------------------------------------------===// // Test arith.muli //===----------------------------------------------------------------------===// @@ -161,6 +208,7 @@ func.func @test_shrui() -> () { //===----------------------------------------------------------------------===// func.func @entry() { + func.call @test_addi() : () -> () func.call @test_muli() : () -> () func.call @test_shrui() : () -> () return -- 2.7.4