From 242d558658cd5a480b02883e2982d7246342e0d0 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 20 Sep 2022 11:03:37 -0400 Subject: [PATCH] [mlir][arith] Add test pass for wide integer emulation The new test pass allows for running wide integer emulation conversion within specified functions only. I intend to use it in integration tests in a way that allows me print both original and emulated results in the same format, or even compare both results at runtime and print on mismatch only. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134120 --- .../Arithmetic/test-emulate-wide-int-pass.mlir | 39 +++++++++ .../CPU/test-wide-int-emulation-constants-i16.mlir | 25 +++++- .../CPU/test-wide-int-emulation-muli-i16.mlir | 68 +++++++--------- mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt | 14 ++++ .../lib/Dialect/Arithmetic/TestEmulateWideInt.cpp | 95 ++++++++++++++++++++++ mlir/test/lib/Dialect/CMakeLists.txt | 1 + mlir/tools/mlir-opt/CMakeLists.txt | 1 + mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 8 files changed, 203 insertions(+), 42 deletions(-) create mode 100644 mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir create mode 100644 mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt create mode 100644 mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp diff --git a/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir new file mode 100644 index 0000000..bc6151e --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/test-emulate-wide-int-pass.mlir @@ -0,0 +1,39 @@ +// Check that the test version of the wide integer emulation pass applies +// conversion to functions whose name start with a given prefix only, and that +// the function signatures are preserved. + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="function-prefix=emulate_me_" | FileCheck %s + +// CHECK-LABEL: func.func @entry() +// CHECK: {{%.+}} = call @emulate_me_please({{.+}}) : (i64) -> i64 +// CHECK-NEXT: {{%.+}} = call @foo({{.+}}) : (i64) -> i64 +func.func @entry() { + %cst0 = arith.constant 0 : i64 + func.call @emulate_me_please(%cst0) : (i64) -> (i64) + func.call @foo(%cst0) : (i64) -> (i64) + return +} + +// CHECK-LABEL: func.func @emulate_me_please +// CHECK-SAME: ([[ARG:%.+]]: i64) -> i64 { +// CHECK-NEXT: [[BCAST0:%.+]] = llvm.bitcast [[ARG]] : i64 to vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32> +// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1 +// CHECK: [[RES:%.+]] = llvm.bitcast {{%.+}} : vector<2xi32> to i64 +// CHECK-NEXt: return [[RES]] : i64 +func.func @emulate_me_please(%x : i64) -> i64 { + %r = arith.addi %x, %x : i64 + return %r : i64 +} + +// CHECK-LABEL: func.func @foo +// CHECK-SAME: ([[ARG:%.+]]: i64) -> i64 { +// CHECK-NEXT: [[RES:%.+]] = arith.addi [[ARG]], [[ARG]] : i64 +// CHECK-NEXT: return [[RES]] : i64 +func.func @foo(%x : i64) -> i64 { + %r = arith.addi %x, %x : i64 + return %r : i64 +} diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir index 8cc5ceb..22ef5d4 100644 --- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir +++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-constants-i16.mlir @@ -1,7 +1,7 @@ // Check that the wide integer constant emulation produces the same result as wide // constants and that printing works. Emulate i16 ops with i8 ops. -// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ // RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ // RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ @@ -10,6 +10,16 @@ func.func @entry() { %cst0 = arith.constant 0 : i16 + func.call @emulate_constant(%cst0) : (i16) -> () + func.call @foo(%cst0) : (i16) -> () + return +} + +func.func @emulate_constant(%first : i16) { + // EMULATED: ( 0, 0 ) + vector.print %first : i16 + + %cst0 = arith.constant 0 : i16 %cst1 = arith.constant 1 : i16 %cst_1 = arith.constant -1 : i16 %cst_3 = arith.constant -3 : i16 @@ -20,7 +30,7 @@ func.func @entry() { %cst_i16_max = arith.constant 32767 : i16 %cst_i16_min = arith.constant -32768 : i16 - // EMULATED: ( 0, 0 ) + // EMULATED-NEXT: ( 0, 0 ) vector.print %cst0 : i16 // EMULATED-NEXT: ( 1, 0 ) vector.print %cst1 : i16 @@ -39,6 +49,17 @@ func.func @entry() { vector.print %cst_i16_max : i16 // EMULATED-NEXT: ( 0, -128 ) vector.print %cst_i16_min : i16 + return +} +func.func @foo(%first: i16) { + // These should not be emulated because the function name does not start with + // 'emulated_'. + + // EMULATED-NEXT: 0 + vector.print %first : i16 + // EMULATED-NEXT: 1 + %cst1 = arith.constant 1 : i16 + vector.print %cst1 : i16 return } diff --git a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir index 7a56ed9..976e28f 100644 --- a/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir +++ b/mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-muli-i16.mlir @@ -5,17 +5,23 @@ // RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s --match-full-lines --check-prefix=WIDE +// RUN: FileCheck %s --match-full-lines -// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ // RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ // RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s --match-full-lines --check-prefix=EMULATED +// RUN: FileCheck %s --match-full-lines -func.func @check_muli(%lhs : i16, %rhs : i16) -> () { +// Ops in this function *only* will be emulated using i8 types. +func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) { %res = arith.muli %lhs, %rhs : i16 + return %res : i16 +} + +func.func @check_muli(%lhs : i16, %rhs : i16) -> () { + %res = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16) vector.print %res : i16 return } @@ -34,63 +40,45 @@ func.func @entry() { %cst_i16_max = arith.constant 32767 : i16 %cst_i16_min = arith.constant -32768 : i16 - // WIDE: 0 - // EMULATED: ( 0, 0 ) + // CHECK: 0 func.call @check_muli(%cst0, %cst0) : (i16, i16) -> () - // WIDE-NEXT: 0 - // EMULATED-NEXT: ( 0, 0 ) + // CHECK-NEXT: 0 func.call @check_muli(%cst0, %cst1) : (i16, i16) -> () - // WIDE-NEXT: 1 - // EMULATED-NEXT: ( 1, 0 ) + // CHECK-NEXT: 1 func.call @check_muli(%cst1, %cst1) : (i16, i16) -> () - // WIDE-NEXT: -1 - // EMULATED-NEXT: ( -1, -1 ) + // CHECK-NEXT: -1 func.call @check_muli(%cst1, %cst_1) : (i16, i16) -> () - // WIDE-NEXT: 1 - // EMULATED-NEXT: ( 1, 0 ) + // CHECK-NEXT: 1 func.call @check_muli(%cst_1, %cst_1) : (i16, i16) -> () - // WIDE-NEXT: -3 - // EMULATED-NEXT: ( -3, -1 ) + // CHECK-NEXT: -3 func.call @check_muli(%cst1, %cst_3) : (i16, i16) -> () - // WIDE-NEXT: 169 - // EMULATED-NEXT: ( -87, 0 ) + // CHECK-NEXT: 169 func.call @check_muli(%cst13, %cst13) : (i16, i16) -> () - // WIDE-NEXT: 481 - // EMULATED-NEXT: ( -31, 1 ) + // CHECK-NEXT: 481 func.call @check_muli(%cst13, %cst37) : (i16, i16) -> () - // WIDE-NEXT: 1554 - // EMULATED-NEXT: ( 18, 6 ) + // CHECK-NEXT: 1554 func.call @check_muli(%cst37, %cst42) : (i16, i16) -> () - // WIDE-NEXT: -256 - // EMULATED-NEXT: ( 0, -1 ) + // CHECK-NEXT: -256 func.call @check_muli(%cst_1, %cst256) : (i16, i16) -> () - // WIDE-NEXT: 3328 - // EMULATED-NEXT: ( 0, 13 ) + // CHECK-NEXT: 3328 func.call @check_muli(%cst256, %cst13) : (i16, i16) -> () - // WIDE-NEXT: 9472 - // EMULATED-NEXT: ( 0, 37 ) + // CHECK-NEXT: 9472 func.call @check_muli(%cst256, %cst37) : (i16, i16) -> () - // WIDE-NEXT: -768 - // EMULATED-NEXT: ( 0, -3 ) + // CHECK-NEXT: -768 func.call @check_muli(%cst256, %cst_3) : (i16, i16) -> () - // WIDE-NEXT: 32755 - // EMULATED-NEXT: ( -13, 127 ) + // CHECK-NEXT: 32755 func.call @check_muli(%cst13, %cst_i16_max) : (i16, i16) -> () - // WIDE-NEXT: -32768 - // EMULATED-NEXT: ( 0, -128 ) + // CHECK-NEXT: -32768 func.call @check_muli(%cst_i16_min, %cst37) : (i16, i16) -> () - // WIDE-NEXT: 1 - // EMULATED-NEXT: ( 1, 0 ) + // CHECK-NEXT: 1 func.call @check_muli(%cst_i16_max, %cst_i16_max) : (i16, i16) -> () - // WIDE-NEXT: -32768 - // EMULATED-NEXT: ( 0, -128 ) + // CHECK-NEXT: -32768 func.call @check_muli(%cst_i16_min, %cst13) : (i16, i16) -> () - // WIDE-NEXT: 0 - // EMULATED-NEXT: ( 0, 0 ) + // CHECK-NEXT: 0 func.call @check_muli(%cst_i16_min, %cst_i16_min) : (i16, i16) -> () return diff --git a/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt new file mode 100644 index 0000000..17d288e --- /dev/null +++ b/mlir/test/lib/Dialect/Arithmetic/CMakeLists.txt @@ -0,0 +1,14 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRArithmeticTestPasses + TestEmulateWideInt.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRArithmeticDialect + MLIRArithmeticTransforms + MLIRFuncDialect + MLIRLLVMDialect + MLIRPass + MLIRVectorDialect +) diff --git a/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp new file mode 100644 index 0000000..7cf76ad --- /dev/null +++ b/mlir/test/lib/Dialect/Arithmetic/TestEmulateWideInt.cpp @@ -0,0 +1,95 @@ +//===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for integration testing of wide integer +// emulation patterns. Applies conversion patterns only to functions whose +// names start with a specified prefix. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct TestEmulateWideIntPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass) + + TestEmulateWideIntPass() = default; + TestEmulateWideIntPass(const TestEmulateWideIntPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-arith-emulate-wide-int"; } + StringRef getDescription() const final { + return "Function pass to test Wide Integer Emulation"; + } + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { + signalPassFailure(); + return; + } + + func::FuncOp op = getOperation(); + if (!op.getSymName().startswith(testFunctionPrefix)) + return; + + MLIRContext *ctx = op.getContext(); + arith::WideIntEmulationConverter typeConverter(widestIntSupported); + + // Use `llvm.bitcast` as the bridge so that we can use preserve the + // function argument and return types of the processed function. + // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector + // casts (and vice versa) and using it insted of `llvm.bitcast`. + auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Optional { + auto cast = builder.create(loc, type, inputs); + return cast->getResult(0); + }; + typeConverter.addSourceMaterialization(addBitcast); + typeConverter.addTargetMaterialization(addBitcast); + + ConversionTarget target(*ctx); + target.addDynamicallyLegalDialect( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + arith::populateWideIntEmulationPatterns(typeConverter, patterns); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } + + Option testFunctionPrefix{ + *this, "function-prefix", + llvm::cl::desc("Prefix of functions to run the emulation pass on"), + llvm::cl::init("emulate_")}; + Option widestIntSupported{ + *this, "widest-int-supported", + llvm::cl::desc("Maximum integer bit width supported by the target"), + llvm::cl::init(32)}; +}; +} // namespace + +namespace mlir::test { +void registerTestArithmeticEmulateWideIntPass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index 46b38dc..002e484 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(Arithmetic) add_subdirectory(DLTI) add_subdirectory(Func) add_subdirectory(GPU) diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index e1a9095..3b27cfd 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -14,6 +14,7 @@ if(MLIR_INCLUDE_TESTS) set(test_libs MLIRTestFuncToLLVM MLIRAffineTransformsTestPasses + MLIRArithmeticTestPasses MLIRDLTITestPasses MLIRFuncTestPasses MLIRGPUTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 05bb9e4..58e6598 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -64,6 +64,7 @@ void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); void registerTestAffineLoopParametricTilingPass(); +void registerTestArithmeticEmulateWideIntPass(); void registerTestAliasAnalysisPass(); void registerTestBuiltinAttributeInterfaces(); void registerTestCallGraphPass(); @@ -161,6 +162,7 @@ void registerTestPasses() { mlir::test::registerSimpleParametricTilingPass(); mlir::test::registerTestAffineLoopParametricTilingPass(); mlir::test::registerTestAliasAnalysisPass(); + mlir::test::registerTestArithmeticEmulateWideIntPass(); mlir::test::registerTestBuiltinAttributeInterfaces(); mlir::test::registerTestCallGraphPass(); mlir::test::registerTestConstantFold(); -- 2.7.4