[SVE][AArch64TTI] Fix invalid mla combine that miscomputes the value of inactive...
authorPaul Walker <paul.walker@arm.com>
Sat, 17 Jun 2023 16:51:49 +0000 (17:51 +0100)
committerPaul Walker <paul.walker@arm.com>
Sun, 18 Jun 2023 12:07:03 +0000 (13:07 +0100)
Consider: add(pg, a, mul_u(pg, b, c))

Although the multiply's inactive lanes are undefined, they don't
contribute to the final result.  The overall result of the inactive
lanes come from "a" and thus the above is another form of mla
rather than mla_u.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-muladdsub.ll

index 7b587da..8a00325 100644 (file)
@@ -1305,11 +1305,11 @@ instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
                                             Intrinsic::aarch64_sve_fmad>(IC, II,
                                                                          false))
     return FMAD;
-  if (auto FMLA_U =
+  if (auto FMLA =
           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
-                                            Intrinsic::aarch64_sve_fmla_u>(
-              IC, II, true))
-    return FMLA_U;
+                                            Intrinsic::aarch64_sve_fmla>(IC, II,
+                                                                         true))
+    return FMLA;
   return instCombineSVEVectorBinOp(IC, II);
 }
 
@@ -1345,11 +1345,11 @@ instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
                                             Intrinsic::aarch64_sve_fnmsb>(
               IC, II, false))
     return FMSB;
-  if (auto FMLS_U =
+  if (auto FMLS =
           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
-                                            Intrinsic::aarch64_sve_fmls_u>(
-              IC, II, true))
-    return FMLS_U;
+                                            Intrinsic::aarch64_sve_fmls>(IC, II,
+                                                                         true))
+    return FMLS;
   return instCombineSVEVectorBinOp(IC, II);
 }
 
index a1c6a37..f9239dd 100644 (file)
@@ -14,11 +14,10 @@ define <vscale x 8 x half> @combine_fmuladd_1(<vscale x 8 x i1> %p, <vscale x 8
   ret <vscale x 8 x half> %2
 }
 
-; TODO: Test highlights an invalid combine!
 ; fadd(a, fmul_u(b, c)) -> fmla(a, b, c)
 define <vscale x 8 x half> @combine_fmuladd_2(<vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) #0 {
 ; CHECK-LABEL: @combine_fmuladd_2(
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.u.nxv8f16(<vscale x 8 x i1> [[P:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.nxv8f16(<vscale x 8 x i1> [[P:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP1]]
 ;
   %1 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.u.nxv8f16(<vscale x 8 x i1> %p, <vscale x 8 x half> %b, <vscale x 8 x half> %c)
@@ -109,11 +108,9 @@ define <vscale x 8 x half> @combine_fmulsub_1(<vscale x 8 x i1> %p, <vscale x 8
   ret <vscale x 8 x half> %2
 }
 
-; TODO: Test highlights an invalid combine!
-; fsub(a, fmul_u(b, c)) -> fmls(a, b, c)
 define <vscale x 8 x half> @combine_fmulsub_2(<vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) #0 {
 ; CHECK-LABEL: @combine_fmulsub_2(
-; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmls.u.nxv8f16(<vscale x 8 x i1> [[P:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmls.nxv8f16(<vscale x 8 x i1> [[P:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]], <vscale x 8 x half> [[C:%.*]])
 ; CHECK-NEXT:    ret <vscale x 8 x half> [[TMP1]]
 ;
   %1 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.u.nxv8f16(<vscale x 8 x i1> %p, <vscale x 8 x half> %b, <vscale x 8 x half> %c)