[mlir][llvm] Introduce some constant folding.
authorThéo Degioanni <theo.degioanni@nextsilicon.com>
Mon, 26 Jun 2023 10:49:54 +0000 (12:49 +0200)
committerThéo Degioanni <theo.degioanni@nextsilicon.com>
Mon, 26 Jun 2023 10:52:48 +0000 (12:52 +0200)
This revision introduces some constant folding features to the LLVM
dialect. This specific choice of operations to cover is intended to
allow the elimination of logic generated by mem2reg with memset in the
common case of memsets of constant values.

This also introduces new verifiers for integer extension operations.
This lead to a fix in SPIRV to LLVM conversion, as it would sometimes
generate invalid ZExt and SExt operations.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D153135

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/constant-folding.mlir [new file with mode: 0644]
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir

index 6f0f8f7..e445ca8 100644 (file)
@@ -15,6 +15,7 @@ def LLVM_Dialect : Dialect {
   let name = "llvm";
   let cppNamespace = "::mlir::LLVM";
 
+  let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
   let hasRegionArgAttrVerify = 1;
   let hasRegionResultAttrVerify = 1;
index 8536cbf..e7aca6d 100644 (file)
@@ -98,9 +98,13 @@ def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
 def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
 def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or">;
+def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+  let hasFolder = 1;
+}
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl">;
+def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+  let hasFolder = 1;
+}
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
 def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
 
@@ -495,10 +499,15 @@ def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",
                                   LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
                               LLVM_ScalarOrVectorOf<AnyInteger>,
-                              LLVM_ScalarOrVectorOf<AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>> {
+  let hasVerifier = 1;
+}
 def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
                               LLVM_ScalarOrVectorOf<AnyInteger>,
-                              LLVM_ScalarOrVectorOf<AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>> {
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
 def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
                                LLVM_ScalarOrVectorOf<AnyInteger>,
                                LLVM_ScalarOrVectorOf<AnyInteger>>;
index 8b43808..28e587a 100644 (file)
@@ -51,6 +51,17 @@ static bool isUnsignedIntegerOrVector(Type type) {
   return false;
 }
 
+/// Returns the width of an integer or of the element type of an integer vector,
+/// if applicable.
+static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
+  if (auto intType = dyn_cast<IntegerType>(type))
+    return intType.getWidth();
+  if (auto vecType = dyn_cast<VectorType>(type))
+    if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
+      return intType.getWidth();
+  return std::nullopt;
+}
+
 /// Returns the bit width of integer, float or vector of float or integer values
 static unsigned getBitWidth(Type type) {
   assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
@@ -1183,15 +1194,30 @@ public:
       return success();
     }
 
+    std::optional<uint64_t> dstTypeWidth =
+        getIntegerOrVectorElementWidth(dstType);
+    std::optional<uint64_t> op2TypeWidth =
+        getIntegerOrVectorElementWidth(op2Type);
+
+    if (!dstTypeWidth || !op2TypeWidth)
+      return failure();
+
     Location loc = operation.getLoc();
     Value extended;
-    if (isUnsignedIntegerOrVector(op2Type)) {
-      extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
-                                                        adaptor.getOperand2());
+    if (op2TypeWidth < dstTypeWidth) {
+      if (isUnsignedIntegerOrVector(op2Type)) {
+        extended = rewriter.template create<LLVM::ZExtOp>(
+            loc, dstType, adaptor.getOperand2());
+      } else {
+        extended = rewriter.template create<LLVM::SExtOp>(
+            loc, dstType, adaptor.getOperand2());
+      }
+    } else if (op2TypeWidth == dstTypeWidth) {
+      extended = adaptor.getOperand2();
     } else {
-      extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
-                                                        adaptor.getOperand2());
+      return failure();
     }
+
     Value result = rewriter.template create<LLVMOp>(
         loc, dstType, adaptor.getOperand1(), extended);
     rewriter.replaceOp(operation, result);
index b24e2ca..9b91304 100644 (file)
@@ -2532,6 +2532,64 @@ LogicalResult FenceOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
+// Verifier for extension ops
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the given extension operation operates on consistent scalars
+/// or vectors, and that the target width is larger than the input width.
+template <class ExtOp>
+static LogicalResult verifyExtOp(ExtOp op) {
+  IntegerType inputType, outputType;
+  if (isCompatibleVectorType(op.getArg().getType())) {
+    if (!isCompatibleVectorType(op.getResult().getType()))
+      return op.emitError(
+          "input type is a vector but output type is an integer");
+    if (getVectorNumElements(op.getArg().getType()) !=
+        getVectorNumElements(op.getResult().getType()))
+      return op.emitError("input and output vectors are of incompatible shape");
+    // Because this is a CastOp, the element of vectors is guaranteed to be an
+    // integer.
+    inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
+    outputType =
+        cast<IntegerType>(getVectorElementType(op.getResult().getType()));
+  } else {
+    // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
+    // an integer.
+    inputType = cast<IntegerType>(op.getArg().getType());
+    outputType = dyn_cast<IntegerType>(op.getResult().getType());
+    if (!outputType)
+      return op.emitError(
+          "input type is an integer but output type is a vector");
+  }
+
+  if (outputType.getWidth() <= inputType.getWidth())
+    return op.emitError("integer width of the output type is smaller or "
+                        "equal to the integer width of the input type");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ZExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
+
+OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
+  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+  if (!arg)
+    return {};
+
+  size_t targetSize = cast<IntegerType>(getType()).getWidth();
+  return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
+}
+
+//===----------------------------------------------------------------------===//
+// SExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
+
+//===----------------------------------------------------------------------===//
 // Folder and verifier for LLVM::BitcastOp
 //===----------------------------------------------------------------------===//
 
@@ -2648,6 +2706,42 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
 }
 
 //===----------------------------------------------------------------------===//
+// ShlOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
+  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+  if (!rhs)
+    return {};
+
+  if (rhs.getValue().getZExtValue() >=
+      getLhs().getType().getIntOrFloatBitWidth())
+    return {}; // TODO: Fold into poison.
+
+  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+  if (!lhs)
+    return {};
+
+  return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
+}
+
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
+  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+  if (!lhs)
+    return {};
+
+  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+  if (!rhs)
+    return {};
+
+  return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
+}
+
+//===----------------------------------------------------------------------===//
 // Utilities for LLVM::MetadataOp
 //===----------------------------------------------------------------------===//
 
@@ -3186,6 +3280,15 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
   return verifyParameterAttribute(op, resType, resAttr);
 }
 
+Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                            Type type, Location loc) {
+  // TODO: Accept more possible attributes. So far, only IntegerAttr may come
+  // up.
+  if (!isa<IntegerAttr>(value))
+    return nullptr;
+  return builder.create<LLVM::ConstantOp>(loc, type, value);
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
new file mode 100644 (file)
index 0000000..f800f26
--- /dev/null
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(canonicalize))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @zext_basic
+llvm.func @zext_basic() -> i64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.zext %0 : i32 to i64
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: llvm.return %[[RES]] : i64
+  llvm.return %1 : i64
+}
+
+// CHECK-LABEL: llvm.func @zext_neg
+llvm.func @zext_neg() -> i64 {
+  %0 = llvm.mlir.constant(-1 : i32) : i32
+  %1 = llvm.zext %0 : i32 to i64
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(4294967295 : i64) : i64
+  // CHECK: llvm.return %[[RES]] : i64
+  llvm.return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @shl_basic
+llvm.func @shl_basic() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %2 = llvm.shl %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// CHECK-LABEL: llvm.func @shl_multiple
+llvm.func @shl_multiple() -> i32 {
+  %0 = llvm.mlir.constant(45 : i32) : i32
+  %1 = llvm.mlir.constant(7 : i32) : i32
+  %2 = llvm.shl %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(5760 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @or_basic
+llvm.func @or_basic() -> i32 {
+  %0 = llvm.mlir.constant(5 : i32) : i32
+  %1 = llvm.mlir.constant(9 : i32) : i32
+  %2 = llvm.or %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(13 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
index c88ed03..e6e1192 100644 (file)
@@ -1447,3 +1447,45 @@ llvm.comdat @__llvm_comdat {
 llvm.mlir.global @not_comdat(0 : i32) : i32
 // expected-error@+1 {{expected comdat symbol}}
 llvm.mlir.global @invalid_comdat_use(0 : i32) comdat(@not_comdat) : i32
+
+// -----
+
+func.func @invalid_zext_target_size_equal(%arg: i32)  {
+  // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : i32 to i32
+}
+
+// -----
+
+func.func @invalid_zext_target_size(%arg: i32)  {
+  // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : i32 to i16
+}
+
+// -----
+
+func.func @invalid_zext_target_size_vector(%arg: vector<1xi32>)  {
+  // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : vector<1xi32> to vector<1xi16>
+}
+
+// -----
+
+func.func @invalid_zext_target_shape(%arg: vector<1xi32>)  {
+  // expected-error@+1 {{input and output vectors are of incompatible shape}}
+  %0 = llvm.zext %arg : vector<1xi32> to vector<2xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type(%arg: i32)  {
+  // expected-error@+1 {{input type is an integer but output type is a vector}}
+  %0 = llvm.zext %arg : i32 to vector<1xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
+  // expected-error@+1 {{input type is a vector but output type is an integer}}
+  %0 = llvm.zext %arg : vector<1xi32> to i64
+}
index 4a262f7..ce6338f 100644 (file)
@@ -1,12 +1,11 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @basic_memset
-llvm.func @basic_memset() -> i32 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @basic_memset(%memset_value: i8) -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
@@ -24,36 +23,27 @@ llvm.func @basic_memset() -> i32 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
-// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
-llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
+// CHECK-LABEL: llvm.func @basic_memset_constant
+llvm.func @basic_memset_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
-  // CHECK-NOT: "llvm.intr.memset"
-  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
-  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
-  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
-  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
-  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
-  // CHECK-NOT: "llvm.intr.memset"
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: llvm.return %[[VALUE_32]] : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
 
 // -----
 
 // CHECK-LABEL: llvm.func @exotic_target_memset
-llvm.func @exotic_target_memset() -> i40 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(5 : i32) : i32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
   // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
   // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
@@ -74,6 +64,21 @@ llvm.func @exotic_target_memset() -> i40 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @exotic_target_memset_constant
+llvm.func @exotic_target_memset_constant() -> i40 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(5 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+  // CHECK: llvm.return %[[RES]] : i40
+  llvm.return %2 : i40
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @no_volatile_memset
 llvm.func @no_volatile_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32