[HLSL] Further improve to numthreads diagnostics
authorChris Bieneman <chris.bieneman@me.com>
Thu, 31 Mar 2022 14:38:47 +0000 (09:38 -0500)
committerChris Bieneman <chris.bieneman@me.com>
Thu, 31 Mar 2022 16:34:01 +0000 (11:34 -0500)
This adds diagnostics for conflicting attributes on the same
declarataion, conflicting attributes on a forward and final
declaration, and defines a more narrowly scoped HLSLEntry attribute
target.

Big shout out to @aaron.ballman for the great feedback and review on
this!

clang/include/clang/Basic/Attr.td
clang/include/clang/Basic/DiagnosticSemaKinds.td
clang/include/clang/Sema/Sema.h
clang/lib/Sema/SemaDecl.cpp
clang/lib/Sema/SemaDeclAttr.cpp
clang/test/SemaHLSL/num_threads.hlsl

index 97b1027742f60adc39a8cd66bbe2e6ae090ea044..4789493399ec2e0f0e6c0e1c297608fbd5ff9b0a 100644 (file)
@@ -126,8 +126,10 @@ def FunctionTmpl
                                  FunctionDecl::TK_FunctionTemplate}],
                     "function templates">;
 
-def GlobalFunction
-    : SubsetSubject<Function, [{S->isGlobal()}], "global functions">;
+def HLSLEntry
+    : SubsetSubject<Function,
+                    [{S->isExternallyVisible() && !isa<CXXMethodDecl>(S)}],
+                    "global functions">;
 
 def ClassTmpl : SubsetSubject<CXXRecord, [{S->getDescribedClassTemplate()}],
                               "class templates">;
@@ -3946,7 +3948,7 @@ def Error : InheritableAttr {
 def HLSLNumThreads: InheritableAttr {
   let Spellings = [Microsoft<"numthreads">];
   let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
-  let Subjects = SubjectList<[GlobalFunction]>;
+  let Subjects = SubjectList<[HLSLEntry]>;
   let LangOpts = [HLSL];
   let Documentation = [NumThreadsDocs];
 }
index a272cb741270f8cef2f07220e7a1095e9de67c49..aec172c39ed9a8ef8d09358f0e0d7c0e2baab671 100644 (file)
@@ -11570,6 +11570,7 @@ def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %
 
 def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">;
 def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">;
+def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">;
 
 } // end of sema component.
 
index 6523c3001c294afcf3599b7a869bc02bb31dbb12..c0ad55d52bb31558cd64480b0d19e560e5037e6b 100644 (file)
@@ -3471,6 +3471,9 @@ public:
   EnforceTCBLeafAttr *mergeEnforceTCBLeafAttr(Decl *D,
                                               const EnforceTCBLeafAttr &AL);
   BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
+  HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
+                                              const AttributeCommonInfo &AL,
+                                              int X, int Y, int Z);
 
   void mergeDeclAttributes(NamedDecl *New, Decl *Old,
                            AvailabilityMergeKind AMK = AMK_Redeclaration);
index b913f805bc877322c01b447cfaabc568ded0f126..1e25346fde6f680fc0a0480320af483cca0f647a 100644 (file)
@@ -2770,6 +2770,9 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.mergeEnforceTCBLeafAttr(D, *TCBLA);
   else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
     NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
+  else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
+    NewAttr =
+        S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
   else if (Attr->shouldInheritEvenIfAlreadyPresent() || !DeclHasAttr(D, Attr))
     NewAttr = cast<InheritableAttr>(Attr->clone(S.Context));
 
index 87e16635f302126766237dd0d642113673733dcf..4b5201db7517c195400a7bea18e294a9130b673a 100644 (file)
@@ -6892,7 +6892,22 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     return;
   }
 
-  D->addAttr(::new (S.Context) HLSLNumThreadsAttr(S.Context, AL, X, Y, Z));
+  HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
+HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
+                                                  const AttributeCommonInfo &AL,
+                                                  int X, int Y, int Z) {
+  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
+    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
 }
 
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
index cf9e24804a0932d8ee52531768ef4aa4710273e2..f93e67d54257c83a6a2d8ea5500a3b7b1d62b831 100644 (file)
 
 #if __SHADER_TARGET_STAGE == __SHADER_STAGE_COMPUTE || __SHADER_TARGET_STAGE == __SHADER_STAGE_MESH || __SHADER_TARGET_STAGE == __SHADER_STAGE_AMPLIFICATION || __SHADER_TARGET_STAGE == __SHADER_STAGE_LIBRARY
 #ifdef FAIL
+
+// expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
+[numthreads(1,1,1)]
+struct Fido {
+  // expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
+  [numthreads(1,1,1)]
+  void wag() {}
+
+  // expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
+  [numthreads(1,1,1)]
+  static void oops() {}
+};
+
+// expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
+[numthreads(1,1,1)]
+static void oops() {}
+
+namespace spec {
+// expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
+[numthreads(1,1,1)]
+static void oops() {}
+}
+
+// expected-error@+1 {{'numthreads' attribute parameters do not match the previous declaration}}
+[numthreads(1,1,1)]
+// expected-note@+1 {{conflicting attribute is here}}
+[numthreads(2,2,1)]
+int doubledUp() {
+  return 1;
+}
+
+// expected-note@+1 {{conflicting attribute is here}}
+[numthreads(1,1,1)]
+int forwardDecl();
+
+// expected-error@+1 {{'numthreads' attribute parameters do not match the previous declaration}}
+[numthreads(2,2,1)]
+int forwardDecl() {
+  return 1;
+}
+
 #if __SHADER_TARGET_MAJOR == 6
 // expected-error@+1 {{'numthreads' attribute requires exactly 3 arguments}}
 [numthreads]
 [numthreads(1,2,2)]
 // expected-error@+1 {{total number of threads cannot exceed 768}}
 [numthreads(1024,1,1)]
-#endif
-#endif
+#endif // __SHADER_TARGET_MAJOR
+#endif // FAIL
 // CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 1 2 1
 [numthreads(1,2,1)]
 int entry() {
  return 1;
 }
 
-// expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
-[numthreads(1,1,1)]
-struct Fido {
-  // expected-warning@+1 {{'numthreads' attribute only applies to global functions}}
-  [numthreads(1,1,1)]
-  void wag() {}
-};
+// Because these two attributes match, they should both appear in the AST
+[numthreads(2,2,1)]
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:90:2, col:18> 2 2 1
+int secondFn();
 
-#else
+[numthreads(2,2,1)]
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:94:2, col:18> 2 2 1
+int secondFn() {
+  return 1;
+}
+
+
+#else // Vertex and Pixel only beyond here
 // expected-error-re@+1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}}
 [numthreads(1,1,1)]
 int main() {
  return 1;
 }
+
 #endif