let name = "llvm";
let cppNamespace = "::mlir::LLVM";
+ let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
let hasRegionArgAttrVerify = 1;
let hasRegionResultAttrVerify = 1;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or">;
+def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+ let hasFolder = 1;
+}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl">;
+def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+ let hasFolder = 1;
+}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
LLVM_ScalarOrVectorOf<AnyInteger>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnyInteger>> {
+ let hasVerifier = 1;
+}
def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
LLVM_ScalarOrVectorOf<AnyInteger>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnyInteger>> {
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>>;
return false;
}
+/// Returns the width of an integer or of the element type of an integer vector,
+/// if applicable.
+static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
+ if (auto intType = dyn_cast<IntegerType>(type))
+ return intType.getWidth();
+ if (auto vecType = dyn_cast<VectorType>(type))
+ if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
+ return intType.getWidth();
+ return std::nullopt;
+}
+
/// Returns the bit width of integer, float or vector of float or integer values
static unsigned getBitWidth(Type type) {
assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
return success();
}
+ std::optional<uint64_t> dstTypeWidth =
+ getIntegerOrVectorElementWidth(dstType);
+ std::optional<uint64_t> op2TypeWidth =
+ getIntegerOrVectorElementWidth(op2Type);
+
+ if (!dstTypeWidth || !op2TypeWidth)
+ return failure();
+
Location loc = operation.getLoc();
Value extended;
- if (isUnsignedIntegerOrVector(op2Type)) {
- extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
- adaptor.getOperand2());
+ if (op2TypeWidth < dstTypeWidth) {
+ if (isUnsignedIntegerOrVector(op2Type)) {
+ extended = rewriter.template create<LLVM::ZExtOp>(
+ loc, dstType, adaptor.getOperand2());
+ } else {
+ extended = rewriter.template create<LLVM::SExtOp>(
+ loc, dstType, adaptor.getOperand2());
+ }
+ } else if (op2TypeWidth == dstTypeWidth) {
+ extended = adaptor.getOperand2();
} else {
- extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
- adaptor.getOperand2());
+ return failure();
}
+
Value result = rewriter.template create<LLVMOp>(
loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(operation, result);
}
//===----------------------------------------------------------------------===//
+// Verifier for extension ops
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the given extension operation operates on consistent scalars
+/// or vectors, and that the target width is larger than the input width.
+template <class ExtOp>
+static LogicalResult verifyExtOp(ExtOp op) {
+ IntegerType inputType, outputType;
+ if (isCompatibleVectorType(op.getArg().getType())) {
+ if (!isCompatibleVectorType(op.getResult().getType()))
+ return op.emitError(
+ "input type is a vector but output type is an integer");
+ if (getVectorNumElements(op.getArg().getType()) !=
+ getVectorNumElements(op.getResult().getType()))
+ return op.emitError("input and output vectors are of incompatible shape");
+ // Because this is a CastOp, the element of vectors is guaranteed to be an
+ // integer.
+ inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
+ outputType =
+ cast<IntegerType>(getVectorElementType(op.getResult().getType()));
+ } else {
+ // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
+ // an integer.
+ inputType = cast<IntegerType>(op.getArg().getType());
+ outputType = dyn_cast<IntegerType>(op.getResult().getType());
+ if (!outputType)
+ return op.emitError(
+ "input type is an integer but output type is a vector");
+ }
+
+ if (outputType.getWidth() <= inputType.getWidth())
+ return op.emitError("integer width of the output type is smaller or "
+ "equal to the integer width of the input type");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ZExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
+
+OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+ if (!arg)
+ return {};
+
+ size_t targetSize = cast<IntegerType>(getType()).getWidth();
+ return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
+}
+
+//===----------------------------------------------------------------------===//
+// SExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
+
+//===----------------------------------------------------------------------===//
// Folder and verifier for LLVM::BitcastOp
//===----------------------------------------------------------------------===//
}
//===----------------------------------------------------------------------===//
+// ShlOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
+ auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+ if (!rhs)
+ return {};
+
+ if (rhs.getValue().getZExtValue() >=
+ getLhs().getType().getIntOrFloatBitWidth())
+ return {}; // TODO: Fold into poison.
+
+ auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+ if (!lhs)
+ return {};
+
+ return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
+}
+
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
+ auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+ if (!lhs)
+ return {};
+
+ auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+ if (!rhs)
+ return {};
+
+ return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
+}
+
+//===----------------------------------------------------------------------===//
// Utilities for LLVM::MetadataOp
//===----------------------------------------------------------------------===//
return verifyParameterAttribute(op, resType, resAttr);
}
+Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ // TODO: Accept more possible attributes. So far, only IntegerAttr may come
+ // up.
+ if (!isa<IntegerAttr>(value))
+ return nullptr;
+ return builder.create<LLVM::ConstantOp>(loc, type, value);
+}
+
//===----------------------------------------------------------------------===//
// Utility functions.
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(canonicalize))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @zext_basic
+llvm.func @zext_basic() -> i64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.zext %0 : i32 to i64
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: llvm.return %[[RES]] : i64
+ llvm.return %1 : i64
+}
+
+// CHECK-LABEL: llvm.func @zext_neg
+llvm.func @zext_neg() -> i64 {
+ %0 = llvm.mlir.constant(-1 : i32) : i32
+ %1 = llvm.zext %0 : i32 to i64
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(4294967295 : i64) : i64
+ // CHECK: llvm.return %[[RES]] : i64
+ llvm.return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @shl_basic
+llvm.func @shl_basic() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(1 : i32) : i32
+ %2 = llvm.shl %0, %1 : i32
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
+
+// CHECK-LABEL: llvm.func @shl_multiple
+llvm.func @shl_multiple() -> i32 {
+ %0 = llvm.mlir.constant(45 : i32) : i32
+ %1 = llvm.mlir.constant(7 : i32) : i32
+ %2 = llvm.shl %0, %1 : i32
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(5760 : i32) : i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @or_basic
+llvm.func @or_basic() -> i32 {
+ %0 = llvm.mlir.constant(5 : i32) : i32
+ %1 = llvm.mlir.constant(9 : i32) : i32
+ %2 = llvm.or %0, %1 : i32
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(13 : i32) : i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
llvm.mlir.global @not_comdat(0 : i32) : i32
// expected-error@+1 {{expected comdat symbol}}
llvm.mlir.global @invalid_comdat_use(0 : i32) comdat(@not_comdat) : i32
+
+// -----
+
+func.func @invalid_zext_target_size_equal(%arg: i32) {
+ // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+ %0 = llvm.zext %arg : i32 to i32
+}
+
+// -----
+
+func.func @invalid_zext_target_size(%arg: i32) {
+ // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+ %0 = llvm.zext %arg : i32 to i16
+}
+
+// -----
+
+func.func @invalid_zext_target_size_vector(%arg: vector<1xi32>) {
+ // expected-error@+1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+ %0 = llvm.zext %arg : vector<1xi32> to vector<1xi16>
+}
+
+// -----
+
+func.func @invalid_zext_target_shape(%arg: vector<1xi32>) {
+ // expected-error@+1 {{input and output vectors are of incompatible shape}}
+ %0 = llvm.zext %arg : vector<1xi32> to vector<2xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type(%arg: i32) {
+ // expected-error@+1 {{input type is an integer but output type is a vector}}
+ %0 = llvm.zext %arg : i32 to vector<1xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type_two(%arg: vector<1xi32>) {
+ // expected-error@+1 {{input type is a vector but output type is an integer}}
+ %0 = llvm.zext %arg : vector<1xi32> to i64
+}
// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @basic_memset
-llvm.func @basic_memset() -> i32 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @basic_memset(%memset_value: i8) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- %memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(4 : i32) : i32
- // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
// -----
-// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
-// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
-llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
+// CHECK-LABEL: llvm.func @basic_memset_constant
+llvm.func @basic_memset_constant() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(4 : i32) : i32
- // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- // CHECK-NOT: "llvm.intr.memset"
- // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
- // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
- // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
- // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
- // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
- // CHECK-NOT: "llvm.intr.memset"
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: llvm.return %[[VALUE_32]] : i32
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[RES]] : i32
llvm.return %2 : i32
}
// -----
// CHECK-LABEL: llvm.func @exotic_target_memset
-llvm.func @exotic_target_memset() -> i40 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- %memset_value = llvm.mlir.constant(42 : i8) : i8
%memset_len = llvm.mlir.constant(5 : i32) : i32
- // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
// CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
// -----
+// CHECK-LABEL: llvm.func @exotic_target_memset_constant
+llvm.func @exotic_target_memset_constant() -> i40 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_value = llvm.mlir.constant(42 : i8) : i8
+ %memset_len = llvm.mlir.constant(5 : i32) : i32
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+ // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+ // CHECK: llvm.return %[[RES]] : i40
+ llvm.return %2 : i40
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @no_volatile_memset
llvm.func @no_volatile_memset() -> i32 {
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32