[LSV] Improve chain splitting in some corner cases.
authorArtem Belevich <tra@google.com>
Mon, 10 Apr 2023 22:55:39 +0000 (15:55 -0700)
committerArtem Belevich <tra@google.com>
Mon, 17 Apr 2023 20:42:00 +0000 (13:42 -0700)
Currently we happen to split a chain of 12xi8 accesses into 6xi8 + 6xi8, which
produces rather suboptimal code.

This change attempts to split-off non-multiples of 4bytes at the end and if that
does not work, splits on the smaller power-of-2 boundary.

Differential Revision: https://reviews.llvm.org/D147976

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
llvm/test/Transforms/LoadStoreVectorizer/NVPTX/vectorize_i8.ll [new file with mode: 0644]

index 6d8fbd1..5633d4c 100644 (file)
@@ -665,14 +665,16 @@ Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
                                unsigned ElementSizeBits) {
   unsigned ElementSizeBytes = ElementSizeBits / 8;
   unsigned SizeBytes = ElementSizeBytes * Chain.size();
-  unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
-  if (NumLeft == Chain.size()) {
-    if ((NumLeft & 1) == 0)
-      NumLeft /= 2; // Split even in half
-    else
-      --NumLeft;    // Split off last element
-  } else if (NumLeft == 0)
+  unsigned LeftBytes = (SizeBytes - (SizeBytes % 4));
+  // If we're already a multiple of 4 bytes or the whole chain is shorter than 4
+  // bytes, then try splitting down on power-of-2 boundary.
+  if (LeftBytes == SizeBytes || LeftBytes == 0)
+    LeftBytes = PowerOf2Ceil(SizeBytes) / 2;
+  unsigned NumLeft = LeftBytes / ElementSizeBytes;
+  if (NumLeft == 0)
     NumLeft = 1;
+  LLVM_DEBUG(dbgs() << "LSV: Splitting the chain into " << NumLeft << "+"
+                    << Chain.size() - NumLeft << " elements\n");
   return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
 }
 
diff --git a/llvm/test/Transforms/LoadStoreVectorizer/NVPTX/vectorize_i8.ll b/llvm/test/Transforms/LoadStoreVectorizer/NVPTX/vectorize_i8.ll
new file mode 100644 (file)
index 0000000..387d678
--- /dev/null
@@ -0,0 +1,309 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=nvptx64-nvidia-cuda -passes=load-store-vectorizer -S -o - %s | FileCheck %s
+
+; Vectorize and emit valid code (Issue #54896).
+
+define void @int8x3a2(ptr nocapture align 2 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+
+  %l0 = load i8, ptr %ptr0, align 2
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+
+  store i8 %l2, ptr %ptr0, align 2
+  store i8 %l1, ptr %ptr1, align 1
+  store i8 %l0, ptr %ptr2, align 2
+
+  ret void
+
+; CHECK-LABEL: @int8x3a2
+; CHECK-DAG: load <2 x i8>
+; CHECK-DAG: load i8
+; CHECK-DAG: store <2 x i8>
+; CHECK-DAG: store i8
+}
+
+define void @int8x3a4(ptr nocapture align 4 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+
+  %l0 = load i8, ptr %ptr0, align 4
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+
+  store i8 %l2, ptr %ptr0, align 2
+  store i8 %l1, ptr %ptr1, align 1
+  store i8 %l0, ptr %ptr2, align 4
+
+  ret void
+
+; CHECK-LABEL: @int8x3a4
+; CHECK: load <3 x i8>
+; CHECK: store <3 x i8>
+}
+
+define void @int8x12a4(ptr nocapture align 4 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+  %ptr3 = getelementptr i8, ptr %ptr, i64 3
+  %ptr4 = getelementptr i8, ptr %ptr, i64 4
+  %ptr5 = getelementptr i8, ptr %ptr, i64 5
+  %ptr6 = getelementptr i8, ptr %ptr, i64 6
+  %ptr7 = getelementptr i8, ptr %ptr, i64 7
+  %ptr8 = getelementptr i8, ptr %ptr, i64 8
+  %ptr9 = getelementptr i8, ptr %ptr, i64 9
+  %ptra = getelementptr i8, ptr %ptr, i64 10
+  %ptrb = getelementptr i8, ptr %ptr, i64 11
+
+  %l0 = load i8, ptr %ptr0, align 4
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+  %l3 = load i8, ptr %ptr3, align 1
+  %l4 = load i8, ptr %ptr4, align 4
+  %l5 = load i8, ptr %ptr5, align 1
+  %l6 = load i8, ptr %ptr6, align 2
+  %l7 = load i8, ptr %ptr7, align 1
+  %l8 = load i8, ptr %ptr8, align 4
+  %l9 = load i8, ptr %ptr9, align 1
+  %la = load i8, ptr %ptra, align 2
+  %lb = load i8, ptr %ptrb, align 1
+
+  store i8 %lb, ptr %ptr0, align 4
+  store i8 %la, ptr %ptr1, align 1
+  store i8 %l9, ptr %ptr2, align 2
+  store i8 %l8, ptr %ptr3, align 1
+  store i8 %l7, ptr %ptr4, align 4
+  store i8 %l6, ptr %ptr5, align 1
+  store i8 %l5, ptr %ptr6, align 2
+  store i8 %l4, ptr %ptr7, align 1
+  store i8 %l3, ptr %ptr8, align 4
+  store i8 %l2, ptr %ptr9, align 1
+  store i8 %l1, ptr %ptra, align 2
+  store i8 %l0, ptr %ptrb, align 1
+
+  ret void
+
+; CHECK-LABEL: @int8x12a4
+; CHECK: load <4 x i8>
+; CHECK: load <4 x i8>
+; CHECK: load <4 x i8>
+; CHECK: store <4 x i8>
+; CHECK: store <4 x i8>
+; CHECK: store <4 x i8>
+}
+
+
+define void @int8x16a4(ptr nocapture align 4 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+  %ptr3 = getelementptr i8, ptr %ptr, i64 3
+  %ptr4 = getelementptr i8, ptr %ptr, i64 4
+  %ptr5 = getelementptr i8, ptr %ptr, i64 5
+  %ptr6 = getelementptr i8, ptr %ptr, i64 6
+  %ptr7 = getelementptr i8, ptr %ptr, i64 7
+  %ptr8 = getelementptr i8, ptr %ptr, i64 8
+  %ptr9 = getelementptr i8, ptr %ptr, i64 9
+  %ptra = getelementptr i8, ptr %ptr, i64 10
+  %ptrb = getelementptr i8, ptr %ptr, i64 11
+  %ptrc = getelementptr i8, ptr %ptr, i64 12
+  %ptrd = getelementptr i8, ptr %ptr, i64 13
+  %ptre = getelementptr i8, ptr %ptr, i64 14
+  %ptrf = getelementptr i8, ptr %ptr, i64 15
+
+  %l0 = load i8, ptr %ptr0, align 4
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+  %l3 = load i8, ptr %ptr3, align 1
+  %l4 = load i8, ptr %ptr4, align 4
+  %l5 = load i8, ptr %ptr5, align 1
+  %l6 = load i8, ptr %ptr6, align 2
+  %l7 = load i8, ptr %ptr7, align 1
+  %l8 = load i8, ptr %ptr8, align 4
+  %l9 = load i8, ptr %ptr9, align 1
+  %la = load i8, ptr %ptra, align 2
+  %lb = load i8, ptr %ptrb, align 1
+  %lc = load i8, ptr %ptrc, align 4
+  %ld = load i8, ptr %ptrd, align 1
+  %le = load i8, ptr %ptre, align 2
+  %lf = load i8, ptr %ptrf, align 1
+
+  store i8 %lf, ptr %ptrc, align 4
+  store i8 %le, ptr %ptrd, align 1
+  store i8 %ld, ptr %ptre, align 2
+  store i8 %lc, ptr %ptrf, align 1
+  store i8 %lb, ptr %ptr0, align 4
+  store i8 %la, ptr %ptr1, align 1
+  store i8 %l9, ptr %ptr2, align 2
+  store i8 %l8, ptr %ptr3, align 1
+  store i8 %l7, ptr %ptr4, align 4
+  store i8 %l6, ptr %ptr5, align 1
+  store i8 %l5, ptr %ptr6, align 2
+  store i8 %l4, ptr %ptr7, align 1
+  store i8 %l3, ptr %ptr8, align 4
+  store i8 %l2, ptr %ptr9, align 1
+  store i8 %l1, ptr %ptra, align 2
+  store i8 %l0, ptr %ptrb, align 1
+
+  ret void
+
+; CHECK-LABEL: @int8x16a4
+; CHECK: load <4 x i8>
+; CHECK: load <4 x i8>
+; CHECK: load <4 x i8>
+; CHECK: load <4 x i8>
+; CHECK: store <4 x i8>
+; CHECK: store <4 x i8>
+; CHECK: store <4 x i8>
+; CHECK: store <4 x i8>
+}
+
+define void @int8x8a8(ptr nocapture align 8 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+  %ptr3 = getelementptr i8, ptr %ptr, i64 3
+  %ptr4 = getelementptr i8, ptr %ptr, i64 4
+  %ptr5 = getelementptr i8, ptr %ptr, i64 5
+  %ptr6 = getelementptr i8, ptr %ptr, i64 6
+  %ptr7 = getelementptr i8, ptr %ptr, i64 7
+
+  %l0 = load i8, ptr %ptr0, align 8
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+  %l3 = load i8, ptr %ptr3, align 1
+  %l4 = load i8, ptr %ptr4, align 4
+  %l5 = load i8, ptr %ptr5, align 1
+  %l6 = load i8, ptr %ptr6, align 2
+  %l7 = load i8, ptr %ptr7, align 1
+
+  store i8 %l7, ptr %ptr0, align 8
+  store i8 %l6, ptr %ptr1, align 1
+  store i8 %l5, ptr %ptr2, align 2
+  store i8 %l4, ptr %ptr3, align 1
+  store i8 %l3, ptr %ptr4, align 4
+  store i8 %l2, ptr %ptr5, align 1
+  store i8 %l1, ptr %ptr6, align 2
+  store i8 %l0, ptr %ptr7, align 1
+
+  ret void
+
+; CHECK-LABEL: @int8x8a8
+; CHECK: load <8 x i8>
+; CHECK: store <8 x i8>
+}
+
+define void @int8x12a8(ptr nocapture align 8 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+  %ptr3 = getelementptr i8, ptr %ptr, i64 3
+  %ptr4 = getelementptr i8, ptr %ptr, i64 4
+  %ptr5 = getelementptr i8, ptr %ptr, i64 5
+  %ptr6 = getelementptr i8, ptr %ptr, i64 6
+  %ptr7 = getelementptr i8, ptr %ptr, i64 7
+  %ptr8 = getelementptr i8, ptr %ptr, i64 8
+  %ptr9 = getelementptr i8, ptr %ptr, i64 9
+  %ptra = getelementptr i8, ptr %ptr, i64 10
+  %ptrb = getelementptr i8, ptr %ptr, i64 11
+
+  %l0 = load i8, ptr %ptr0, align 8
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+  %l3 = load i8, ptr %ptr3, align 1
+  %l4 = load i8, ptr %ptr4, align 4
+  %l5 = load i8, ptr %ptr5, align 1
+  %l6 = load i8, ptr %ptr6, align 2
+  %l7 = load i8, ptr %ptr7, align 1
+  %l8 = load i8, ptr %ptr8, align 8
+  %l9 = load i8, ptr %ptr9, align 1
+  %la = load i8, ptr %ptra, align 2
+  %lb = load i8, ptr %ptrb, align 1
+
+  store i8 %lb, ptr %ptr0, align 8
+  store i8 %la, ptr %ptr1, align 1
+  store i8 %l9, ptr %ptr2, align 2
+  store i8 %l8, ptr %ptr3, align 1
+  store i8 %l7, ptr %ptr4, align 4
+  store i8 %l6, ptr %ptr5, align 1
+  store i8 %l5, ptr %ptr6, align 2
+  store i8 %l4, ptr %ptr7, align 1
+  store i8 %l3, ptr %ptr8, align 8
+  store i8 %l2, ptr %ptr9, align 1
+  store i8 %l1, ptr %ptra, align 2
+  store i8 %l0, ptr %ptrb, align 1
+
+  ret void
+
+; CHECK-LABEL: @int8x12a8
+; CHECK-DAG: load <8 x i8>
+; CHECK-DAG: load <4 x i8>
+; CHECK-DAG: store <8 x i8>
+; CHECK-DAG: store <4 x i8>
+}
+
+
+define void @int8x16a8(ptr nocapture align 8 %ptr) {
+  %ptr0 = getelementptr i8, ptr %ptr, i64 0
+  %ptr1 = getelementptr i8, ptr %ptr, i64 1
+  %ptr2 = getelementptr i8, ptr %ptr, i64 2
+  %ptr3 = getelementptr i8, ptr %ptr, i64 3
+  %ptr4 = getelementptr i8, ptr %ptr, i64 4
+  %ptr5 = getelementptr i8, ptr %ptr, i64 5
+  %ptr6 = getelementptr i8, ptr %ptr, i64 6
+  %ptr7 = getelementptr i8, ptr %ptr, i64 7
+  %ptr8 = getelementptr i8, ptr %ptr, i64 8
+  %ptr9 = getelementptr i8, ptr %ptr, i64 9
+  %ptra = getelementptr i8, ptr %ptr, i64 10
+  %ptrb = getelementptr i8, ptr %ptr, i64 11
+  %ptrc = getelementptr i8, ptr %ptr, i64 12
+  %ptrd = getelementptr i8, ptr %ptr, i64 13
+  %ptre = getelementptr i8, ptr %ptr, i64 14
+  %ptrf = getelementptr i8, ptr %ptr, i64 15
+
+  %l0 = load i8, ptr %ptr0, align 8
+  %l1 = load i8, ptr %ptr1, align 1
+  %l2 = load i8, ptr %ptr2, align 2
+  %l3 = load i8, ptr %ptr3, align 1
+  %l4 = load i8, ptr %ptr4, align 4
+  %l5 = load i8, ptr %ptr5, align 1
+  %l6 = load i8, ptr %ptr6, align 2
+  %l7 = load i8, ptr %ptr7, align 1
+  %l8 = load i8, ptr %ptr8, align 8
+  %l9 = load i8, ptr %ptr9, align 1
+  %la = load i8, ptr %ptra, align 2
+  %lb = load i8, ptr %ptrb, align 1
+  %lc = load i8, ptr %ptrc, align 4
+  %ld = load i8, ptr %ptrd, align 1
+  %le = load i8, ptr %ptre, align 2
+  %lf = load i8, ptr %ptrf, align 1
+
+  store i8 %lf, ptr %ptr0, align 8
+  store i8 %le, ptr %ptr1, align 1
+  store i8 %ld, ptr %ptr2, align 2
+  store i8 %lc, ptr %ptr3, align 1
+  store i8 %lb, ptr %ptr4, align 4
+  store i8 %la, ptr %ptr5, align 1
+  store i8 %l9, ptr %ptr6, align 2
+  store i8 %l8, ptr %ptr7, align 1
+  store i8 %l7, ptr %ptr8, align 8
+  store i8 %l6, ptr %ptr9, align 1
+  store i8 %l5, ptr %ptra, align 2
+  store i8 %l4, ptr %ptrb, align 1
+  store i8 %l3, ptr %ptrc, align 4
+  store i8 %l2, ptr %ptrd, align 1
+  store i8 %l1, ptr %ptre, align 2
+  store i8 %l0, ptr %ptrf, align 1
+
+  ret void
+
+; CHECK-LABEL: @int8x16a8
+; CHECK: load <8 x i8>
+; CHECK: load <8 x i8>
+; CHECK: store <8 x i8>
+; CHECK: store <8 x i8>
+}