[X86] Split AVX512 getCastInstrCost into tables that require useAVX512Regs() and...
authorCraig Topper <craig.topper@intel.com>
Tue, 14 Apr 2020 02:29:33 +0000 (19:29 -0700)
committerCraig Topper <craig.topper@intel.com>
Tue, 14 Apr 2020 04:09:42 +0000 (21:09 -0700)
Use useAVX512Regs() to skip lookups instead of using type legalization
action.

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

index fc7c317..da776be 100644 (file)
@@ -1319,18 +1319,10 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     { ISD::ZERO_EXTEND, MVT::v32i16, MVT::v32i8, 1 },
 
     // Mask sign extend has an instruction.
-    { ISD::SIGN_EXTEND, MVT::v8i16,  MVT::v8i1,  1 },
-    { ISD::SIGN_EXTEND, MVT::v16i8,  MVT::v16i1, 1 },
-    { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i1, 1 },
-    { ISD::SIGN_EXTEND, MVT::v32i8,  MVT::v32i1, 1 },
     { ISD::SIGN_EXTEND, MVT::v32i16, MVT::v32i1, 1 },
     { ISD::SIGN_EXTEND, MVT::v64i8,  MVT::v64i1, 1 },
 
     // Mask zero extend is a load + broadcast.
-    { ISD::ZERO_EXTEND, MVT::v8i16,  MVT::v8i1,  2 },
-    { ISD::ZERO_EXTEND, MVT::v16i8,  MVT::v16i1, 2 },
-    { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i1, 2 },
-    { ISD::ZERO_EXTEND, MVT::v32i8,  MVT::v32i1, 2 },
     { ISD::ZERO_EXTEND, MVT::v32i16, MVT::v32i1, 2 },
     { ISD::ZERO_EXTEND, MVT::v64i8,  MVT::v64i1, 2 },
 
@@ -1338,32 +1330,16 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   };
 
   static const TypeConversionCostTblEntry AVX512DQConversionTbl[] = {
-    { ISD::SINT_TO_FP,  MVT::v2f32,  MVT::v2i64,  1 },
-    { ISD::SINT_TO_FP,  MVT::v2f64,  MVT::v2i64,  1 },
-    { ISD::SINT_TO_FP,  MVT::v4f32,  MVT::v4i64,  1 },
-    { ISD::SINT_TO_FP,  MVT::v4f64,  MVT::v4i64,  1 },
     { ISD::SINT_TO_FP,  MVT::v8f32,  MVT::v8i64,  1 },
     { ISD::SINT_TO_FP,  MVT::v8f64,  MVT::v8i64,  1 },
 
-    { ISD::UINT_TO_FP,  MVT::v2f32,  MVT::v2i64,  1 },
-    { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i64,  1 },
-    { ISD::UINT_TO_FP,  MVT::v4f32,  MVT::v4i64,  1 },
-    { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i64,  1 },
     { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i64,  1 },
     { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i64,  1 },
 
-    { ISD::FP_TO_SINT,  MVT::v2i64,  MVT::v2f32,  1 },
-    { ISD::FP_TO_SINT,  MVT::v4i64,  MVT::v4f32,  1 },
     { ISD::FP_TO_SINT,  MVT::v8i64,  MVT::v8f32,  1 },
-    { ISD::FP_TO_SINT,  MVT::v2i64,  MVT::v2f64,  1 },
-    { ISD::FP_TO_SINT,  MVT::v4i64,  MVT::v4f64,  1 },
     { ISD::FP_TO_SINT,  MVT::v8i64,  MVT::v8f64,  1 },
 
-    { ISD::FP_TO_UINT,  MVT::v2i64,  MVT::v2f32,  1 },
-    { ISD::FP_TO_UINT,  MVT::v4i64,  MVT::v4f32,  1 },
     { ISD::FP_TO_UINT,  MVT::v8i64,  MVT::v8f32,  1 },
-    { ISD::FP_TO_UINT,  MVT::v2i64,  MVT::v2f64,  1 },
-    { ISD::FP_TO_UINT,  MVT::v4i64,  MVT::v4f64,  1 },
     { ISD::FP_TO_UINT,  MVT::v8i64,  MVT::v8f64,  1 },
   };
 
@@ -1406,28 +1382,74 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
 
     { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i1,   4 },
     { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i1,  3 },
+    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i8,   2 },
+    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i8,  2 },
+    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i16,  2 },
+    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i16, 2 },
+    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i32,  1 },
+    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i32, 1 },
+    { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i64, 26 },
+    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i64,  5 },
+
+    { ISD::FP_TO_UINT,  MVT::v8i32,  MVT::v8f64,  1 },
+    { ISD::FP_TO_UINT,  MVT::v8i16,  MVT::v8f64,  2 },
+    { ISD::FP_TO_UINT,  MVT::v8i8,   MVT::v8f64,  2 },
+    { ISD::FP_TO_UINT,  MVT::v16i32, MVT::v16f32, 1 },
+    { ISD::FP_TO_UINT,  MVT::v16i16, MVT::v16f32, 2 },
+    { ISD::FP_TO_UINT,  MVT::v16i8,  MVT::v16f32, 2 },
+  };
+
+  static const TypeConversionCostTblEntry AVX512BWVLConversionTbl[] {
+    // Mask sign extend has an instruction.
+    { ISD::SIGN_EXTEND, MVT::v8i16,  MVT::v8i1,  1 },
+    { ISD::SIGN_EXTEND, MVT::v16i8,  MVT::v16i1, 1 },
+    { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i1, 1 },
+    { ISD::SIGN_EXTEND, MVT::v32i8,  MVT::v32i1, 1 },
+
+    // Mask zero extend is a load + broadcast.
+    { ISD::ZERO_EXTEND, MVT::v8i16,  MVT::v8i1,  2 },
+    { ISD::ZERO_EXTEND, MVT::v16i8,  MVT::v16i1, 2 },
+    { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i1, 2 },
+    { ISD::ZERO_EXTEND, MVT::v32i8,  MVT::v32i1, 2 },
+  };
+
+  static const TypeConversionCostTblEntry AVX512DQVLConversionTbl[] = {
+    { ISD::SINT_TO_FP,  MVT::v2f32,  MVT::v2i64,  1 },
+    { ISD::SINT_TO_FP,  MVT::v2f64,  MVT::v2i64,  1 },
+    { ISD::SINT_TO_FP,  MVT::v4f32,  MVT::v4i64,  1 },
+    { ISD::SINT_TO_FP,  MVT::v4f64,  MVT::v4i64,  1 },
+
+    { ISD::UINT_TO_FP,  MVT::v2f32,  MVT::v2i64,  1 },
+    { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i64,  1 },
+    { ISD::UINT_TO_FP,  MVT::v4f32,  MVT::v4i64,  1 },
+    { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i64,  1 },
+
+    { ISD::FP_TO_SINT,  MVT::v2i64,  MVT::v2f32,  1 },
+    { ISD::FP_TO_SINT,  MVT::v4i64,  MVT::v4f32,  1 },
+    { ISD::FP_TO_SINT,  MVT::v2i64,  MVT::v2f64,  1 },
+    { ISD::FP_TO_SINT,  MVT::v4i64,  MVT::v4f64,  1 },
+
+    { ISD::FP_TO_UINT,  MVT::v2i64,  MVT::v2f32,  1 },
+    { ISD::FP_TO_UINT,  MVT::v4i64,  MVT::v4f32,  1 },
+    { ISD::FP_TO_UINT,  MVT::v2i64,  MVT::v2f64,  1 },
+    { ISD::FP_TO_UINT,  MVT::v4i64,  MVT::v4f64,  1 },
+  };
+
+  static const TypeConversionCostTblEntry AVX512VLConversionTbl[] = {
     { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i8,   2 },
     { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i8,   2 },
     { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i8,   2 },
-    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i8,   2 },
-    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i8,  2 },
     { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i16,  5 },
     { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i16,  2 },
     { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i16,  2 },
-    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i16,  2 },
-    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i16, 2 },
     { ISD::UINT_TO_FP,  MVT::v2f32,  MVT::v2i32,  2 },
     { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i32,  1 },
     { ISD::UINT_TO_FP,  MVT::v4f32,  MVT::v4i32,  1 },
     { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i32,  1 },
     { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i32,  1 },
-    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i32,  1 },
-    { ISD::UINT_TO_FP,  MVT::v16f32, MVT::v16i32, 1 },
     { ISD::UINT_TO_FP,  MVT::v2f32,  MVT::v2i64,  5 },
-    { ISD::UINT_TO_FP,  MVT::v8f32,  MVT::v8i64, 26 },
     { ISD::UINT_TO_FP,  MVT::v2f64,  MVT::v2i64,  5 },
     { ISD::UINT_TO_FP,  MVT::v4f64,  MVT::v4i64,  5 },
-    { ISD::UINT_TO_FP,  MVT::v8f64,  MVT::v8i64,  5 },
 
     { ISD::UINT_TO_FP,  MVT::f32,    MVT::i64,    1 },
     { ISD::UINT_TO_FP,  MVT::f64,    MVT::i64,    1 },
@@ -1438,12 +1460,6 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     { ISD::FP_TO_UINT,  MVT::v4i32,  MVT::v4f32,  1 },
     { ISD::FP_TO_UINT,  MVT::v4i32,  MVT::v4f64,  1 },
     { ISD::FP_TO_UINT,  MVT::v8i32,  MVT::v8f32,  1 },
-    { ISD::FP_TO_UINT,  MVT::v8i32,  MVT::v8f64,  1 },
-    { ISD::FP_TO_UINT,  MVT::v8i16,  MVT::v8f64,  2 },
-    { ISD::FP_TO_UINT,  MVT::v8i8,   MVT::v8f64,  2 },
-    { ISD::FP_TO_UINT,  MVT::v16i32, MVT::v16f32, 1 },
-    { ISD::FP_TO_UINT,  MVT::v16i16, MVT::v16f32, 2 },
-    { ISD::FP_TO_UINT,  MVT::v16i8,  MVT::v16f32, 2 },
   };
 
   static const TypeConversionCostTblEntry AVX2ConversionTbl[] = {
@@ -1693,11 +1709,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   MVT SimpleSrcTy = SrcTy.getSimpleVT();
   MVT SimpleDstTy = DstTy.getSimpleVT();
 
-  // Make sure that neither type is going to be split before using the
-  // AVX512 tables. This handles -mprefer-vector-width=256
-  // with -min-legal-vector-width<=256
-  if (TLI->getTypeAction(SimpleSrcTy) != TargetLowering::TypeSplitVector &&
-      TLI->getTypeAction(SimpleDstTy) != TargetLowering::TypeSplitVector) {
+  if (ST->useAVX512Regs()) {
     if (ST->hasBWI())
       if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD,
                                                      SimpleDstTy, SimpleSrcTy))
@@ -1714,6 +1726,21 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
         return Entry->Cost;
   }
 
+  if (ST->hasBWI())
+    if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD,
+                                                   SimpleDstTy, SimpleSrcTy))
+      return Entry->Cost;
+
+  if (ST->hasDQI())
+    if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD,
+                                                   SimpleDstTy, SimpleSrcTy))
+      return Entry->Cost;
+
+  if (ST->hasAVX512())
+    if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
+                                                   SimpleDstTy, SimpleSrcTy))
+      return Entry->Cost;
+
   if (ST->hasAVX2()) {
     if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))