[AArch64] Add support to loop vectorization for non temporal loads
authorZain Jaffal <z_jaffal@apple.com>
Mon, 3 Oct 2022 16:06:47 +0000 (17:06 +0100)
committerFlorian Hahn <flo@fhahn.com>
Mon, 3 Oct 2022 16:06:47 +0000 (17:06 +0100)
Currently, AArch64 doesn't support vectorization for non temporal loads because `isLegalNTLoad` is not implemented for the target.
This patch applies similar functionality as `D73158` but for non temporal loads

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D131964

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/test/Transforms/LoopVectorize/AArch64/nontemporal-load-store.ll

index eef5295..55961ba 100644 (file)
@@ -308,24 +308,34 @@ public:
     return false;
   }
 
-  bool isLegalNTStore(Type *DataType, Align Alignment) {
+  bool isLegalNTStoreLoad(Type *DataType, Align Alignment) {
     // NOTE: The logic below is mostly geared towards LV, which calls it with
     //       vectors with 2 elements. We might want to improve that, if other
     //       users show up.
-    // Nontemporal vector stores can be directly lowered to STNP, if the vector
-    // can be halved so that each half fits into a register. That's the case if
-    // the element type fits into a register and the number of elements is a
-    // power of 2 > 1.
-    if (auto *DataTypeVTy = dyn_cast<VectorType>(DataType)) {
-      unsigned NumElements =
-          cast<FixedVectorType>(DataTypeVTy)->getNumElements();
-      unsigned EltSize = DataTypeVTy->getElementType()->getScalarSizeInBits();
+    // Nontemporal vector loads/stores can be directly lowered to LDNP/STNP, if
+    // the vector can be halved so that each half fits into a register. That's
+    // the case if the element type fits into a register and the number of
+    // elements is a power of 2 > 1.
+    if (auto *DataTypeTy = dyn_cast<FixedVectorType>(DataType)) {
+      unsigned NumElements = DataTypeTy->getNumElements();
+      unsigned EltSize = DataTypeTy->getElementType()->getScalarSizeInBits();
       return NumElements > 1 && isPowerOf2_64(NumElements) && EltSize >= 8 &&
              EltSize <= 128 && isPowerOf2_64(EltSize);
     }
     return BaseT::isLegalNTStore(DataType, Alignment);
   }
 
+  bool isLegalNTStore(Type *DataType, Align Alignment) {
+    return isLegalNTStoreLoad(DataType, Alignment);
+  }
+
+  bool isLegalNTLoad(Type *DataType, Align Alignment) {
+    // Only supports little-endian targets.
+    if (ST->isLittleEndian())
+      return isLegalNTStoreLoad(DataType, Alignment);
+    return BaseT::isLegalNTLoad(DataType, Alignment);
+  }
+
   bool enableOrderedReductions() const { return true; }
 
   InstructionCost getInterleavedMemoryOpCost(
index 9df651a..4af55b7 100644 (file)
@@ -258,8 +258,7 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i4 @test_i4_load(i4* %ddst) {
 ; CHECK-LABEL: define i4 @test_i4_load
-; CHECK-LABEL: vector.body:
-; CHECK:         [[LOAD:%.*]] = load i4, i4* {{.*}}, align 1, !nontemporal !0
+; CHECK-NOT: vector.body:
 ; CHECk: ret i4 %{{.*}}
 ;
 entry:
@@ -281,7 +280,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i8 @test_load_i8(i8* %ddst) {
 ; CHECK-LABEL: @test_load_i8(
-; CHECK-NOT:   vector.body:
+; CHECK:   vector.body:
+; CHECK: load <4 x i8>, <4 x i8>* {{.*}}, align 1, !nontemporal !0
 ; CHECk: ret i8 %{{.*}}
 ;
 entry:
@@ -303,7 +303,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define half @test_half_load(half* %ddst) {
 ; CHECK-LABEL: @test_half_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x half>, <4 x half>* {{.*}}, align 2, !nontemporal !0
 ; CHECk: ret half %{{.*}}
 ;
 entry:
@@ -325,7 +326,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i16 @test_i16_load(i16* %ddst) {
 ; CHECK-LABEL: @test_i16_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x i16>, <4 x i16>* {{.*}}, align 2, !nontemporal !0
 ; CHECk: ret i16 %{{.*}}
 ;
 entry:
@@ -347,7 +349,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i32 @test_i32_load(i32* %ddst) {
 ; CHECK-LABEL: @test_i32_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x i32>, <4 x i32>* {{.*}}, align 4, !nontemporal !0
 ; CHECk: ret i32 %{{.*}}
 ;
 entry:
@@ -413,7 +416,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i64 @test_i64_load(i64* %ddst) {
 ; CHECK-LABEL: @test_i64_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x i64>, <4 x i64>* {{.*}}, align 4, !nontemporal !0
 ; CHECk: ret i64 %{{.*}}
 ;
 entry:
@@ -435,7 +439,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define double @test_double_load(double* %ddst) {
 ; CHECK-LABEL: @test_double_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x double>, <4 x double>* {{.*}}, align 4, !nontemporal !0
 ; CHECk: ret double %{{.*}}
 ;
 entry:
@@ -457,7 +462,8 @@ for.cond.cleanup:                                 ; preds = %for.body
 
 define i128 @test_i128_load(i128* %ddst) {
 ; CHECK-LABEL: @test_i128_load
-; CHECK-NOT:   vector.body:
+; CHECK-LABEL:   vector.body:
+; CHECK: load <4 x i128>, <4 x i128>* {{.*}}, align 4, !nontemporal !0
 ; CHECk: ret i128 %{{.*}}
 ;
 entry: