From: Florian Hahn Date: Wed, 7 Jun 2023 19:45:07 +0000 (+0100) Subject: [Matrix] Convert binop operand of dot product to a row vector. X-Git-Tag: upstream/17.0.6~5796 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c10a7772bd7603cb6b252ef89e2009fa0c60dc94;p=platform%2Fupstream%2Fllvm.git [Matrix] Convert binop operand of dot product to a row vector. The dot product lowering will use the left operand as row vector. If the operand is a binary op, convert it to operate on a row vector instead of a column vector. Depends on D148428. Reviewed By: thegameg Differential Revision: https://reviews.llvm.org/D148429 --- diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 8508c90cc939..6fe02a950415 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1358,7 +1358,9 @@ public: if (!IsIntVec && !FMF.allowReassoc()) return; - auto CanBeFlattened = [](Value *Op) { + auto CanBeFlattened = [this](Value *Op) { + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + return true; return match( Op, m_OneUse(m_CombineOr( m_Load(m_Value()), @@ -1386,6 +1388,16 @@ public: return EmbedCost; } + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + InstructionCost OriginalCost = + TTI.getArithmeticInstrCost(cast(Op)->getOpcode(), + EltTy) * + N; + InstructionCost NewCost = TTI.getArithmeticInstrCost( + cast(Op)->getOpcode(), VecTy); + return NewCost - OriginalCost; + } + if (match(Op, m_Intrinsic())) { // The transpose can be skipped for the dot product lowering, roughly // estimate the savings as the cost of embedding the columns in a @@ -1433,8 +1445,12 @@ public: if (!CanBeFlattened(Op)) return Op; - FusedInsts.insert(cast(Op)); + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + ShapeMap[Op] = ShapeMap[Op].t(); + return Op; + } + FusedInsts.insert(cast(Op)); // If vector uses the builtin load, lower to a LoadInst Value *Arg; if (match(Op, m_Intrinsic( diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll index e998fb2483ca..feb38ee3a618 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll @@ -51,41 +51,13 @@ entry: define <1 x i32> @add_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c) { ; CHECK-LABEL: @add_feeding_dotproduct_i32_v8_1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]] -; CHECK-NEXT: [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]] -; CHECK-NEXT: [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]] -; CHECK-NEXT: [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]] -; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]] -; CHECK-NEXT: [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]] -; CHECK-NEXT: [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]] -; CHECK-NEXT: [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]] -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> -; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> -; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> -; CHECK-NEXT: [[TMP15:%.*]] = mul <8 x i32> [[TMP14]], [[C:%.*]] -; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP15]]) -; CHECK-NEXT: [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0 -; CHECK-NEXT: ret <1 x i32> [[TMP17]] +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]] +; CHECK-NEXT: [[TMP1:%.*]] = mul <8 x i32> [[TMP0]], [[C:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP3]] ; entry: %add = add <8 x i32> %a, %b @@ -113,41 +85,13 @@ entry: define <1 x i32> @sub_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c) { ; CHECK-LABEL: @sub_feeding_dotproduct_i32_v8_1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = sub <1 x i32> [[SPLIT]], [[SPLIT8]] -; CHECK-NEXT: [[TMP1:%.*]] = sub <1 x i32> [[SPLIT1]], [[SPLIT9]] -; CHECK-NEXT: [[TMP2:%.*]] = sub <1 x i32> [[SPLIT2]], [[SPLIT10]] -; CHECK-NEXT: [[TMP3:%.*]] = sub <1 x i32> [[SPLIT3]], [[SPLIT11]] -; CHECK-NEXT: [[TMP4:%.*]] = sub <1 x i32> [[SPLIT4]], [[SPLIT12]] -; CHECK-NEXT: [[TMP5:%.*]] = sub <1 x i32> [[SPLIT5]], [[SPLIT13]] -; CHECK-NEXT: [[TMP6:%.*]] = sub <1 x i32> [[SPLIT6]], [[SPLIT14]] -; CHECK-NEXT: [[TMP7:%.*]] = sub <1 x i32> [[SPLIT7]], [[SPLIT15]] -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> -; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> -; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> -; CHECK-NEXT: [[TMP15:%.*]] = mul <8 x i32> [[TMP14]], [[C:%.*]] -; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP15]]) -; CHECK-NEXT: [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0 -; CHECK-NEXT: ret <1 x i32> [[TMP17]] +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = sub <8 x i32> [[SPLIT]], [[SPLIT1]] +; CHECK-NEXT: [[TMP1:%.*]] = mul <8 x i32> [[TMP0]], [[C:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP3]] ; entry: %sub = sub <8 x i32> %a, %b @@ -199,33 +143,20 @@ define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]] ; CHECK-NEXT: [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]] ; CHECK-NEXT: [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]] -; CHECK-NEXT: [[SPLIT16:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT18:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT19:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT20:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT21:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT22:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[SPLIT23:%.*]] = shufflevector <8 x i32> [[C]], <8 x i32> poison, <1 x i32> -; CHECK-NEXT: [[TMP8:%.*]] = add <1 x i32> [[TMP0]], [[SPLIT16]] -; CHECK-NEXT: [[TMP9:%.*]] = add <1 x i32> [[TMP1]], [[SPLIT17]] -; CHECK-NEXT: [[TMP10:%.*]] = add <1 x i32> [[TMP2]], [[SPLIT18]] -; CHECK-NEXT: [[TMP11:%.*]] = add <1 x i32> [[TMP3]], [[SPLIT19]] -; CHECK-NEXT: [[TMP12:%.*]] = add <1 x i32> [[TMP4]], [[SPLIT20]] -; CHECK-NEXT: [[TMP13:%.*]] = add <1 x i32> [[TMP5]], [[SPLIT21]] -; CHECK-NEXT: [[TMP14:%.*]] = add <1 x i32> [[TMP6]], [[SPLIT22]] -; CHECK-NEXT: [[TMP15:%.*]] = add <1 x i32> [[TMP7]], [[SPLIT23]] -; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x i32> [[TMP8]], <1 x i32> [[TMP9]], <2 x i32> -; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <1 x i32> [[TMP10]], <1 x i32> [[TMP11]], <2 x i32> -; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <1 x i32> [[TMP12]], <1 x i32> [[TMP13]], <2 x i32> -; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP14]], <1 x i32> [[TMP15]], <2 x i32> -; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i32> [[TMP16]], <2 x i32> [[TMP17]], <4 x i32> -; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <2 x i32> [[TMP18]], <2 x i32> [[TMP19]], <4 x i32> -; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <4 x i32> [[TMP20]], <4 x i32> [[TMP21]], <8 x i32> -; CHECK-NEXT: [[TMP23:%.*]] = mul <8 x i32> [[TMP22]], [[D:%.*]] -; CHECK-NEXT: [[TMP24:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP23]]) -; CHECK-NEXT: [[TMP25:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i64 0 -; CHECK-NEXT: ret <1 x i32> [[TMP25]] +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> +; CHECK-NEXT: [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]] +; CHECK-NEXT: [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]] +; CHECK-NEXT: [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]]) +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP18]] ; entry: %add.1 = add <8 x i32> %a, %b