[SPIRV] fix several issues in builds with expensive checks
authorIlia Diachkov <ilia.diachkov@gmail.com>
Mon, 27 Feb 2023 18:16:48 +0000 (21:16 +0300)
committerIlia Diachkov <ilia.diachkov@gmail.com>
Thu, 16 Mar 2023 21:08:23 +0000 (00:08 +0300)
The patch fixes "Virtual register does not match instruction constraint"
and partly "Illegal virtual register for instruction" fails in the SPIRV
backend builds with LLVM_ENABLE_EXPENSIVE_CHECKS enabled. As a result,
the number of passed LIT tests with enabled checks is doubled.

Also, support for ndrange_*D builtins is placed in a separate function.

Differential Revision: https://reviews.llvm.org/D144897

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

index c11b36a..40b6520 100644 (file)
@@ -291,6 +291,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
 
   Register ResultRegister =
       MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
+  MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass);
   GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
   return std::make_tuple(ResultRegister, BoolType);
 }
@@ -417,33 +418,41 @@ static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
 }
 
 static Register buildScopeReg(Register CLScopeRegister,
+                              SPIRV::Scope::Scope Scope,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR,
-                              const MachineRegisterInfo *MRI) {
-  auto CLScope =
-      static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
-  SPIRV::Scope::Scope Scope = getSPIRVScope(CLScope);
-
-  if (CLScope == static_cast<unsigned>(Scope))
-    return CLScopeRegister;
-
+                              MachineRegisterInfo *MRI) {
+  if (CLScopeRegister.isValid()) {
+    auto CLScope =
+        static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
+    Scope = getSPIRVScope(CLScope);
+
+    if (CLScope == static_cast<unsigned>(Scope)) {
+      MRI->setRegClass(CLScopeRegister, &SPIRV::IDRegClass);
+      return CLScopeRegister;
+    }
+  }
   return buildConstantIntReg(Scope, MIRBuilder, GR);
 }
 
 static Register buildMemSemanticsReg(Register SemanticsRegister,
-                                     Register PtrRegister,
-                                     const MachineRegisterInfo *MRI,
+                                     Register PtrRegister, unsigned &Semantics,
+                                     MachineIRBuilder &MIRBuilder,
                                      SPIRVGlobalRegistry *GR) {
-  std::memory_order Order =
-      static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
-  unsigned Semantics =
-      getSPIRVMemSemantics(Order) |
-      getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
-
-  if (Order == Semantics)
-    return SemanticsRegister;
+  if (SemanticsRegister.isValid()) {
+    MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+    std::memory_order Order =
+        static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
+    Semantics =
+        getSPIRVMemSemantics(Order) |
+        getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
 
-  return Register();
+    if (Order == Semantics) {
+      MRI->setRegClass(SemanticsRegister, &SPIRV::IDRegClass);
+      return SemanticsRegister;
+    }
+  }
+  return buildConstantIntReg(Semantics, MIRBuilder, GR);
 }
 
 /// Helper function for translating atomic init to OpStore.
@@ -451,7 +460,8 @@ static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder) {
   assert(Call->Arguments.size() == 2 &&
          "Need 2 arguments for atomic init translation");
-
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpStore)
       .addUse(Call->Arguments[0])
       .addUse(Call->Arguments[1]);
@@ -463,19 +473,22 @@ static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
   Register PtrRegister = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
   // TODO: if true insert call to __translate_ocl_memory_sccope before
   // OpAtomicLoad and the function implementation. We can use Translator's
   // output for transcoding/atomic_explicit_arguments.cl as an example.
   Register ScopeRegister;
-  if (Call->Arguments.size() > 1)
+  if (Call->Arguments.size() > 1) {
     ScopeRegister = Call->Arguments[1];
-  else
+    MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::IDRegClass);
+  } else
     ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
 
   Register MemSemanticsReg;
   if (Call->Arguments.size() > 2) {
     // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
     MemSemanticsReg = Call->Arguments[2];
+    MIRBuilder.getMRI()->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
   } else {
     int Semantics =
         SPIRV::MemorySemantics::SequentiallyConsistent |
@@ -499,11 +512,12 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
   Register ScopeRegister =
       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
   Register PtrRegister = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
   int Semantics =
       SPIRV::MemorySemantics::SequentiallyConsistent |
       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
   Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
-
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
       .addUse(PtrRegister)
       .addUse(ScopeRegister)
@@ -525,6 +539,9 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
   Register ObjectPtr = Call->Arguments[0];   // Pointer (volatile A *object.)
   Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
   Register Desired = Call->Arguments[2];     // Value (C Desired).
+  MRI->setRegClass(ObjectPtr, &SPIRV::IDRegClass);
+  MRI->setRegClass(ExpectedArg, &SPIRV::IDRegClass);
+  MRI->setRegClass(Desired, &SPIRV::IDRegClass);
   SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
   LLT DesiredLLT = MRI->getType(Desired);
 
@@ -564,6 +581,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
       MemSemEqualReg = Call->Arguments[3];
     if (MemOrdNeq == MemSemEqual)
       MemSemUnequalReg = Call->Arguments[4];
+    MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[4], &SPIRV::IDRegClass);
   }
   if (!MemSemEqualReg.isValid())
     MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
@@ -580,6 +599,7 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
     Scope = getSPIRVScope(ClScope);
     if (ClScope == static_cast<unsigned>(Scope))
       ScopeReg = Call->Arguments[5];
+    MRI->setRegClass(Call->Arguments[5], &SPIRV::IDRegClass);
   }
   if (!ScopeReg.isValid())
     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
@@ -591,6 +611,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
   MRI->setType(Expected, DesiredLLT);
   Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
                             : Call->ReturnRegister;
+  if (!MRI->getRegClassOrNull(Tmp))
+    MRI->setRegClass(Tmp, &SPIRV::IDRegClass);
   GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
 
   SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
@@ -614,30 +636,23 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                                MachineIRBuilder &MIRBuilder,
                                SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
-  Register ScopeRegister;
-
-  if (Call->Arguments.size() >= 4) {
-    assert(Call->Arguments.size() == 4 &&
-           "Too many args for explicit atomic RMW");
-    ScopeRegister = buildScopeReg(Call->Arguments[3], MIRBuilder, GR, MRI);
-  }
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  Register ScopeRegister =
+      Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register();
 
-  if (!ScopeRegister.isValid())
-    ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+  assert(Call->Arguments.size() <= 4 &&
+         "Too many args for explicit atomic RMW");
+  ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup,
+                                MIRBuilder, GR, MRI);
 
   Register PtrRegister = Call->Arguments[0];
   unsigned Semantics = SPIRV::MemorySemantics::None;
-  Register MemSemanticsReg;
-
-  if (Call->Arguments.size() >= 3)
-    MemSemanticsReg =
-        buildMemSemanticsReg(Call->Arguments[2], PtrRegister, MRI, GR);
-
-  if (!MemSemanticsReg.isValid())
-    MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
-
+  MRI->setRegClass(PtrRegister, &SPIRV::IDRegClass);
+  Register MemSemanticsReg =
+      Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
+  MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
+                                         Semantics, MIRBuilder, GR);
+  MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(Opcode)
       .addDef(Call->ReturnRegister)
       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
@@ -653,32 +668,23 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register PtrRegister = Call->Arguments[0];
   unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
-  Register MemSemanticsReg;
-
-  if (Call->Arguments.size() >= 2)
-    MemSemanticsReg =
-        buildMemSemanticsReg(Call->Arguments[1], PtrRegister, MRI, GR);
-
-  if (!MemSemanticsReg.isValid())
-    MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
+  Register MemSemanticsReg =
+      Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register();
+  MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
+                                         Semantics, MIRBuilder, GR);
 
   assert((Opcode != SPIRV::OpAtomicFlagClear ||
           (Semantics != SPIRV::MemorySemantics::Acquire &&
            Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
          "Invalid memory order argument!");
 
-  SPIRV::Scope::Scope Scope = SPIRV::Scope::Device;
-  Register ScopeRegister;
-
-  if (Call->Arguments.size() >= 3)
-    ScopeRegister = buildScopeReg(Call->Arguments[2], MIRBuilder, GR, MRI);
-
-  if (!ScopeRegister.isValid())
-    ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+  Register ScopeRegister =
+      Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
+  ScopeRegister =
+      buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI);
 
   auto MIB = MIRBuilder.buildInstr(Opcode);
   if (Opcode == SPIRV::OpAtomicFlagTestAndSet)
@@ -694,7 +700,7 @@ static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                              MachineIRBuilder &MIRBuilder,
                              SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
   unsigned MemSemantics = SPIRV::MemorySemantics::None;
 
@@ -716,9 +722,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
   }
 
   Register MemSemanticsReg;
-  if (MemFlags == MemSemantics)
+  if (MemFlags == MemSemantics) {
     MemSemanticsReg = Call->Arguments[0];
-  else
+    MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
+  } else
     MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
 
   Register ScopeReg;
@@ -738,8 +745,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
         (Opcode == SPIRV::OpMemoryBarrier))
       Scope = MemScope;
 
-    if (CLScope == static_cast<unsigned>(Scope))
+    if (CLScope == static_cast<unsigned>(Scope)) {
       ScopeReg = Call->Arguments[1];
+      MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+    }
   }
 
   if (!ScopeReg.isValid())
@@ -834,7 +843,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   const SPIRV::GroupBuiltin *GroupBuiltin =
       SPIRV::lookupGroupBuiltin(Builtin->Name);
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register Arg0;
   if (GroupBuiltin->HasBoolArg) {
     Register ConstRegister = Call->Arguments[0];
@@ -876,8 +885,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
     MIB.addImm(GroupBuiltin->GroupOperation);
   if (Call->Arguments.size() > 0) {
     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
-    for (unsigned i = 1; i < Call->Arguments.size(); i++)
+    MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    for (unsigned i = 1; i < Call->Arguments.size(); i++) {
       MIB.addUse(Call->Arguments[i]);
+      MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
+    }
   }
 
   // Build select instruction.
@@ -936,16 +948,17 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
   // If it's out of range (max dimension is 3), we can just return the constant
   // default value (0 or 1 depending on which query function).
   if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
-    Register defaultReg = Call->ReturnRegister;
+    Register DefaultReg = Call->ReturnRegister;
     if (PointerSize != ResultWidth) {
-      defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
-      GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg,
+      DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+      MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass);
+      GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg,
                                 MIRBuilder.getMF());
-      ToTruncate = defaultReg;
+      ToTruncate = DefaultReg;
     }
     auto NewRegister =
         GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
-    MIRBuilder.buildCopy(defaultReg, NewRegister);
+    MIRBuilder.buildCopy(DefaultReg, NewRegister);
   } else { // If it could be in range, we need to load from the given builtin.
     auto Vec3Ty =
         GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
@@ -956,6 +969,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
     Register Extracted = Call->ReturnRegister;
     if (!IsConstantIndex || PointerSize != ResultWidth) {
       Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+      MRI->setRegClass(Extracted, &SPIRV::IDRegClass);
       GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
     }
     // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
@@ -974,6 +988,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
 
       Register CompareRegister =
           MRI->createGenericVirtualRegister(LLT::scalar(1));
+      MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass);
       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
 
       // Use G_ICMP to check if idxVReg < 3.
@@ -990,6 +1005,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       if (PointerSize != ResultWidth) {
         SelectionResult =
             MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+        MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass);
         GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
                                   MIRBuilder.getMF());
       }
@@ -1125,6 +1141,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
   if (NumExpectedRetComponents != NumActualRetComponents) {
     QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
         LLT::fixed_vector(NumActualRetComponents, 32));
+    MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass);
     SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
     QueryResultType = GR->getOrCreateSPIRVVectorType(
         IntTy, NumActualRetComponents, MIRBuilder);
@@ -1133,6 +1150,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
   bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
   unsigned Opcode =
       IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
   auto MIB = MIRBuilder.buildInstr(Opcode)
                  .addDef(QueryResult)
                  .addUse(GR->getSPIRVTypeID(QueryResultType))
@@ -1177,6 +1195,7 @@ static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
 
   Register Image = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass);
   SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
       GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
 
@@ -1239,8 +1258,13 @@ static bool generateReadImageInst(const StringRef DemangledCall,
                                   SPIRVGlobalRegistry *GR) {
   Register Image = Call->Arguments[0];
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-
-  if (DemangledCall.contains_insensitive("ocl_sampler")) {
+  MRI->setRegClass(Image, &SPIRV::IDRegClass);
+  MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler");
+  bool HasMsaa = DemangledCall.contains_insensitive("msaa");
+  if (HasOclSampler || HasMsaa)
+    MRI->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
+  if (HasOclSampler) {
     Register Sampler = Call->Arguments[1];
 
     if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
@@ -1274,6 +1298,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
     }
     LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType));
     Register TempRegister = MRI->createGenericVirtualRegister(LLType);
+    MRI->setRegClass(TempRegister, &SPIRV::IDRegClass);
     GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
 
     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
@@ -1290,7 +1315,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
           .addUse(TempRegister)
           .addImm(0);
-  } else if (DemangledCall.contains_insensitive("msaa")) {
+  } else if (HasMsaa) {
     MIRBuilder.buildInstr(SPIRV::OpImageRead)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
@@ -1311,6 +1336,9 @@ static bool generateReadImageInst(const StringRef DemangledCall,
 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
                                    MachineIRBuilder &MIRBuilder,
                                    SPIRVGlobalRegistry *GR) {
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpImageWrite)
       .addUse(Call->Arguments[0])  // Image.
       .addUse(Call->Arguments[1])  // Coordinate.
@@ -1322,10 +1350,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
                                     const SPIRV::IncomingCall *Call,
                                     MachineIRBuilder &MIRBuilder,
                                     SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   if (Call->Builtin->Name.contains_insensitive(
           "__translate_sampler_initializer")) {
     // Build sampler literal.
-    uint64_t Bitmask = getIConstVal(Call->Arguments[0], MIRBuilder.getMRI());
+    uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI);
     Register Sampler = GR->buildConstantSampler(
         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
         getSamplerParamFromBitmask(Bitmask),
@@ -1340,7 +1369,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
     Register SampledImage =
         Call->ReturnRegister.isValid()
             ? Call->ReturnRegister
-            : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
+            : MRI->createVirtualRegister(&SPIRV::IDRegClass);
     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
         .addDef(SampledImage)
         .addUse(GR->getSPIRVTypeID(SampledImageType))
@@ -1356,6 +1385,10 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
       ReturnType = ReturnType.substr(0, ReturnType.find('('));
     }
     SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
+    MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
+
     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Type))
@@ -1431,6 +1464,75 @@ static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
   }
 }
 
+static bool buildNDRange(const SPIRV::IncomingCall *Call,
+                         MachineIRBuilder &MIRBuilder,
+                         SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+  assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
+         PtrType->getOperand(2).isReg());
+  Register TypeReg = PtrType->getOperand(2).getReg();
+  SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
+  MachineFunction &MF = MIRBuilder.getMF();
+  Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF);
+  // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
+  // three other arguments, so pass zero constant on absence.
+  unsigned NumArgs = Call->Arguments.size();
+  assert(NumArgs >= 2);
+  Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
+  MRI->setRegClass(GlobalWorkSize, &SPIRV::IDRegClass);
+  Register LocalWorkSize =
+      NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
+  if (LocalWorkSize.isValid())
+    MRI->setRegClass(LocalWorkSize, &SPIRV::IDRegClass);
+  Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
+  if (GlobalWorkOffset.isValid())
+    MRI->setRegClass(GlobalWorkOffset, &SPIRV::IDRegClass);
+  if (NumArgs < 4) {
+    Register Const;
+    SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
+    if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
+      MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
+      assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
+             DefInstr->getOperand(3).isReg());
+      Register GWSPtr = DefInstr->getOperand(3).getReg();
+      if (!MRI->getRegClassOrNull(GWSPtr))
+        MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass);
+      // TODO: Maybe simplify generation of the type of the fields.
+      unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2;
+      unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
+      Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
+      Type *FieldTy = ArrayType::get(BaseTy, Size);
+      SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
+      GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+      GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
+      MIRBuilder.buildInstr(SPIRV::OpLoad)
+          .addDef(GlobalWorkSize)
+          .addUse(GR->getSPIRVTypeID(SpvFieldTy))
+          .addUse(GWSPtr);
+      Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
+    } else {
+      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+    }
+    if (!LocalWorkSize.isValid())
+      LocalWorkSize = Const;
+    if (!GlobalWorkOffset.isValid())
+      GlobalWorkOffset = Const;
+  }
+  assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid());
+  MIRBuilder.buildInstr(SPIRV::OpBuildNDRange)
+      .addDef(TmpReg)
+      .addUse(TypeReg)
+      .addUse(GlobalWorkSize)
+      .addUse(LocalWorkSize)
+      .addUse(GlobalWorkOffset);
+  return MIRBuilder.buildInstr(SPIRV::OpStore)
+      .addUse(Call->Arguments[0])
+      .addUse(TmpReg);
+}
+
 static MachineInstr *getBlockStructInstr(Register ParamReg,
                                          MachineRegisterInfo *MRI) {
   // We expect the following sequence of instructions:
@@ -1538,9 +1640,8 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg =
-          MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
-      MIRBuilder.getMRI()->setType(Reg, LLType);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+      MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(Intrinsic::spv_gep,
                                                ArrayRef<Register>{Reg}, true);
@@ -1605,6 +1706,7 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
   switch (Opcode) {
   case SPIRV::OpRetainEvent:
   case SPIRV::OpReleaseEvent:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
   case SPIRV::OpCreateUserEvent:
   case SPIRV::OpGetDefaultQueue:
@@ -1612,77 +1714,27 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
   case SPIRV::OpIsValidEvent:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
         .addUse(Call->Arguments[0]);
   case SPIRV::OpSetUserEventStatus:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addUse(Call->Arguments[0])
         .addUse(Call->Arguments[1]);
   case SPIRV::OpCaptureEventProfilingInfo:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addUse(Call->Arguments[0])
         .addUse(Call->Arguments[1])
         .addUse(Call->Arguments[2]);
-  case SPIRV::OpBuildNDRange: {
-    MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-    SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
-    assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
-           PtrType->getOperand(2).isReg());
-    Register TypeReg = PtrType->getOperand(2).getReg();
-    SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
-    Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF());
-    // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
-    // three other arguments, so pass zero constant on absence.
-    unsigned NumArgs = Call->Arguments.size();
-    assert(NumArgs >= 2);
-    Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
-    Register LocalWorkSize =
-        NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
-    Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
-    if (NumArgs < 4) {
-      Register Const;
-      SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
-      if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
-        MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
-        assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
-               DefInstr->getOperand(3).isReg());
-        Register GWSPtr = DefInstr->getOperand(3).getReg();
-        // TODO: Maybe simplify generation of the type of the fields.
-        unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2;
-        unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
-        Type *BaseTy = IntegerType::get(
-            MIRBuilder.getMF().getFunction().getContext(), BitWidth);
-        Type *FieldTy = ArrayType::get(BaseTy, Size);
-        SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
-        GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-        GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize,
-                                  MIRBuilder.getMF());
-        MIRBuilder.buildInstr(SPIRV::OpLoad)
-            .addDef(GlobalWorkSize)
-            .addUse(GR->getSPIRVTypeID(SpvFieldTy))
-            .addUse(GWSPtr);
-        Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
-      } else {
-        Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
-      }
-      if (!LocalWorkSize.isValid())
-        LocalWorkSize = Const;
-      if (!GlobalWorkOffset.isValid())
-        GlobalWorkOffset = Const;
-    }
-    MIRBuilder.buildInstr(Opcode)
-        .addDef(TmpReg)
-        .addUse(TypeReg)
-        .addUse(GlobalWorkSize)
-        .addUse(LocalWorkSize)
-        .addUse(GlobalWorkOffset);
-    return MIRBuilder.buildInstr(SPIRV::OpStore)
-        .addUse(Call->Arguments[0])
-        .addUse(TmpReg);
-  }
+  case SPIRV::OpBuildNDRange:
+    return buildNDRange(Call, MIRBuilder, GR);
   case SPIRV::OpEnqueueKernel:
     return buildEnqueueKernel(Call, MIRBuilder, GR);
   default:
@@ -1817,16 +1869,23 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
   }
   // Add a pointer to the value to load/store.
   MIB.addUse(Call->Arguments[0]);
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
   // Add a value to store.
-  if (!IsLoad)
+  if (!IsLoad) {
     MIB.addUse(Call->Arguments[1]);
+    MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  }
   // Add optional memory attributes and an alignment.
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned NumArgs = Call->Arguments.size();
-  if ((IsLoad && NumArgs >= 2) || NumArgs >= 3)
+  if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) {
     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
-  if ((IsLoad && NumArgs >= 3) || NumArgs >= 4)
+    MRI->setRegClass(Call->Arguments[IsLoad ? 1 : 2], &SPIRV::IDRegClass);
+  }
+  if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) {
     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
+    MRI->setRegClass(Call->Arguments[IsLoad ? 2 : 3], &SPIRV::IDRegClass);
+  }
   return true;
 }
 
@@ -1846,6 +1905,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   SPIRVType *ReturnType = nullptr;
   if (OrigRetTy && !OrigRetTy->isVoidTy()) {
     ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder);
+    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
+      MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass);
   } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
     ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
     MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));
index 8b61868..47b25a1 100644 (file)
@@ -374,6 +374,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     FTy = getOriginalFunctionType(*CF);
   }
 
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
   std::string FuncName = Info.Callee.getGlobal()->getName().str();
@@ -410,8 +411,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     for (const Argument &Arg : CF->args()) {
       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
         continue; // Don't handle zero sized types.
-      ToInsert.push_back(
-          {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
+      Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+      MRI->setRegClass(Reg, &SPIRV::IDRegClass);
+      ToInsert.push_back({Reg});
       VRegArgs.push_back(ToInsert.back());
     }
     // TODO: Reuse FunctionLoweringInfo
index 062188a..c77a7f8 100644 (file)
@@ -143,6 +143,7 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     LLT LLTy = LLT::scalar(32);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     if (MIRBuilder)
       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
     else
@@ -202,6 +203,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
+    MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
@@ -247,6 +249,7 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
   if (!Res.isValid()) {
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
+    MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
     DT.add(ConstFP, &MF, Res);
     MIRBuilder.buildFConstant(Res, *ConstFP);
@@ -272,6 +275,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     MachineInstrBuilder MIB;
@@ -343,6 +347,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
     Register SpvVecConst =
         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     if (EmitIR) {
@@ -411,6 +416,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
   if (!Res.isValid()) {
     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
         .addDef(Res)
@@ -1090,6 +1096,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
     return Res;
   LLT LLTy = LLT::scalar(32);
   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+  CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
   DT.add(UV, CurMF, Res);
 
index 27d0e8a..2818329 100644 (file)
@@ -85,6 +85,9 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
     Register Reg = MI->getOperand(2).getReg();
     if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
       Reg = RegsAlreadyAddedToDT[MI];
+    auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
+    if (!MRI.getRegClassOrNull(Reg) && RC)
+      MRI.setRegClass(Reg, RC);
     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
     MI->eraseFromParent();
   }
@@ -201,8 +204,12 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
                                       : Def->getParent()->end()));
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
-  if (auto *RC = MRI.getRegClassOrNull(Reg))
+  if (auto *RC = MRI.getRegClassOrNull(Reg)) {
     MRI.setRegClass(NewReg, RC);
+  } else {
+    MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
+    MRI.setRegClass(Reg, &SPIRV::IDRegClass);
+  }
   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
   // This is to make it convenient for Legalizer to get the SPIRVType
@@ -217,7 +224,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
       .addUse(GR->getSPIRVTypeID(SpirvTy))
       .setMIFlags(Flags);
   Def->getOperand(0).setReg(NewReg);
-  MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
   return NewReg;
 }
 } // namespace llvm