// * Single constant active lane -> store
// * Adjacent vector addresses -> masked.store
// * Narrow store width by halfs excluding zero/undef lanes
-// * Vector splat address w/known mask -> scalar store
// * Vector incrementing address -> vector masked store
Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
+ // Vector splat address -> scalar store
+ if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
+ // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
+ if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
+ Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+ StoreInst *S =
+ new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
+ S->copyMetadata(II);
+ return S;
+ }
+ // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
+ // lastlane), ptr
+ if (ConstMask->isAllOnesValue()) {
+ Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+ VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
+ ElementCount VF = WideLoadTy->getElementCount();
+ Constant *EC =
+ ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
+ Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
+ Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
+ Value *Extract =
+ Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
+ StoreInst *S =
+ new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
+ S->copyMetadata(II);
+ return S;
+ }
+ }
if (isa<ScalableVectorType>(ConstMask->getType()))
return nullptr;
call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
ret void
}
+
+
+; Test scatters that can be simplified to scalar stores.
+
+;; Value splat (mask is not used)
+define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
+; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT: ret void
+;
+entry:
+ %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
+ %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
+ %broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
+ %broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
+ call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
+ ret void
+}
+
+define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
+; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT: ret void
+;
+entry:
+ %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
+ %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
+ %broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
+ %broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+ call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
+ ret void
+}
+
+;; The pointer is splat and mask is all active, but value is not a splat
+define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) {
+; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
+; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT: ret void
+;
+entry:
+ %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
+ %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
+ %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+ call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
+ ret void
+}
+
+define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
+; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
+; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2
+; CHECK-NEXT: ret void
+;
+entry:
+ %broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
+ %broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
+ %wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
+ call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+ ret void
+}
+
+; Negative scatter tests
+
+;; Pointer is splat, but mask is not all active and value is not a splat
+define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
+; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
+; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
+; CHECK-NEXT: ret void
+;
+ %insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
+ %broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
+ %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+ call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
+ ret void
+}
+
+;; The pointer in NOT a splat
+define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
+; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
+; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
+; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT: ret void
+;
+ %broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
+ %wide.load = load <4 x i16>, <4 x i16>* %src, align 2
+ call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
+ ret void
+}
+
+
+; Function Attrs:
+declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
+declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)