LoadStoreVectorizer: Match nested adds to prove vectorization is safe
authorVolkan Keles <vkeles@apple.com>
Mon, 18 May 2020 19:11:46 +0000 (12:11 -0700)
committerVolkan Keles <vkeles@apple.com>
Mon, 18 May 2020 19:13:01 +0000 (12:13 -0700)
If both OpA and OpB is an add with NSW/NUW and with the same LHS operand,
we can guarantee that the transformation is safe if we can prove that OpA
won't overflow when IdxDiff added to the RHS of OpA.

Review: https://reviews.llvm.org/D79817

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll [new file with mode: 0644]

index 9915e27..c02b8f8 100644 (file)
@@ -430,20 +430,78 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
 
   // Now we need to prove that adding IdxDiff to ValA won't overflow.
   bool Safe = false;
+  auto CheckFlags = [](Instruction *I, bool Signed) {
+    BinaryOperator *BinOpI = cast<BinaryOperator>(I);
+    return (Signed && BinOpI->hasNoSignedWrap()) ||
+           (!Signed && BinOpI->hasNoUnsignedWrap());
+  };
+
   // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
   // ValA, we're okay.
   if (OpB->getOpcode() == Instruction::Add &&
       isa<ConstantInt>(OpB->getOperand(1)) &&
-      IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue())) {
-    if (Signed)
-      Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap();
-    else
-      Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap();
+      IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
+      CheckFlags(OpB, Signed))
+    Safe = true;
+
+  // Second attempt: If both OpA and OpB is an add with NSW/NUW and with
+  // the same LHS operand, we can guarantee that the transformation is safe
+  // if we can prove that OpA won't overflow when IdxDiff added to the RHS
+  // of OpA.
+  // For example:
+  //  %tmp7 = add nsw i32 %tmp2, %v0
+  //  %tmp8 = sext i32 %tmp7 to i64
+  //  ...
+  //  %tmp11 = add nsw i32 %v0, 1
+  //  %tmp12 = add nsw i32 %tmp2, %tmp11
+  //  %tmp13 = sext i32 %tmp12 to i64
+  //
+  //  Both %tmp7 and %tmp2 has the nsw flag and the first operand
+  //  is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
+  //  because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
+  //  nsw flag.
+  OpA = dyn_cast<Instruction>(ValA);
+  if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
+      OpB->getOpcode() == Instruction::Add &&
+      OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) &&
+      CheckFlags(OpB, Signed)) {
+    Value *RHSA = OpA->getOperand(1);
+    Value *RHSB = OpB->getOperand(1);
+    Instruction *OpRHSA = dyn_cast<Instruction>(RHSA);
+    Instruction *OpRHSB = dyn_cast<Instruction>(RHSB);
+    // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
+    if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add &&
+        CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) {
+      int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+      if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal)
+        Safe = true;
+    }
+    // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
+    if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add &&
+        CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) {
+      int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+      if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal)
+        Safe = true;
+    }
+    // Match `x +nsw/nuw (y +nsw/nuw c)` and
+    // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
+    if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add &&
+        OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) &&
+        CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) &&
+        isa<ConstantInt>(OpRHSB->getOperand(1))) {
+      int64_t CstValA =
+          cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+      int64_t CstValB =
+          cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+      if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) &&
+          IdxDiff.getSExtValue() == (CstValB - CstValA))
+        Safe = true;
+    }
   }
 
   unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
 
-  // Second attempt:
+  // Third attempt:
   // If all set bits of IdxDiff or any higher order bit other than the sign bit
   // are known to be zero in ValA, we can add Diff to it while guaranteeing no
   // overflow of any sort.
diff --git a/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll b/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll
new file mode 100644 (file)
index 0000000..6ee00ad
--- /dev/null
@@ -0,0 +1,165 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -o - -S -load-store-vectorizer -dce %s | FileCheck %s
+
+; Make sure LoadStoreVectorizer vectorizes the loads below.
+; In order to prove that the vectorization is safe, it tries to
+; match nested adds and find an expression that adds a constant
+; value to an existing index and the result doesn't overflow.
+
+target triple = "x86_64--"
+
+define void @ld_v4i8_add_nsw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_nsw(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
+; CHECK-NEXT:    [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
+; CHECK-NEXT:    [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
+; CHECK-NEXT:    [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nsw i32 %v0, -1
+  %tmp1 = add nsw i32 %v1, %tmp
+  %tmp2 = sext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add nsw i32 %v1, %v0
+  %tmp6 = sext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nsw i32 %v0, 1
+  %tmp10 = add nsw i32 %v1, %tmp9
+  %tmp11 = sext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nsw i32 %v0, 2
+  %tmp15 = add nsw i32 %v1, %tmp14
+  %tmp16 = sext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}
+
+define void @ld_v4i8_add_nuw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_nuw(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
+; CHECK-NEXT:    [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
+; CHECK-NEXT:    [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
+; CHECK-NEXT:    [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nuw i32 %v0, -1
+  %tmp1 = add nuw i32 %v1, %tmp
+  %tmp2 = zext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add nuw i32 %v1, %v0
+  %tmp6 = zext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nuw i32 %v0, 1
+  %tmp10 = add nuw i32 %v1, %tmp9
+  %tmp11 = zext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nuw i32 %v0, 2
+  %tmp15 = add nuw i32 %v1, %tmp14
+  %tmp16 = zext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}
+
+; Make sure we don't vectorize the loads below because the source of
+; sext instructions doesn't have the nsw flag.
+
+define void @ld_v4i8_add_not_safe(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_not_safe(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = load i8, i8* [[TMP3]], align 1
+; CHECK-NEXT:    [[TMP5:%.*]] = add i32 [[V1]], [[V0]]
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i32 [[TMP5]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = load i8, i8* [[TMP7]], align 1
+; CHECK-NEXT:    [[TMP9:%.*]] = add nsw i32 [[V0]], 1
+; CHECK-NEXT:    [[TMP10:%.*]] = add i32 [[V1]], [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext i32 [[TMP10]] to i64
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = load i8, i8* [[TMP12]], align 1
+; CHECK-NEXT:    [[TMP14:%.*]] = add nsw i32 [[V0]], 2
+; CHECK-NEXT:    [[TMP15:%.*]] = add i32 [[V1]], [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i32 [[TMP15]] to i64
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP16]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, i8* [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP4]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP8]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP13]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP18]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nsw i32 %v0, -1
+  %tmp1 = add i32 %v1, %tmp
+  %tmp2 = sext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add i32 %v1, %v0
+  %tmp6 = sext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nsw i32 %v0, 1
+  %tmp10 = add i32 %v1, %tmp9
+  %tmp11 = sext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nsw i32 %v0, 2
+  %tmp15 = add i32 %v1, %tmp14
+  %tmp16 = sext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}