}
//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+// Input:
+// %0 = spirv.UMod %arg0, %const32 : i32
+// %1 = spirv.UMod %0, %const4 : i32
+// Output:
+// %0 = spirv.UMod %arg0, %const32 : i32
+// %1 = spirv.UMod %arg0, %const4 : i32
+
+// The transformation is only applied if one divisor is a multiple of the other.
+
+// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
+struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UModOp umodOp,
+ PatternRewriter &rewriter) const override {
+ auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
+ if (!prevUMod)
+ return failure();
+
+ IntegerAttr prevValue;
+ IntegerAttr currValue;
+ if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
+ !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
+ return failure();
+
+ APInt prevConstValue = prevValue.getValue();
+ APInt currConstValue = currValue.getValue();
+
+ // Ensure that one divisor is a multiple of the other. If not, fail the
+ // transformation.
+ if (prevConstValue.urem(currConstValue) != 0 &&
+ currConstValue.urem(prevConstValue) != 0)
+ return failure();
+
+ // The transformation is safe. Replace the existing UMod operation with a
+ // new UMod operation, using the original dividend and the current divisor.
+ rewriter.replaceOpWithNewOp<spirv::UModOp>(
+ umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
+
+ return success();
+ }
+};
+
+void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.insert<UModSimplification>(context);
+}
+
+//===----------------------------------------------------------------------===//
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
// -----
//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umod_fold
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fold(%arg0: i32) -> (i32, i32) {
+ // CHECK: %[[CONST4:.*]] = spirv.Constant 4
+ // CHECK: %[[CONST32:.*]] = spirv.Constant 32
+ %const1 = spirv.Constant 32 : i32
+ %0 = spirv.UMod %arg0, %const1 : i32
+ %const2 = spirv.Constant 4 : i32
+ %1 = spirv.UMod %0, %const2 : i32
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: i32, i32
+}
+
+// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+ // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+ %const1 = spirv.Constant dense<32> : vector<4xi32>
+ %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+ %const2 = spirv.Constant dense<4> : vector<4xi32>
+ %1 = spirv.UMod %0, %const2 : vector<4xi32>
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
+// CHECK-LABEL: @umod_fold_same_divisor
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
+ // CHECK: %[[CONST1:.*]] = spirv.Constant 32
+ %const1 = spirv.Constant 32 : i32
+ %0 = spirv.UMod %arg0, %const1 : i32
+ %const2 = spirv.Constant 32 : i32
+ %1 = spirv.UMod %0, %const2 : i32
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST1]]
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST1]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: i32, i32
+}
+
+// CHECK-LABEL: @umod_fail_fold
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+ // CHECK: %[[CONST5:.*]] = spirv.Constant 5
+ // CHECK: %[[CONST32:.*]] = spirv.Constant 32
+ %const1 = spirv.Constant 32 : i32
+ %0 = spirv.UMod %arg0, %const1 : i32
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+ %const2 = spirv.Constant 5 : i32
+ %1 = spirv.UMod %0, %const2 : i32
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST5]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//