[ConstraintElimination] Fix variables used for pattern matching.
authorFlorian Hahn <flo@fhahn.com>
Sun, 14 Feb 2021 18:06:09 +0000 (18:06 +0000)
committerFlorian Hahn <flo@fhahn.com>
Sun, 14 Feb 2021 18:42:37 +0000 (18:42 +0000)
Re-using the matched variable in the pattern does not work as expected.
This patch fixes that by introducing a new variable for the 2nd level
match.

llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
llvm/test/Transforms/ConstraintElimination/geps.ll

index 09b0b4a..00f1c48 100644 (file)
@@ -54,20 +54,20 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) {
   }
   auto *GEP = dyn_cast<GetElementPtrInst>(V);
   if (GEP && GEP->getNumOperands() == 2 && GEP->isInBounds()) {
-    Value *Op0;
+    Value *Op0, *Op1;
     ConstantInt *CI;
 
     // If the index is zero-extended, it is guaranteed to be positive.
     if (match(GEP->getOperand(GEP->getNumOperands() - 1),
               m_ZExt(m_Value(Op0)))) {
-      if (match(Op0, m_NUWShl(m_Value(Op0), m_ConstantInt(CI))))
+      if (match(Op0, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))))
         return {{0, nullptr},
                 {1, GEP->getPointerOperand()},
-                {std::pow(int64_t(2), CI->getSExtValue()), Op0}};
-      if (match(Op0, m_NSWAdd(m_Value(Op0), m_ConstantInt(CI))))
+                {std::pow(int64_t(2), CI->getSExtValue()), Op1}};
+      if (match(Op0, m_NSWAdd(m_Value(Op1), m_ConstantInt(CI))))
         return {{CI->getSExtValue(), nullptr},
                 {1, GEP->getPointerOperand()},
-                {1, Op0}};
+                {1, Op1}};
       return {{0, nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}};
     }
 
index 5c891f3..9141ace 100644 (file)
@@ -516,7 +516,7 @@ if.end:                                           ; preds = %entry
 }
 
 ; Test which requires decomposing GEP %ptr, SHL().
-define void @test.ult.gep.shl(i32* readonly %src, i32* readnone %max, i32 %idx, i32 %j) {
+define void @test.ult.gep.shl(i32* readonly %src, i32* readnone %max, i32 %idx) {
 ; CHECK-LABEL: @test.ult.gep.shl(
 ; CHECK-NEXT:  check.0.min:
 ; CHECK-NEXT:    [[ADD_10:%.*]] = getelementptr inbounds i32, i32* [[SRC:%.*]], i32 10
@@ -646,4 +646,57 @@ check.max:                                      ; preds = %check.0.min
   ret void
 }
 
+; Make sure non-constant shift amounts are handled correctly.
+define i1 @test.ult.gep.shl.nonconst.zext(i16 %B, i16* readonly %src, i16* readnone %max, i16 %idx, i16 %j) {
+; CHECK-LABEL: @test.ult.gep.shl.nonconst.zext(
+; CHECK-NEXT:  check.0.min:
+; CHECK-NEXT:    [[ADD_10:%.*]] = getelementptr inbounds i16, i16* [[SRC:%.*]], i16 10
+; CHECK-NEXT:    [[C_ADD_10_MAX:%.*]] = icmp ugt i16* [[ADD_10]], [[MAX:%.*]]
+; CHECK-NEXT:    br i1 [[C_ADD_10_MAX]], label [[EXIT_1:%.*]], label [[CHECK_IDX:%.*]]
+; CHECK:       exit.1:
+; CHECK-NEXT:    ret i1 true
+; CHECK:       check.idx:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[IDX:%.*]], 5
+; CHECK-NEXT:    br i1 [[CMP]], label [[CHECK_MAX:%.*]], label [[TRAP:%.*]]
+; CHECK:       check.max:
+; CHECK-NEXT:    [[IDX_SHL:%.*]] = shl nuw i16 [[IDX]], [[B:%.*]]
+; CHECK-NEXT:    [[EXT:%.*]] = zext i16 [[IDX_SHL]] to i64
+; CHECK-NEXT:    [[ADD_PTR_SHL:%.*]] = getelementptr inbounds i16, i16* [[SRC]], i64 [[EXT]]
+; CHECK-NEXT:    [[C_MAX:%.*]] = icmp ult i16* [[ADD_PTR_SHL]], [[MAX]]
+; CHECK-NEXT:    ret i1 [[C_MAX]]
+; CHECK:       trap:
+; CHECK-NEXT:    [[IDX_SHL_1:%.*]] = shl nuw i16 [[IDX]], [[B]]
+; CHECK-NEXT:    [[EXT_1:%.*]] = zext i16 [[IDX_SHL_1]] to i64
+; CHECK-NEXT:    [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i16, i16* [[SRC]], i64 [[EXT_1]]
+; CHECK-NEXT:    [[C_MAX_1:%.*]] = icmp ult i16* [[ADD_PTR_SHL_1]], [[MAX]]
+; CHECK-NEXT:    ret i1 [[C_MAX_1]]
+;
+check.0.min:
+  %add.10 = getelementptr inbounds i16, i16* %src, i16 10
+  %c.add.10.max = icmp ugt i16* %add.10, %max
+  br i1 %c.add.10.max, label %exit.1, label %check.idx
+
+exit.1:
+  ret i1 true
+
+
+check.idx:                                        ; preds = %check.0.min
+  %cmp = icmp ult i16 %idx, 5
+  br i1 %cmp, label %check.max, label %trap
+
+check.max:                                        ; preds = %check.idx
+  %idx.shl = shl nuw i16 %idx, %B
+  %ext = zext i16 %idx.shl to i64
+  %add.ptr.shl = getelementptr inbounds i16, i16* %src, i64 %ext
+  %c.max = icmp ult i16* %add.ptr.shl, %max
+  ret i1 %c.max
+
+trap:                                             ; preds = %check.idx, %check.0.min
+  %idx.shl.1 = shl nuw i16 %idx, %B
+  %ext.1 = zext i16 %idx.shl.1 to i64
+  %add.ptr.shl.1 = getelementptr inbounds i16, i16* %src, i64 %ext.1
+  %c.max.1 = icmp ult i16* %add.ptr.shl.1, %max
+  ret i1 %c.max.1
+}
+
 declare void @use(i1)