From 5358968e131cc486f236b0a5593d2201a609cdc1 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Sat, 24 Sep 2022 17:06:48 -0700 Subject: [PATCH] [RISCV] Pattern match scalable strided load/store Very straight forward extension of the existing pattern matching pass to handle scalable types as well as fixed length types. The only extra bit beyond removing a bailout is recognizing stepvector. Differential Revision: https://reviews.llvm.org/D134502 --- .../Target/RISCV/RISCVGatherScatterLowering.cpp | 14 +++++++--- llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll | 30 +++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index 2410cc1..eada2bf 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -21,9 +21,11 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsRISCV.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "riscv-gather-scatter-lowering" @@ -139,6 +141,12 @@ static std::pair matchStridedStart(Value *Start, if (StartC) return matchStridedConstant(StartC); + // Base case, start is a stepvector + if (match(Start, m_Intrinsic())) { + auto *Ty = Start->getType()->getScalarType(); + return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); + } + // Not a constant, maybe it's a strided constant with a splat added to it. auto *BO = dyn_cast(Start); if (!BO || BO->getOpcode() != Instruction::Add) @@ -482,11 +490,9 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) { for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); - if (II && II->getIntrinsicID() == Intrinsic::masked_gather && - isa(II->getType())) { + if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { Gathers.push_back(II); - } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && - isa(II->getArgOperand(0)->getType())) { + } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { Scatters.push_back(II); } } diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll index dd663be..be151e4 100644 --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -10,21 +10,18 @@ define @gather(ptr %a, i32 %len) { ; CHECK-NEXT: vector.ph: ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP0]], i64 0 -; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_IND:%.*]] = phi [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[ACCUM:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[ACCUM_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], [[VEC_IND]], i32 3 -; CHECK-NEXT: [[GATHER:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[TMP2]], i32 8, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), undef) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3 +; CHECK-NEXT: [[GATHER:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( undef, ptr [[TMP1]], i64 16, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) ; CHECK-NEXT: [[ACCUM_NEXT]] = add [[ACCUM]], [[GATHER]] ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] -; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret [[ACCUM_NEXT]] ; @@ -57,19 +54,16 @@ define void @scatter(ptr %a, i32 %len) { ; CHECK-NEXT: vector.ph: ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP0]], i64 0 -; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_IND:%.*]] = phi [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], [[VEC_IND]], i32 3 -; CHECK-NEXT: tail call void @llvm.masked.scatter.nxv1i64.nxv1p0( zeroinitializer, [[TMP2]], i32 8, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) +; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3 +; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64( zeroinitializer, ptr [[TMP1]], i64 16, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] -; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret void ; -- 2.7.4