Remove STJ dependency on IncrementalValuesProvider.Collect() (#86616)
authorEirik Tsarpalis <eirik.tsarpalis@gmail.com>
Tue, 23 May 2023 19:38:17 +0000 (20:38 +0100)
committerGitHub <noreply@github.com>
Tue, 23 May 2023 19:38:17 +0000 (20:38 +0100)
* Remove dependency on IncrementalValuesProvider.Collect()

* Address feedback.

src/libraries/System.Text.Json/gen/Helpers/KnownTypeSymbols.cs
src/libraries/System.Text.Json/gen/JsonSourceGenerator.DiagnosticDescriptors.cs [new file with mode: 0644]
src/libraries/System.Text.Json/gen/JsonSourceGenerator.Emitter.cs
src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs
src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs
src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs
src/libraries/System.Text.Json/gen/Model/SourceGenerationSpec.cs [deleted file]
src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets
src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/CompilationHelper.cs
src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs

index 412a33b..2c32d00 100644 (file)
@@ -11,86 +11,208 @@ using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;
 
 namespace System.Text.Json.SourceGeneration
 {
-    internal sealed class KnownTypeSymbols(Compilation compilation)
+    internal sealed class KnownTypeSymbols
     {
-#pragma warning disable CA1822 // Mark members as static false positive with primary constructors.
-        public Compilation Compilation => compilation!;
-
-        public readonly INamedTypeSymbol? IListOfTType = compilation!.GetBestTypeByMetadataName(typeof(IList<>));
-        public readonly INamedTypeSymbol? ICollectionOfTType = compilation!.GetBestTypeByMetadataName(typeof(ICollection<>));
-        public readonly INamedTypeSymbol? IEnumerableType = compilation!.GetBestTypeByMetadataName(typeof(IEnumerable));
-        public readonly INamedTypeSymbol? IEnumerableOfTType = compilation!.GetBestTypeByMetadataName(typeof(IEnumerable<>));
-
-        public readonly INamedTypeSymbol? ListOfTType = compilation!.GetBestTypeByMetadataName(typeof(List<>));
-        public readonly INamedTypeSymbol? DictionaryOfTKeyTValueType = compilation!.GetBestTypeByMetadataName(typeof(Dictionary<,>));
-        public readonly INamedTypeSymbol? IAsyncEnumerableOfTType = compilation!.GetBestTypeByMetadataName("System.Collections.Generic.IAsyncEnumerable`1");
-        public readonly INamedTypeSymbol? IDictionaryOfTKeyTValueType = compilation!.GetBestTypeByMetadataName(typeof(IDictionary<,>));
-        public readonly INamedTypeSymbol? IReadonlyDictionaryOfTKeyTValueType = compilation!.GetBestTypeByMetadataName(typeof(IReadOnlyDictionary<,>));
-        public readonly INamedTypeSymbol? ISetOfTType = compilation!.GetBestTypeByMetadataName(typeof(ISet<>));
-        public readonly INamedTypeSymbol? StackOfTType = compilation!.GetBestTypeByMetadataName(typeof(Stack<>));
-        public readonly INamedTypeSymbol? QueueOfTType = compilation!.GetBestTypeByMetadataName(typeof(Queue<>));
-        public readonly INamedTypeSymbol? ConcurrentStackType = compilation!.GetBestTypeByMetadataName(typeof(ConcurrentStack<>));
-        public readonly INamedTypeSymbol? ConcurrentQueueType = compilation!.GetBestTypeByMetadataName(typeof(ConcurrentQueue<>));
-        public readonly INamedTypeSymbol? IDictionaryType = compilation!.GetBestTypeByMetadataName(typeof(IDictionary));
-        public readonly INamedTypeSymbol? IListType = compilation!.GetBestTypeByMetadataName(typeof(IList));
-        public readonly INamedTypeSymbol? StackType = compilation!.GetBestTypeByMetadataName(typeof(Stack));
-        public readonly INamedTypeSymbol? QueueType = compilation!.GetBestTypeByMetadataName(typeof(Queue));
-        public readonly INamedTypeSymbol? KeyValuePair = compilation!.GetBestTypeByMetadataName(typeof(KeyValuePair<,>));
-
-        public readonly INamedTypeSymbol? ImmutableArrayType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableArray<>));
-        public readonly INamedTypeSymbol? ImmutableListType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableList<>));
-        public readonly INamedTypeSymbol? IImmutableListType = compilation!.GetBestTypeByMetadataName(typeof(IImmutableList<>));
-        public readonly INamedTypeSymbol? ImmutableStackType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableStack<>));
-        public readonly INamedTypeSymbol? IImmutableStackType = compilation!.GetBestTypeByMetadataName(typeof(IImmutableStack<>));
-        public readonly INamedTypeSymbol? ImmutableQueueType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableQueue<>));
-        public readonly INamedTypeSymbol? IImmutableQueueType = compilation!.GetBestTypeByMetadataName(typeof(IImmutableQueue<>));
-        public readonly INamedTypeSymbol? ImmutableSortedType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableSortedSet<>));
-        public readonly INamedTypeSymbol? ImmutableHashSetType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableHashSet<>));
-        public readonly INamedTypeSymbol? IImmutableSetType = compilation!.GetBestTypeByMetadataName(typeof(IImmutableSet<>));
-        public readonly INamedTypeSymbol? ImmutableDictionaryType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableDictionary<,>));
-        public readonly INamedTypeSymbol? ImmutableSortedDictionaryType = compilation!.GetBestTypeByMetadataName(typeof(ImmutableSortedDictionary<,>));
-        public readonly INamedTypeSymbol? IImmutableDictionaryType = compilation!.GetBestTypeByMetadataName(typeof(IImmutableDictionary<,>));
-
-        public readonly INamedTypeSymbol ObjectType = compilation!.GetSpecialType(SpecialType.System_Object);
-        public readonly INamedTypeSymbol StringType = compilation!.GetSpecialType(SpecialType.System_String);
-
-        public readonly INamedTypeSymbol? DateTimeOffsetType = compilation!.GetBestTypeByMetadataName(typeof(DateTimeOffset));
-        public readonly INamedTypeSymbol? TimeSpanType = compilation!.GetBestTypeByMetadataName(typeof(TimeSpan));
-        public readonly INamedTypeSymbol? DateOnlyType = compilation!.GetBestTypeByMetadataName("System.DateOnly");
-        public readonly INamedTypeSymbol? TimeOnlyType = compilation!.GetBestTypeByMetadataName("System.TimeOnly");
-        public readonly IArrayTypeSymbol? ByteArrayType = compilation!.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), rank: 1);
-        public readonly INamedTypeSymbol? GuidType = compilation!.GetBestTypeByMetadataName(typeof(Guid));
-        public readonly INamedTypeSymbol? UriType = compilation!.GetBestTypeByMetadataName(typeof(Uri));
-        public readonly INamedTypeSymbol? VersionType = compilation!.GetBestTypeByMetadataName(typeof(Version));
+        public KnownTypeSymbols(Compilation compilation)
+            => Compilation = compilation;
+
+        public Compilation Compilation { get; }
+
+        // Caches a set of types with built-in converter support. Populated by the Parser class.
+        public HashSet<ITypeSymbol>? BuiltInSupportTypes { get; set; }
+
+        public INamedTypeSymbol? IListOfTType => GetOrResolveType(typeof(IList<>), ref _IListOfTType);
+        private Option<INamedTypeSymbol?> _IListOfTType;
+
+        public INamedTypeSymbol? ICollectionOfTType => GetOrResolveType(typeof(ICollection<>), ref _ICollectionOfTType);
+        private Option<INamedTypeSymbol?> _ICollectionOfTType;
+
+        public INamedTypeSymbol? IEnumerableType => GetOrResolveType(typeof(IEnumerable), ref _IEnumerableType);
+        private Option<INamedTypeSymbol?> _IEnumerableType;
+
+        public INamedTypeSymbol? IEnumerableOfTType => GetOrResolveType(typeof(IEnumerable<>), ref _IEnumerableOfTType);
+        private Option<INamedTypeSymbol?> _IEnumerableOfTType;
+
+        public INamedTypeSymbol? ListOfTType => GetOrResolveType(typeof(List<>), ref _ListOfTType);
+        private Option<INamedTypeSymbol?> _ListOfTType;
+
+        public INamedTypeSymbol? DictionaryOfTKeyTValueType => GetOrResolveType(typeof(Dictionary<,>), ref _DictionaryOfTKeyTValueType);
+        private Option<INamedTypeSymbol?> _DictionaryOfTKeyTValueType;
+
+        public INamedTypeSymbol? IAsyncEnumerableOfTType => GetOrResolveType("System.Collections.Generic.IAsyncEnumerable`1", ref _AsyncEnumerableOfTType);
+        private Option<INamedTypeSymbol?> _AsyncEnumerableOfTType;
+
+        public INamedTypeSymbol? IDictionaryOfTKeyTValueType => GetOrResolveType(typeof(IDictionary<,>), ref _IDictionaryOfTKeyTValueType);
+        private Option<INamedTypeSymbol?> _IDictionaryOfTKeyTValueType;
+
+        public INamedTypeSymbol? IReadonlyDictionaryOfTKeyTValueType => GetOrResolveType(typeof(IReadOnlyDictionary<,>), ref _IReadonlyDictionaryOfTKeyTValueType);
+        private Option<INamedTypeSymbol?> _IReadonlyDictionaryOfTKeyTValueType;
+
+        public INamedTypeSymbol? ISetOfTType => GetOrResolveType(typeof(ISet<>), ref _ISetOfTType);
+        private Option<INamedTypeSymbol?> _ISetOfTType;
+
+        public INamedTypeSymbol? StackOfTType => GetOrResolveType(typeof(Stack<>), ref _StackOfTType);
+        private Option<INamedTypeSymbol?> _StackOfTType;
+
+        public INamedTypeSymbol? QueueOfTType => GetOrResolveType(typeof(Queue<>), ref _QueueOfTType);
+        private Option<INamedTypeSymbol?> _QueueOfTType;
+
+        public INamedTypeSymbol? ConcurrentStackType => GetOrResolveType(typeof(ConcurrentStack<>), ref _ConcurrentStackType);
+        private Option<INamedTypeSymbol?> _ConcurrentStackType;
+
+        public INamedTypeSymbol? ConcurrentQueueType => GetOrResolveType(typeof(ConcurrentQueue<>), ref _ConcurrentQueueType);
+        private Option<INamedTypeSymbol?> _ConcurrentQueueType;
+
+        public INamedTypeSymbol? IDictionaryType => GetOrResolveType(typeof(IDictionary), ref _IDictionaryType);
+        private Option<INamedTypeSymbol?> _IDictionaryType;
+
+        public INamedTypeSymbol? IListType => GetOrResolveType(typeof(IList), ref _IListType);
+        private Option<INamedTypeSymbol?> _IListType;
+
+        public INamedTypeSymbol? StackType => GetOrResolveType(typeof(Stack), ref _StackType);
+        private Option<INamedTypeSymbol?> _StackType;
+
+        public INamedTypeSymbol? QueueType => GetOrResolveType(typeof(Queue), ref _QueueType);
+        private Option<INamedTypeSymbol?> _QueueType;
+
+        public INamedTypeSymbol? KeyValuePair => GetOrResolveType(typeof(KeyValuePair<,>), ref _KeyValuePair);
+        private Option<INamedTypeSymbol?> _KeyValuePair;
+
+        public INamedTypeSymbol? ImmutableArrayType => GetOrResolveType(typeof(ImmutableArray<>), ref _ImmutableArrayType);
+        private Option<INamedTypeSymbol?> _ImmutableArrayType;
+
+        public INamedTypeSymbol? ImmutableListType => GetOrResolveType(typeof(ImmutableList<>), ref _ImmutableListType);
+        private Option<INamedTypeSymbol?> _ImmutableListType;
+
+        public INamedTypeSymbol? IImmutableListType => GetOrResolveType(typeof(IImmutableList<>), ref _IImmutableListType);
+        private Option<INamedTypeSymbol?> _IImmutableListType;
+
+        public INamedTypeSymbol? ImmutableStackType => GetOrResolveType(typeof(ImmutableStack<>), ref _ImmutableStackType);
+        private Option<INamedTypeSymbol?> _ImmutableStackType;
+
+        public INamedTypeSymbol? IImmutableStackType => GetOrResolveType(typeof(IImmutableStack<>), ref _IImmutableStackType);
+        private Option<INamedTypeSymbol?> _IImmutableStackType;
+
+        public INamedTypeSymbol? ImmutableQueueType => GetOrResolveType(typeof(ImmutableQueue<>), ref _ImmutableQueueType);
+        private Option<INamedTypeSymbol?> _ImmutableQueueType;
+
+        public INamedTypeSymbol? IImmutableQueueType => GetOrResolveType(typeof(IImmutableQueue<>), ref _IImmutableQueueType);
+        private Option<INamedTypeSymbol?> _IImmutableQueueType;
+
+        public INamedTypeSymbol? ImmutableSortedType => GetOrResolveType(typeof(ImmutableSortedSet<>), ref _ImmutableSortedType);
+        private Option<INamedTypeSymbol?> _ImmutableSortedType;
+
+        public INamedTypeSymbol? ImmutableHashSetType => GetOrResolveType(typeof(ImmutableHashSet<>), ref _ImmutableHashSetType);
+        private Option<INamedTypeSymbol?> _ImmutableHashSetType;
+
+        public INamedTypeSymbol? IImmutableSetType => GetOrResolveType(typeof(IImmutableSet<>), ref _IImmutableSetType);
+        private Option<INamedTypeSymbol?> _IImmutableSetType;
+
+        public INamedTypeSymbol? ImmutableDictionaryType => GetOrResolveType(typeof(ImmutableDictionary<,>), ref _ImmutableDictionaryType);
+        private Option<INamedTypeSymbol?> _ImmutableDictionaryType;
+
+        public INamedTypeSymbol? ImmutableSortedDictionaryType => GetOrResolveType(typeof(ImmutableSortedDictionary<,>), ref _ImmutableSortedDictionaryType);
+        private Option<INamedTypeSymbol?> _ImmutableSortedDictionaryType;
+
+        public INamedTypeSymbol? IImmutableDictionaryType => GetOrResolveType(typeof(IImmutableDictionary<,>), ref _IImmutableDictionaryType);
+        private Option<INamedTypeSymbol?> _IImmutableDictionaryType;
+
+        public INamedTypeSymbol ObjectType => _ObjectType ??= Compilation.GetSpecialType(SpecialType.System_Object);
+        private INamedTypeSymbol? _ObjectType;
+
+        public INamedTypeSymbol StringType => _StringType ??= Compilation.GetSpecialType(SpecialType.System_String);
+        private INamedTypeSymbol? _StringType;
+
+        public INamedTypeSymbol? DateTimeOffsetType => GetOrResolveType(typeof(DateTimeOffset), ref _DateTimeOffsetType);
+        private Option<INamedTypeSymbol?> _DateTimeOffsetType;
+
+        public INamedTypeSymbol? TimeSpanType => GetOrResolveType(typeof(TimeSpan), ref _TimeSpanType);
+        private Option<INamedTypeSymbol?> _TimeSpanType;
+
+        public INamedTypeSymbol? DateOnlyType => GetOrResolveType("System.DateOnly", ref _DateOnlyType);
+        private Option<INamedTypeSymbol?> _DateOnlyType;
+
+        public INamedTypeSymbol? TimeOnlyType => GetOrResolveType("System.TimeOnly", ref _TimeOnlyType);
+        private Option<INamedTypeSymbol?> _TimeOnlyType;
+
+        public IArrayTypeSymbol? ByteArrayType => _ByteArrayType.HasValue
+            ? _ByteArrayType.Value
+            : (_ByteArrayType = new(Compilation.CreateArrayTypeSymbol(Compilation.GetSpecialType(SpecialType.System_Byte), rank: 1))).Value;
+
+        private Option<IArrayTypeSymbol?> _ByteArrayType;
+
+        public INamedTypeSymbol? GuidType => GetOrResolveType(typeof(Guid), ref _GuidType);
+        private Option<INamedTypeSymbol?> _GuidType;
+
+        public INamedTypeSymbol? UriType => GetOrResolveType(typeof(Uri), ref _UriType);
+        private Option<INamedTypeSymbol?> _UriType;
+
+        public INamedTypeSymbol? VersionType => GetOrResolveType(typeof(Version), ref _VersionType);
+        private Option<INamedTypeSymbol?> _VersionType;
 
         // System.Text.Json types
-        public readonly INamedTypeSymbol? JsonConverterType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonConverter");
-        public readonly INamedTypeSymbol? JsonSerializerContextType = compilation.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonSerializerContext");
-        public readonly INamedTypeSymbol? JsonSerializableAttributeType = compilation.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonSerializableAttribute");
+        public INamedTypeSymbol? JsonConverterType => GetOrResolveType("System.Text.Json.Serialization.JsonConverter", ref _JsonConverterType);
+        private Option<INamedTypeSymbol?> _JsonConverterType;
+
+        public INamedTypeSymbol? JsonSerializerContextType => GetOrResolveType("System.Text.Json.Serialization.JsonSerializerContext", ref _JsonSerializerContextType);
+        private Option<INamedTypeSymbol?> _JsonSerializerContextType;
+
+        public INamedTypeSymbol? JsonSerializableAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonSerializableAttribute", ref _JsonSerializableAttributeType);
+        private Option<INamedTypeSymbol?> _JsonSerializableAttributeType;
 
-        public readonly INamedTypeSymbol? JsonDocumentType = compilation!.GetBestTypeByMetadataName("System.Text.Json.JsonDocument");
-        public readonly INamedTypeSymbol? JsonElementType = compilation!.GetBestTypeByMetadataName("System.Text.Json.JsonElement");
+        public INamedTypeSymbol? JsonDocumentType => GetOrResolveType("System.Text.Json.JsonDocument", ref _JsonDocumentType);
+        private Option<INamedTypeSymbol?> _JsonDocumentType;
 
-        public readonly INamedTypeSymbol? JsonNodeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Nodes.JsonNode");
-        public readonly INamedTypeSymbol? JsonValueType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Nodes.JsonValue");
-        public readonly INamedTypeSymbol? JsonObjectType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Nodes.JsonObject");
-        public readonly INamedTypeSymbol? JsonArrayType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Nodes.JsonArray");
+        public INamedTypeSymbol? JsonElementType => GetOrResolveType("System.Text.Json.JsonElement", ref _JsonElementType);
+        private Option<INamedTypeSymbol?> _JsonElementType;
+
+        public INamedTypeSymbol? JsonNodeType => GetOrResolveType("System.Text.Json.Nodes.JsonNode", ref _JsonNodeType);
+        private Option<INamedTypeSymbol?> _JsonNodeType;
+
+        public INamedTypeSymbol? JsonValueType => GetOrResolveType("System.Text.Json.Nodes.JsonValue", ref _JsonValueType);
+        private Option<INamedTypeSymbol?> _JsonValueType;
+
+        public INamedTypeSymbol? JsonObjectType => GetOrResolveType("System.Text.Json.Nodes.JsonObject", ref _JsonObjectType);
+        private Option<INamedTypeSymbol?> _JsonObjectType;
+
+        public INamedTypeSymbol? JsonArrayType => GetOrResolveType("System.Text.Json.Nodes.JsonArray", ref _JsonArrayType);
+        private Option<INamedTypeSymbol?> _JsonArrayType;
 
         // System.Text.Json attributes
-        public readonly INamedTypeSymbol? JsonConverterAttributeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonConverterAttribute");
-        public readonly INamedTypeSymbol? JsonDerivedTypeAttributeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonDerivedTypeAttribute");
-        public readonly INamedTypeSymbol? JsonNumberHandlingAttributeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonNumberHandlingAttribute");
-        public readonly INamedTypeSymbol? JsonObjectCreationHandlingAttributeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonObjectCreationHandlingAttribute");
-        public readonly INamedTypeSymbol? JsonSourceGenerationOptionsAttributeType = compilation.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonSourceGenerationOptionsAttribute");
-        public readonly INamedTypeSymbol? JsonUnmappedMemberHandlingAttributeType = compilation!.GetBestTypeByMetadataName("System.Text.Json.Serialization.JsonUnmappedMemberHandlingAttribute");
+        public INamedTypeSymbol? JsonConverterAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonConverterAttribute", ref _JsonConverterAttributeType);
+        private Option<INamedTypeSymbol?> _JsonConverterAttributeType;
+
+        public INamedTypeSymbol? JsonDerivedTypeAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonDerivedTypeAttribute", ref _JsonDerivedTypeAttributeType);
+        private Option<INamedTypeSymbol?> _JsonDerivedTypeAttributeType;
+
+        public INamedTypeSymbol? JsonNumberHandlingAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonNumberHandlingAttribute", ref _JsonNumberHandlingAttributeType);
+        private Option<INamedTypeSymbol?> _JsonNumberHandlingAttributeType;
+
+        public INamedTypeSymbol? JsonObjectCreationHandlingAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonObjectCreationHandlingAttribute", ref _JsonObjectCreationHandlingAttributeType);
+        private Option<INamedTypeSymbol?> _JsonObjectCreationHandlingAttributeType;
+
+        public INamedTypeSymbol? JsonSourceGenerationOptionsAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonSourceGenerationOptionsAttribute", ref _JsonSourceGenerationOptionsAttributeType);
+        private Option<INamedTypeSymbol?> _JsonSourceGenerationOptionsAttributeType;
+
+        public INamedTypeSymbol? JsonUnmappedMemberHandlingAttributeType => GetOrResolveType("System.Text.Json.Serialization.JsonUnmappedMemberHandlingAttribute", ref _JsonUnmappedMemberHandlingAttributeType);
+        private Option<INamedTypeSymbol?> _JsonUnmappedMemberHandlingAttributeType;
 
         // Unsupported types
-        public readonly INamedTypeSymbol? DelegateType = compilation!.GetSpecialType(SpecialType.System_Delegate);
-        public readonly INamedTypeSymbol? MemberInfoType = compilation!.GetBestTypeByMetadataName(typeof(MemberInfo));
-        public readonly INamedTypeSymbol? SerializationInfoType = compilation!.GetBestTypeByMetadataName(typeof(Runtime.Serialization.SerializationInfo));
-        public readonly INamedTypeSymbol? IntPtrType = compilation!.GetBestTypeByMetadataName(typeof(IntPtr));
-        public readonly INamedTypeSymbol? UIntPtrType = compilation!.GetBestTypeByMetadataName(typeof(UIntPtr));
-#pragma warning restore CA1822 // Mark members as static false positive with primary constructors.
+        public INamedTypeSymbol? DelegateType => _DelegateType ??= Compilation.GetSpecialType(SpecialType.System_Delegate);
+        private INamedTypeSymbol? _DelegateType;
+
+        public INamedTypeSymbol? MemberInfoType => GetOrResolveType(typeof(MemberInfo), ref _MemberInfoType);
+        private Option<INamedTypeSymbol?> _MemberInfoType;
+
+        public INamedTypeSymbol? SerializationInfoType => GetOrResolveType(typeof(Runtime.Serialization.SerializationInfo), ref _SerializationInfoType);
+        private Option<INamedTypeSymbol?> _SerializationInfoType;
+
+        public INamedTypeSymbol? IntPtrType => GetOrResolveType(typeof(IntPtr), ref _IntPtrType);
+        private Option<INamedTypeSymbol?> _IntPtrType;
+
+        public INamedTypeSymbol? UIntPtrType => GetOrResolveType(typeof(UIntPtr), ref _UIntPtrType);
+        private Option<INamedTypeSymbol?> _UIntPtrType;
+
 
         public bool IsImmutableEnumerableType(ITypeSymbol type, out string? factoryTypeFullName)
         {
@@ -171,5 +293,32 @@ namespace System.Text.Json.SourceGeneration
             factoryTypeFullName = null;
             return false;
         }
+
+        private INamedTypeSymbol? GetOrResolveType(Type type, ref Option<INamedTypeSymbol?> field)
+            => GetOrResolveType(type.FullName!, ref field);
+
+        private INamedTypeSymbol? GetOrResolveType(string fullyQualifiedName, ref Option<INamedTypeSymbol?> field)
+        {
+            if (field.HasValue)
+            {
+                return field.Value;
+            }
+
+            INamedTypeSymbol? type = Compilation.GetBestTypeByMetadataName(fullyQualifiedName);
+            field = new(type);
+            return type;
+        }
+
+        private readonly struct Option<T>
+        {
+            public readonly bool HasValue;
+            public readonly T Value;
+
+            public Option(T value)
+            {
+                HasValue = true;
+                Value = value;
+            }
+        }
     }
 }
diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.DiagnosticDescriptors.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.DiagnosticDescriptors.cs
new file mode 100644 (file)
index 0000000..17d7dd5
--- /dev/null
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.CodeAnalysis;
+
+namespace System.Text.Json.SourceGeneration
+{
+    public sealed partial class JsonSourceGenerator
+    {
+        internal static class DiagnosticDescriptors
+        {
+            public static DiagnosticDescriptor TypeNotSupported { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1030",
+                title: new LocalizableResourceString(nameof(SR.TypeNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.TypeNotSupportedMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Warning,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor DuplicateTypeName { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1031",
+                title: new LocalizableResourceString(nameof(SR.DuplicateTypeNameTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.DuplicateTypeNameMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Warning,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor ContextClassesMustBePartial { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1032",
+                title: new LocalizableResourceString(nameof(SR.ContextClassesMustBePartialTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.ContextClassesMustBePartialMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Warning,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor MultipleJsonConstructorAttribute { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1033",
+                title: new LocalizableResourceString(nameof(SR.MultipleJsonConstructorAttributeTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.MultipleJsonConstructorAttributeFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Error,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor MultipleJsonExtensionDataAttribute { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1035",
+                title: new LocalizableResourceString(nameof(SR.MultipleJsonExtensionDataAttributeTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.MultipleJsonExtensionDataAttributeFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Error,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor DataExtensionPropertyInvalid { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1036",
+                title: new LocalizableResourceString(nameof(SR.DataExtensionPropertyInvalidTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.DataExtensionPropertyInvalidFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Error,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor InaccessibleJsonIncludePropertiesNotSupported { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1038",
+                title: new LocalizableResourceString(nameof(SR.InaccessibleJsonIncludePropertiesNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.InaccessibleJsonIncludePropertiesNotSupportedFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Warning,
+                isEnabledByDefault: true);
+
+            public static DiagnosticDescriptor PolymorphismNotSupported { get; } = new DiagnosticDescriptor(
+                id: "SYSLIB1039",
+                title: new LocalizableResourceString(nameof(SR.FastPathPolymorphismNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                messageFormat: new LocalizableResourceString(nameof(SR.FastPathPolymorphismNotSupportedMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
+                category: JsonConstants.SystemTextJsonSourceGenerationName,
+                defaultSeverity: DiagnosticSeverity.Warning,
+                isEnabledByDefault: true);
+        }
+    }
+}
index e497bcd..e4e8e5a 100644 (file)
@@ -68,9 +68,6 @@ namespace System.Text.Json.SourceGeneration
             private const string JsonTypeInfoTypeRef = "global::System.Text.Json.Serialization.Metadata.JsonTypeInfo";
             private const string JsonTypeInfoResolverTypeRef = "global::System.Text.Json.Serialization.Metadata.IJsonTypeInfoResolver";
 
-            private readonly JsonSourceGenerationContext _sourceGenerationContext;
-            private readonly SourceGenerationSpec _generationSpec;
-
             /// <summary>
             /// Contains an index from TypeRef to TypeGenerationSpec for the current ContextGenerationSpec.
             /// </summary>
@@ -83,53 +80,43 @@ namespace System.Text.Json.SourceGeneration
             /// </summary>
             private readonly Dictionary<string, string> _propertyNames = new();
 
-            public Emitter(in JsonSourceGenerationContext sourceGenerationContext, SourceGenerationSpec generationSpec)
-            {
-                _sourceGenerationContext = sourceGenerationContext;
-                _generationSpec = generationSpec;
-            }
+            /// <summary>
+            /// The SourceText emit implementation filled by the individual Roslyn versions.
+            /// </summary>
+            private partial void AddSource(string hintName, SourceText sourceText);
 
-            public void Emit()
+            public void Emit(ContextGenerationSpec contextGenerationSpec)
             {
-                foreach (DiagnosticInfo diagnostic in _generationSpec.Diagnostics)
+                Debug.Assert(_typeIndex.Count == 0);
+                Debug.Assert(_propertyNames.Count == 0);
+
+                foreach (TypeGenerationSpec spec in contextGenerationSpec.GeneratedTypes)
                 {
-                    // Report any diagnostics produced by the parser ahead of formatting source code.
-                    _sourceGenerationContext.ReportDiagnostic(diagnostic.CreateDiagnostic());
+                    _typeIndex.Add(spec.TypeRef, spec);
                 }
 
-                foreach (ContextGenerationSpec contextGenerationSpec in _generationSpec.ContextGenerationSpecs)
+                foreach (TypeGenerationSpec typeGenerationSpec in contextGenerationSpec.GeneratedTypes)
                 {
-                    Debug.Assert(_typeIndex.Count == 0);
-                    Debug.Assert(_propertyNames.Count == 0);
-
-                    foreach (TypeGenerationSpec spec in contextGenerationSpec.GeneratedTypes)
+                    SourceText? sourceText = GenerateTypeInfo(contextGenerationSpec, typeGenerationSpec);
+                    if (sourceText != null)
                     {
-                        _typeIndex.Add(spec.TypeRef, spec);
-                    }
-
-                    foreach (TypeGenerationSpec typeGenerationSpec in contextGenerationSpec.GeneratedTypes)
-                    {
-                        SourceText? sourceText = GenerateTypeInfo(contextGenerationSpec, typeGenerationSpec);
-                        if (sourceText != null)
-                        {
-                            _sourceGenerationContext.AddSource($"{contextGenerationSpec.ContextType.Name}.{typeGenerationSpec.TypeInfoPropertyName}.g.cs", sourceText);
-                        }
+                        AddSource($"{contextGenerationSpec.ContextType.Name}.{typeGenerationSpec.TypeInfoPropertyName}.g.cs", sourceText);
                     }
+                }
 
-                    string contextName = contextGenerationSpec.ContextType.Name;
+                string contextName = contextGenerationSpec.ContextType.Name;
 
-                    // Add root context implementation.
-                    _sourceGenerationContext.AddSource($"{contextName}.g.cs", GetRootJsonContextImplementation(contextGenerationSpec));
+                // Add root context implementation.
+                AddSource($"{contextName}.g.cs", GetRootJsonContextImplementation(contextGenerationSpec));
 
-                    // Add GetJsonTypeInfo override implementation.
-                    _sourceGenerationContext.AddSource($"{contextName}.GetJsonTypeInfo.g.cs", GetGetTypeInfoImplementation(contextGenerationSpec));
+                // Add GetJsonTypeInfo override implementation.
+                AddSource($"{contextName}.GetJsonTypeInfo.g.cs", GetGetTypeInfoImplementation(contextGenerationSpec));
 
-                    // Add property name initialization.
-                    _sourceGenerationContext.AddSource($"{contextName}.PropertyNames.g.cs", GetPropertyNameInitialization(contextGenerationSpec));
+                // Add property name initialization.
+                AddSource($"{contextName}.PropertyNames.g.cs", GetPropertyNameInitialization(contextGenerationSpec));
 
-                    _propertyNames.Clear();
-                    _typeIndex.Clear();
-                }
+                _propertyNames.Clear();
+                _typeIndex.Clear();
             }
 
             private static SourceWriter CreateSourceWriterWithContextHeader(ContextGenerationSpec contextSpec, bool isPrimaryContextSourceFile = false, string? interfaceImplementation = null)
index 248469a..6fe2320 100644 (file)
@@ -34,27 +34,23 @@ namespace System.Text.Json.SourceGeneration
 
             internal const string JsonSerializableAttributeFullName = "System.Text.Json.Serialization.JsonSerializableAttribute";
 
-            private readonly Compilation _compilation;
             private readonly KnownTypeSymbols _knownSymbols;
+            private readonly bool _compilationContainsCoreJsonTypes;
 
             // Keeps track of generated context type names
             private readonly HashSet<(string ContextName, string TypeName)> _generatedContextAndTypeNames = new();
 
-#pragma warning disable RS1024 // Compare symbols correctly https://github.com/dotnet/roslyn-analyzers/issues/5804
-            private readonly HashSet<ITypeSymbol> _builtInSupportTypes = new(SymbolEqualityComparer.Default);
-#pragma warning restore
-
-            private readonly Queue<(ITypeSymbol type, JsonSourceGenerationMode mode, string? typeInfoPropertyName, Location? attributeLocation)> _typesToGenerate = new();
+            private readonly HashSet<ITypeSymbol> _builtInSupportTypes;
+            private readonly Queue<TypeToGenerate> _typesToGenerate = new();
 #pragma warning disable RS1024 // Compare symbols correctly https://github.com/dotnet/roslyn-analyzers/issues/5804
             private readonly Dictionary<ITypeSymbol, TypeGenerationSpec> _generatedTypes = new(SymbolEqualityComparer.Default);
 #pragma warning restore
-            private JsonKnownNamingPolicy _currentContextNamingPolicy;
 
-            private readonly List<DiagnosticInfo> _diagnostics = new();
+            public List<DiagnosticInfo> Diagnostics { get; } = new();
 
             public void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location, params object?[]? messageArgs)
             {
-                _diagnostics.Add(new DiagnosticInfo
+                Diagnostics.Add(new DiagnosticInfo
                 {
                     Descriptor = descriptor,
                     Location = location.GetTrimmedLocation(),
@@ -62,214 +58,104 @@ namespace System.Text.Json.SourceGeneration
                 });
             }
 
-            private static DiagnosticDescriptor TypeNotSupported { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1030",
-                title: new LocalizableResourceString(nameof(SR.TypeNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.TypeNotSupportedMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Warning,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor DuplicateTypeName { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1031",
-                title: new LocalizableResourceString(nameof(SR.DuplicateTypeNameTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.DuplicateTypeNameMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Warning,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor ContextClassesMustBePartial { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1032",
-                title: new LocalizableResourceString(nameof(SR.ContextClassesMustBePartialTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.ContextClassesMustBePartialMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Warning,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor MultipleJsonConstructorAttribute { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1033",
-                title: new LocalizableResourceString(nameof(SR.MultipleJsonConstructorAttributeTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.MultipleJsonConstructorAttributeFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Error,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor MultipleJsonExtensionDataAttribute { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1035",
-                title: new LocalizableResourceString(nameof(SR.MultipleJsonExtensionDataAttributeTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.MultipleJsonExtensionDataAttributeFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Error,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor DataExtensionPropertyInvalid { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1036",
-                title: new LocalizableResourceString(nameof(SR.DataExtensionPropertyInvalidTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.DataExtensionPropertyInvalidFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Error,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor InaccessibleJsonIncludePropertiesNotSupported { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1038",
-                title: new LocalizableResourceString(nameof(SR.InaccessibleJsonIncludePropertiesNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.InaccessibleJsonIncludePropertiesNotSupportedFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Warning,
-                isEnabledByDefault: true);
-
-            private static DiagnosticDescriptor PolymorphismNotSupported { get; } = new DiagnosticDescriptor(
-                id: "SYSLIB1039",
-                title: new LocalizableResourceString(nameof(SR.FastPathPolymorphismNotSupportedTitle), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                messageFormat: new LocalizableResourceString(nameof(SR.FastPathPolymorphismNotSupportedMessageFormat), SR.ResourceManager, typeof(FxResources.System.Text.Json.SourceGeneration.SR)),
-                category: JsonConstants.SystemTextJsonSourceGenerationName,
-                defaultSeverity: DiagnosticSeverity.Warning,
-                isEnabledByDefault: true);
-
-            public Parser(Compilation compilation)
+            public Parser(KnownTypeSymbols knownSymbols)
             {
-                _compilation = compilation;
-                _knownSymbols = new KnownTypeSymbols(compilation);
-
-                PopulateBuiltInSupportTypes();
+                _knownSymbols = knownSymbols;
+                _compilationContainsCoreJsonTypes =
+                    knownSymbols.JsonSerializerContextType != null &&
+                    knownSymbols.JsonSerializableAttributeType != null &&
+                    knownSymbols.JsonSourceGenerationOptionsAttributeType != null &&
+                    knownSymbols.JsonConverterType != null;
+
+                _builtInSupportTypes = (knownSymbols.BuiltInSupportTypes ??= CreateBuiltInSupportTypeSet(knownSymbols));
             }
 
-            public SourceGenerationSpec? GetGenerationSpec(IEnumerable<ClassDeclarationSyntax> classDeclarationSyntaxList, CancellationToken cancellationToken)
+            public ContextGenerationSpec? ParseContextGenerationSpec(ClassDeclarationSyntax contextClassDeclaration, SemanticModel semanticModel, CancellationToken cancellationToken)
             {
-                Compilation compilation = _compilation;
-                INamedTypeSymbol? jsonSerializerContextSymbol = _knownSymbols.JsonSerializerContextType;
-                INamedTypeSymbol? jsonSerializableAttributeSymbol = _knownSymbols.JsonSerializableAttributeType;
-                INamedTypeSymbol? jsonSourceGenerationOptionsAttributeSymbol = _knownSymbols.JsonSourceGenerationOptionsAttributeType;
-                INamedTypeSymbol? jsonConverterSymbol = _knownSymbols.JsonConverterType;
-
-                if (jsonSerializerContextSymbol == null ||
-                    jsonSerializableAttributeSymbol == null ||
-                    jsonSourceGenerationOptionsAttributeSymbol == null ||
-                    jsonConverterSymbol == null)
+                if (!_compilationContainsCoreJsonTypes)
                 {
                     return null;
                 }
 
-                List<ContextGenerationSpec>? contextGenSpecList = null;
-
-                foreach (IGrouping<SyntaxTree, ClassDeclarationSyntax> group in classDeclarationSyntaxList.GroupBy(c => c.SyntaxTree))
-                {
-                    SyntaxTree syntaxTree = group.Key;
-                    SemanticModel compilationSemanticModel = compilation.GetSemanticModel(syntaxTree);
-                    CompilationUnitSyntax compilationUnitSyntax = (CompilationUnitSyntax)syntaxTree.GetRoot(cancellationToken);
-
-                    foreach (ClassDeclarationSyntax classDeclarationSyntax in group)
-                    {
-                        cancellationToken.ThrowIfCancellationRequested();
-
-                        // Ensure context-scoped metadata caches are empty.
-                        Debug.Assert(_typesToGenerate.Count == 0);
-                        Debug.Assert(_generatedTypes.Count == 0);
-
-                        if (!DerivesFromJsonSerializerContext(classDeclarationSyntax, jsonSerializerContextSymbol, compilationSemanticModel, cancellationToken))
-                        {
-                            continue;
-                        }
-
-                        JsonSourceGenerationOptionsAttribute? options = null;
-                        List<AttributeSyntax>? serializableAttributeList = null;
-
-                        foreach (AttributeListSyntax attributeListSyntax in classDeclarationSyntax.AttributeLists)
-                        {
-                            AttributeSyntax attributeSyntax = attributeListSyntax.Attributes.First();
-                            if (compilationSemanticModel.GetSymbolInfo(attributeSyntax, cancellationToken).Symbol is not IMethodSymbol attributeSymbol)
-                            {
-                                continue;
-                            }
-
-                            INamedTypeSymbol attributeContainingTypeSymbol = attributeSymbol.ContainingType;
-
-                            if (jsonSerializableAttributeSymbol.Equals(attributeContainingTypeSymbol, SymbolEqualityComparer.Default))
-                            {
-                                (serializableAttributeList ??= new List<AttributeSyntax>()).Add(attributeSyntax);
-                            }
-                            else if (jsonSourceGenerationOptionsAttributeSymbol.Equals(attributeContainingTypeSymbol, SymbolEqualityComparer.Default))
-                            {
-                                options = GetSerializerOptions(attributeSyntax);
-                            }
-                        }
-
-                        if (serializableAttributeList == null)
-                        {
-                            // No types were indicated with [JsonSerializable]
-                            continue;
-                        }
+                Debug.Assert(_knownSymbols.JsonSerializerContextType != null);
 
-                        INamedTypeSymbol? contextTypeSymbol = compilationSemanticModel.GetDeclaredSymbol(classDeclarationSyntax, cancellationToken);
-                        Debug.Assert(contextTypeSymbol != null);
+                // Ensure context-scoped metadata caches are empty.
+                Debug.Assert(_typesToGenerate.Count == 0);
+                Debug.Assert(_generatedTypes.Count == 0);
 
-                        Location contextLocation = contextTypeSymbol.Locations.Length > 0 ? contextTypeSymbol.Locations[0] : Location.None;
-
-                        if (!TryGetClassDeclarationList(contextTypeSymbol, out List<string>? classDeclarationList))
-                        {
-                            // Class or one of its containing types is not partial so we can't add to it.
-                            ReportDiagnostic(ContextClassesMustBePartial, contextLocation, new string[] { contextTypeSymbol.Name });
-                            continue;
-                        }
+                if (!DerivesFromJsonSerializerContext(contextClassDeclaration, _knownSymbols.JsonSerializerContextType, semanticModel, cancellationToken))
+                {
+                    return null;
+                }
 
-                        options ??= new JsonSourceGenerationOptionsAttribute();
+                if (!TryParseJsonSerializerContextAttributes(
+                    contextClassDeclaration,
+                    semanticModel,
+                    cancellationToken,
+                    out List<TypeToGenerate>? rootSerializableTypes,
+                    out JsonSourceGenerationOptionsAttribute? options))
+                {
+                    // Context does not specify any source gen attributes.
+                    return null;
+                }
 
-                        // Set the naming policy for the current context.
-                        _currentContextNamingPolicy = options.PropertyNamingPolicy;
+                if (rootSerializableTypes is null)
+                {
+                    // No types were indicated with [JsonSerializable]
+                    return null;
+                }
 
-                        foreach (AttributeSyntax attribute in serializableAttributeList)
-                        {
-                            EnqueueRootType(compilationSemanticModel, attribute, options.GenerationMode, cancellationToken);
-                        }
+                INamedTypeSymbol? contextTypeSymbol = semanticModel.GetDeclaredSymbol(contextClassDeclaration, cancellationToken);
+                Debug.Assert(contextTypeSymbol != null);
 
-                        while (_typesToGenerate.Count > 0)
-                        {
-                            (ITypeSymbol type, JsonSourceGenerationMode mode, string? typeInfoPropertyName, Location? attributeLocation) = _typesToGenerate.Dequeue();
-                            if (!_generatedTypes.ContainsKey(type))
-                            {
-                                TypeGenerationSpec spec = CreateTypeGenerationSpec(type, mode, typeInfoPropertyName, attributeLocation, contextLocation, contextName: contextTypeSymbol.Name);
-                                _generatedTypes.Add(type, spec);
-                            }
-                        }
+                Location contextLocation = contextClassDeclaration.GetLocation();
+                if (!TryGetClassDeclarationList(contextTypeSymbol, out List<string>? classDeclarationList))
+                {
+                    // Class or one of its containing types is not partial so we can't add to it.
+                    ReportDiagnostic(DiagnosticDescriptors.ContextClassesMustBePartial, contextLocation, new string[] { contextTypeSymbol.Name });
+                    return null;
+                }
 
-                        if (_generatedTypes.Count == 0)
-                        {
-                            continue;
-                        }
+                options ??= new JsonSourceGenerationOptionsAttribute();
 
-                        ContextGenerationSpec contextGenSpec = new()
-                        {
-                            ContextType = new(contextTypeSymbol),
-                            GeneratedTypes = _generatedTypes.Values.OrderBy(t => t.TypeRef.FullyQualifiedName).ToImmutableEquatableArray(),
-                            Namespace = contextTypeSymbol.ContainingNamespace.ToDisplayString(),
-                            ContextClassDeclarations = classDeclarationList.ToImmutableEquatableArray(),
-                            DefaultIgnoreCondition = options.DefaultIgnoreCondition,
-                            IgnoreReadOnlyFields = options.IgnoreReadOnlyFields,
-                            IgnoreReadOnlyProperties = options.IgnoreReadOnlyProperties,
-                            IncludeFields = options.IncludeFields,
-                            PropertyNamingPolicy = options.PropertyNamingPolicy,
-                            WriteIndented = options.WriteIndented,
-                        };
-
-                        contextGenSpecList ??= new List<ContextGenerationSpec>();
-                        contextGenSpecList.Add(contextGenSpec);
-
-                        // Clear the caches of generated metadata between the processing of context classes.
-                        _generatedTypes.Clear();
-                        _typesToGenerate.Clear();
-                    }
+                // Enqueue attribute data for spec generation
+                foreach (TypeToGenerate rootSerializableType in rootSerializableTypes)
+                {
+                    EnqueueType(rootSerializableType.Type, rootSerializableType.Mode, rootSerializableType.TypeInfoPropertyName, rootSerializableType.AttributeLocation);
                 }
 
-                if (contextGenSpecList == null)
+                // Walk the transitive type graph generating specs for every encountered type.
+                while (_typesToGenerate.Count > 0)
                 {
-                    return null;
+                    cancellationToken.ThrowIfCancellationRequested();
+                    TypeToGenerate typeToGenerate = _typesToGenerate.Dequeue();
+                    if (!_generatedTypes.ContainsKey(typeToGenerate.Type))
+                    {
+                        TypeGenerationSpec spec = ParseTypeGenerationSpec(typeToGenerate, contextName: contextTypeSymbol.Name, contextLocation, options);
+                        _generatedTypes.Add(typeToGenerate.Type, spec);
+                    }
                 }
 
-                return new SourceGenerationSpec
+                Debug.Assert(_generatedTypes.Count > 0);
+
+                ContextGenerationSpec contextGenSpec = new()
                 {
-                    ContextGenerationSpecs = contextGenSpecList.ToImmutableEquatableArray(),
-                    Diagnostics = _diagnostics.ToImmutableEquatableArray(),
+                    ContextType = new(contextTypeSymbol),
+                    GeneratedTypes = _generatedTypes.Values.OrderBy(t => t.TypeRef.FullyQualifiedName).ToImmutableEquatableArray(),
+                    Namespace = contextTypeSymbol.ContainingNamespace.ToDisplayString(),
+                    ContextClassDeclarations = classDeclarationList.ToImmutableEquatableArray(),
+                    DefaultIgnoreCondition = options.DefaultIgnoreCondition,
+                    IgnoreReadOnlyFields = options.IgnoreReadOnlyFields,
+                    IgnoreReadOnlyProperties = options.IgnoreReadOnlyProperties,
+                    IncludeFields = options.IncludeFields,
+                    PropertyNamingPolicy = options.PropertyNamingPolicy,
+                    WriteIndented = options.WriteIndented,
                 };
+
+                // Clear the caches of generated metadata between the processing of context classes.
+                _generatedTypes.Clear();
+                _typesToGenerate.Clear();
+                return contextGenSpec;
             }
 
             // Returns true if a given type derives directly from JsonSerializerContext.
@@ -380,75 +266,24 @@ namespace System.Text.Json.SourceGeneration
                 return sb.ToString();
             }
 
-            private void EnqueueRootType(
-                SemanticModel compilationSemanticModel,
-                AttributeSyntax attributeSyntax,
-                JsonSourceGenerationMode generationMode,
-                CancellationToken cancellationToken)
-            {
-                IEnumerable<SyntaxNode> attributeArguments = attributeSyntax.DescendantNodes().Where(node => node is AttributeArgumentSyntax);
-
-                ITypeSymbol? typeSymbol = null;
-                string? typeInfoPropertyName = null;
-
-                bool seenFirstArg = false;
-                foreach (AttributeArgumentSyntax node in attributeArguments)
-                {
-                    if (!seenFirstArg)
-                    {
-                        TypeOfExpressionSyntax? typeNode = node.ChildNodes().Single() as TypeOfExpressionSyntax;
-                        if (typeNode != null)
-                        {
-                            ExpressionSyntax typeNameSyntax = (ExpressionSyntax)typeNode.ChildNodes().Single();
-                            typeSymbol = compilationSemanticModel.GetTypeInfo(typeNameSyntax, cancellationToken).ConvertedType;
-                        }
-
-                        seenFirstArg = true;
-                    }
-                    else
-                    {
-                        IEnumerable<SyntaxNode> childNodes = node.ChildNodes();
-
-                        NameEqualsSyntax? propertyNameNode = childNodes.First() as NameEqualsSyntax;
-                        Debug.Assert(propertyNameNode != null);
-
-                        SyntaxNode propertyValueNode = childNodes.ElementAt(1);
-                        string optionName = propertyNameNode.Name.Identifier.ValueText;
-
-                        if (optionName == nameof(JsonSerializableAttribute.TypeInfoPropertyName))
-                        {
-                            typeInfoPropertyName = propertyValueNode.GetFirstToken().ValueText;
-                        }
-                        else if (optionName == nameof(JsonSerializableAttribute.GenerationMode))
-                        {
-                            JsonSourceGenerationMode? mode = GetJsonSourceGenerationModeEnumVal(propertyValueNode);
-                            if (mode.HasValue)
-                            {
-                                generationMode = mode.Value;
-                            }
-                        }
-                    }
-                }
-
-                if (typeSymbol == null)
-                {
-                    return;
-                }
-
-                EnqueueType(typeSymbol, generationMode, typeInfoPropertyName, attributeSyntax.GetLocation());
-            }
-
-            private TypeRef EnqueueType(ITypeSymbol type, JsonSourceGenerationMode generationMode, string? typeInfoPropertyName = null, Location? attributeLocation = null)
+            private TypeRef EnqueueType(ITypeSymbol type, JsonSourceGenerationMode? generationMode, string? typeInfoPropertyName = null, Location? attributeLocation = null)
             {
                 // Trim compile-time erased metadata such as tuple labels and NRT annotations.
-                type = _compilation.EraseCompileTimeMetadata(type);
+                type = _knownSymbols.Compilation.EraseCompileTimeMetadata(type);
 
                 if (_generatedTypes.TryGetValue(type, out TypeGenerationSpec? spec))
                 {
                     return spec.TypeRef;
                 }
 
-                _typesToGenerate.Enqueue((type, generationMode, typeInfoPropertyName, attributeLocation));
+                _typesToGenerate.Enqueue(new TypeToGenerate
+                {
+                    Type = type,
+                    Mode = generationMode,
+                    TypeInfoPropertyName = typeInfoPropertyName,
+                    AttributeLocation = attributeLocation,
+                });
+
                 return new TypeRef(type);
             }
 
@@ -470,13 +305,53 @@ namespace System.Text.Json.SourceGeneration
                 static bool IsValidEnumIdentifier(string token) => token != nameof(JsonSourceGenerationMode) && token != "." && token != "|";
             }
 
-            private static JsonSourceGenerationOptionsAttribute? GetSerializerOptions(AttributeSyntax? attributeSyntax)
+            private bool TryParseJsonSerializerContextAttributes(
+                ClassDeclarationSyntax classDeclarationSyntax,
+                SemanticModel semanticModel,
+                CancellationToken cancellationToken,
+                out List<TypeToGenerate>? rootSerializableTypes,
+                out JsonSourceGenerationOptionsAttribute? options)
             {
-                if (attributeSyntax == null)
+                Debug.Assert(_knownSymbols.JsonSerializableAttributeType != null);
+                Debug.Assert(_knownSymbols.JsonSourceGenerationOptionsAttributeType != null);
+
+                bool foundSourceGenAttributes = false;
+                rootSerializableTypes = null;
+                options = null;
+
+                foreach (AttributeListSyntax attributeListSyntax in classDeclarationSyntax.AttributeLists)
                 {
-                    return null;
+                    AttributeSyntax attributeSyntax = attributeListSyntax.Attributes.First();
+                    if (semanticModel.GetSymbolInfo(attributeSyntax, cancellationToken).Symbol is not IMethodSymbol attributeSymbol)
+                    {
+                        continue;
+                    }
+
+                    INamedTypeSymbol attributeContainingTypeSymbol = attributeSymbol.ContainingType;
+
+                    if (_knownSymbols.JsonSerializableAttributeType.Equals(attributeContainingTypeSymbol, SymbolEqualityComparer.Default))
+                    {
+                        foundSourceGenAttributes = true;
+                        TypeToGenerate? typeToGenerate = ParseJsonSerializableAttribute(semanticModel, attributeSyntax, cancellationToken);
+                        if (typeToGenerate is null)
+                        {
+                            continue;
+                        }
+
+                        (rootSerializableTypes ??= new()).Add(typeToGenerate.Value);
+                    }
+                    else if (_knownSymbols.JsonSourceGenerationOptionsAttributeType.Equals(attributeContainingTypeSymbol, SymbolEqualityComparer.Default))
+                    {
+                        foundSourceGenAttributes = true;
+                        options = ParseJsonSourceGenerationOptionsAttribute(attributeSyntax);
+                    }
                 }
 
+                return foundSourceGenAttributes;
+            }
+
+            private static JsonSourceGenerationOptionsAttribute ParseJsonSourceGenerationOptionsAttribute(AttributeSyntax attributeSyntax)
+            {
                 IEnumerable<SyntaxNode> attributeArguments = attributeSyntax.DescendantNodes().Where(node => node is AttributeArgumentSyntax);
 
                 JsonSourceGenerationOptionsAttribute options = new();
@@ -527,7 +402,7 @@ namespace System.Text.Json.SourceGeneration
                             break;
                         case nameof(JsonSourceGenerationOptionsAttribute.PropertyNamingPolicy):
                             {
-                                if (Enum.TryParse<JsonKnownNamingPolicy>(propertyValueStr, out JsonKnownNamingPolicy value))
+                                if (Enum.TryParse(propertyValueStr, out JsonKnownNamingPolicy value))
                                 {
                                     options.PropertyNamingPolicy = value;
                                 }
@@ -558,9 +433,71 @@ namespace System.Text.Json.SourceGeneration
                 return options;
             }
 
-            private TypeGenerationSpec CreateTypeGenerationSpec(ITypeSymbol type, JsonSourceGenerationMode generationMode, string? typeInfoPropertyName, Location? attributeLocation, Location contextLocation, string contextName)
+            private static TypeToGenerate? ParseJsonSerializableAttribute(SemanticModel semanticModel, AttributeSyntax attributeSyntax, CancellationToken cancellationToken)
             {
-                Location typeLocation = type.GetDiagnosticLocation() ?? attributeLocation ?? contextLocation;
+                IEnumerable<SyntaxNode> attributeArguments = attributeSyntax.DescendantNodes().Where(node => node is AttributeArgumentSyntax);
+
+                ITypeSymbol? typeSymbol = null;
+                string? typeInfoPropertyName = null;
+                JsonSourceGenerationMode? generationMode = null;
+
+                bool seenFirstArg = false;
+                foreach (AttributeArgumentSyntax node in attributeArguments)
+                {
+                    if (!seenFirstArg)
+                    {
+                        TypeOfExpressionSyntax? typeNode = node.ChildNodes().Single() as TypeOfExpressionSyntax;
+                        if (typeNode != null)
+                        {
+                            ExpressionSyntax typeNameSyntax = (ExpressionSyntax)typeNode.ChildNodes().Single();
+                            typeSymbol = semanticModel.GetTypeInfo(typeNameSyntax, cancellationToken).ConvertedType;
+                        }
+
+                        seenFirstArg = true;
+                    }
+                    else
+                    {
+                        IEnumerable<SyntaxNode> childNodes = node.ChildNodes();
+
+                        NameEqualsSyntax? propertyNameNode = childNodes.First() as NameEqualsSyntax;
+                        Debug.Assert(propertyNameNode != null);
+
+                        SyntaxNode propertyValueNode = childNodes.ElementAt(1);
+                        string optionName = propertyNameNode.Name.Identifier.ValueText;
+
+                        if (optionName == nameof(JsonSerializableAttribute.TypeInfoPropertyName))
+                        {
+                            typeInfoPropertyName = propertyValueNode.GetFirstToken().ValueText;
+                        }
+                        else if (optionName == nameof(JsonSerializableAttribute.GenerationMode))
+                        {
+                            JsonSourceGenerationMode? mode = GetJsonSourceGenerationModeEnumVal(propertyValueNode);
+                            if (mode.HasValue)
+                            {
+                                generationMode = mode.Value;
+                            }
+                        }
+                    }
+                }
+
+                if (typeSymbol is null)
+                {
+                    return null;
+                }
+
+                return new TypeToGenerate
+                {
+                    Type = typeSymbol,
+                    Mode = generationMode,
+                    TypeInfoPropertyName = typeInfoPropertyName,
+                    AttributeLocation = attributeSyntax.GetLocation(),
+                };
+            }
+
+            private TypeGenerationSpec ParseTypeGenerationSpec(TypeToGenerate typeToGenerate, string contextName, Location contextLocation, JsonSourceGenerationOptionsAttribute options)
+            {
+                ITypeSymbol type = typeToGenerate.Type;
+                Location typeLocation = type.GetDiagnosticLocation() ?? typeToGenerate.AttributeLocation ?? contextLocation;
 
                 ClassType classType;
                 JsonPrimitiveTypeKind? primitiveTypeKind = GetPrimitiveTypeKind(type);
@@ -618,11 +555,11 @@ namespace System.Text.Json.SourceGeneration
                     {
                         Debug.Assert(attributeData.ConstructorArguments.Length > 0);
                         var derivedType = (ITypeSymbol)attributeData.ConstructorArguments[0].Value!;
-                        EnqueueType(derivedType, generationMode);
+                        EnqueueType(derivedType, typeToGenerate.Mode);
 
-                        if (!isPolymorphic && generationMode == JsonSourceGenerationMode.Serialization)
+                        if (!isPolymorphic && typeToGenerate.Mode == JsonSourceGenerationMode.Serialization)
                         {
-                            ReportDiagnostic(PolymorphismNotSupported, typeLocation, new string[] { type.ToDisplayString() });
+                            ReportDiagnostic(DiagnosticDescriptors.PolymorphismNotSupported, typeLocation, new string[] { type.ToDisplayString() });
                         }
 
                         isPolymorphic = true;
@@ -646,7 +583,7 @@ namespace System.Text.Json.SourceGeneration
                 else if (type.IsNullableValueType(out ITypeSymbol? underlyingType))
                 {
                     classType = ClassType.Nullable;
-                    nullableUnderlyingType = EnqueueType(underlyingType, generationMode);
+                    nullableUnderlyingType = EnqueueType(underlyingType, typeToGenerate.Mode);
                 }
                 else if (type.TypeKind is TypeKind.Enum)
                 {
@@ -661,7 +598,7 @@ namespace System.Text.Json.SourceGeneration
                     }
 
                     ITypeSymbol elementType = iasyncEnumerableType.TypeArguments[0];
-                    collectionValueType = EnqueueType(elementType, generationMode);
+                    collectionValueType = EnqueueType(elementType, typeToGenerate.Mode);
                     collectionType = CollectionType.IAsyncEnumerableOfT;
                     classType = ClassType.Enumerable;
                 }
@@ -816,11 +753,11 @@ namespace System.Text.Json.SourceGeneration
                         valueType = _knownSymbols.ObjectType;
                     }
 
-                    collectionValueType = EnqueueType(valueType, generationMode);
+                    collectionValueType = EnqueueType(valueType, typeToGenerate.Mode);
 
                     if (keyType != null)
                     {
-                        collectionKeyType = EnqueueType(keyType, generationMode);
+                        collectionKeyType = EnqueueType(keyType, typeToGenerate.Mode);
 
                         if (needsRuntimeType)
                         {
@@ -835,7 +772,7 @@ namespace System.Text.Json.SourceGeneration
                     if (!TryGetDeserializationConstructor(type, useDefaultCtorInAnnotatedStructs, out IMethodSymbol? constructor))
                     {
                         classType = ClassType.TypeUnsupportedBySourceGen;
-                        ReportDiagnostic(MultipleJsonConstructorAttribute, typeLocation, new string[] { type.ToDisplayString() });
+                        ReportDiagnostic(DiagnosticDescriptors.MultipleJsonConstructorAttribute, typeLocation, new string[] { type.ToDisplayString() });
                     }
                     else
                     {
@@ -859,7 +796,7 @@ namespace System.Text.Json.SourceGeneration
                                 for (int i = 0; i < paramCount; i++)
                                 {
                                     IParameterSymbol parameterInfo = parameters![i];
-                                    TypeRef parameterTypeRef = EnqueueType(parameterInfo.Type, generationMode);
+                                    TypeRef parameterTypeRef = EnqueueType(parameterInfo.Type, typeToGenerate.Mode);
 
                                     paramGenSpecs[i] = new ParameterGenerationSpec
                                     {
@@ -903,7 +840,7 @@ namespace System.Text.Json.SourceGeneration
                                     continue;
                                 }
 
-                                PropertyGenerationSpec? spec = GetPropertyGenerationSpec(declaringTypeRef, propertyInfo.Type, propertyInfo, isVirtual, generationMode);
+                                PropertyGenerationSpec? spec = ParsePropertyGenerationSpec(declaringTypeRef, propertyInfo.Type, propertyInfo, isVirtual, typeToGenerate.Mode, options);
                                 if (spec is null)
                                 {
                                     continue;
@@ -928,7 +865,7 @@ namespace System.Text.Json.SourceGeneration
                                     continue;
                                 }
 
-                                PropertyGenerationSpec? spec = GetPropertyGenerationSpec(declaringTypeRef, fieldInfo.Type, fieldInfo, isVirtual: false, generationMode);
+                                PropertyGenerationSpec? spec = ParsePropertyGenerationSpec(declaringTypeRef, fieldInfo.Type, fieldInfo, isVirtual: false, typeToGenerate.Mode, options);
                                 if (spec is null)
                                 {
                                     continue;
@@ -947,12 +884,12 @@ namespace System.Text.Json.SourceGeneration
                                 {
                                     if (extensionDataPropertyType != null)
                                     {
-                                        ReportDiagnostic(MultipleJsonExtensionDataAttribute, typeLocation, new string[] { type.Name });
+                                        ReportDiagnostic(DiagnosticDescriptors.MultipleJsonExtensionDataAttribute, typeLocation, new string[] { type.Name });
                                     }
 
                                     if (!IsValidDataExtensionPropertyType(memberType))
                                     {
-                                        ReportDiagnostic(DataExtensionPropertyInvalid, memberInfo.GetDiagnosticLocation(), new string[] { type.Name, spec.MemberName });
+                                        ReportDiagnostic(DiagnosticDescriptors.DataExtensionPropertyInvalid, memberInfo.GetDiagnosticLocation(), new string[] { type.Name, spec.MemberName });
                                     }
 
                                     extensionDataPropertyType = spec.PropertyType;
@@ -981,7 +918,7 @@ namespace System.Text.Json.SourceGeneration
 
                                 if (spec.HasJsonInclude && (!spec.CanUseGetter || !spec.CanUseSetter || !spec.IsPublic))
                                 {
-                                    ReportDiagnostic(InaccessibleJsonIncludePropertiesNotSupported, memberInfo.GetDiagnosticLocation(), new string[] { type.Name, spec.MemberName });
+                                    ReportDiagnostic(DiagnosticDescriptors.InaccessibleJsonIncludePropertiesNotSupported, memberInfo.GetDiagnosticLocation(), new string[] { type.Name, spec.MemberName });
                                 }
                             }
                         }
@@ -994,18 +931,18 @@ namespace System.Text.Json.SourceGeneration
                 }
 
                 var typeRef = new TypeRef(type);
-                typeInfoPropertyName ??= GetTypeInfoPropertyName(type);
+                string typeInfoPropertyName = typeToGenerate.TypeInfoPropertyName ?? GetTypeInfoPropertyName(type);
 
                 if (classType is ClassType.TypeUnsupportedBySourceGen)
                 {
-                    ReportDiagnostic(TypeNotSupported, typeLocation, new string[] { typeRef.FullyQualifiedName });
+                    ReportDiagnostic(DiagnosticDescriptors.TypeNotSupported, typeLocation, new string[] { typeRef.FullyQualifiedName });
                 }
 
                 if (!_generatedContextAndTypeNames.Add((contextName, typeInfoPropertyName)))
                 {
                     // The context name/property name combination will result in a conflict in generated types.
                     // Workaround for https://github.com/dotnet/roslyn/issues/54185 by keeping track of the file names we've used.
-                    ReportDiagnostic(DuplicateTypeName, attributeLocation ?? contextLocation, new string[] { typeInfoPropertyName });
+                    ReportDiagnostic(DiagnosticDescriptors.DuplicateTypeName, typeToGenerate.AttributeLocation ?? contextLocation, new string[] { typeInfoPropertyName });
                     classType = ClassType.TypeUnsupportedBySourceGen;
                 }
 
@@ -1013,7 +950,7 @@ namespace System.Text.Json.SourceGeneration
                 {
                     TypeRef = typeRef,
                     TypeInfoPropertyName = typeInfoPropertyName,
-                    GenerationMode = generationMode,
+                    GenerationMode = typeToGenerate.Mode ?? options.GenerationMode,
                     ClassType = classType,
                     PrimitiveTypeKind = primitiveTypeKind,
                     IsPolymorphic = isPolymorphic,
@@ -1102,12 +1039,13 @@ namespace System.Text.Json.SourceGeneration
                     ignoredMember.IsVirtual();
             }
 
-            private PropertyGenerationSpec? GetPropertyGenerationSpec(
+            private PropertyGenerationSpec? ParsePropertyGenerationSpec(
                 TypeRef declaringType,
                 ITypeSymbol memberType,
                 ISymbol memberInfo,
                 bool isVirtual,
-                JsonSourceGenerationMode generationMode)
+                JsonSourceGenerationMode? generationMode,
+                JsonSourceGenerationOptionsAttribute options)
             {
                 Debug.Assert(memberInfo is IFieldSymbol or IPropertySymbol);
 
@@ -1141,7 +1079,7 @@ namespace System.Text.Json.SourceGeneration
                 bool needsAtSign = memberInfo.MemberNameNeedsAtSign();
 
                 string clrName = memberInfo.Name;
-                string runtimePropertyName = DetermineRuntimePropName(clrName, jsonPropertyName, _currentContextNamingPolicy);
+                string runtimePropertyName = DetermineRuntimePropName(clrName, jsonPropertyName, options.PropertyNamingPolicy);
                 string propertyNameVarName = DeterminePropNameIdentifier(runtimePropertyName);
 
                 return new PropertyGenerationSpec
@@ -1588,25 +1526,29 @@ namespace System.Text.Json.SourceGeneration
                     _builtInSupportTypes.Contains(type);
             }
 
-            private void PopulateBuiltInSupportTypes()
+            private static HashSet<ITypeSymbol> CreateBuiltInSupportTypeSet(KnownTypeSymbols knownSymbols)
             {
-                HashSet<ITypeSymbol> builtInSupportTypes = _builtInSupportTypes;
-
-                AddTypeIfNotNull(_knownSymbols.ByteArrayType);
-                AddTypeIfNotNull(_knownSymbols.TimeSpanType);
-                AddTypeIfNotNull(_knownSymbols.DateTimeOffsetType);
-                AddTypeIfNotNull(_knownSymbols.DateOnlyType);
-                AddTypeIfNotNull(_knownSymbols.TimeOnlyType);
-                AddTypeIfNotNull(_knownSymbols.GuidType);
-                AddTypeIfNotNull(_knownSymbols.UriType);
-                AddTypeIfNotNull(_knownSymbols.VersionType);
-
-                AddTypeIfNotNull(_knownSymbols.JsonArrayType);
-                AddTypeIfNotNull(_knownSymbols.JsonElementType);
-                AddTypeIfNotNull(_knownSymbols.JsonNodeType);
-                AddTypeIfNotNull(_knownSymbols.JsonObjectType);
-                AddTypeIfNotNull(_knownSymbols.JsonValueType);
-                AddTypeIfNotNull(_knownSymbols.JsonDocumentType);
+#pragma warning disable RS1024 // Compare symbols correctly https://github.com/dotnet/roslyn-analyzers/issues/5804
+                HashSet<ITypeSymbol> builtInSupportTypes = new(SymbolEqualityComparer.Default);
+#pragma warning restore
+
+                AddTypeIfNotNull(knownSymbols.ByteArrayType);
+                AddTypeIfNotNull(knownSymbols.TimeSpanType);
+                AddTypeIfNotNull(knownSymbols.DateTimeOffsetType);
+                AddTypeIfNotNull(knownSymbols.DateOnlyType);
+                AddTypeIfNotNull(knownSymbols.TimeOnlyType);
+                AddTypeIfNotNull(knownSymbols.GuidType);
+                AddTypeIfNotNull(knownSymbols.UriType);
+                AddTypeIfNotNull(knownSymbols.VersionType);
+
+                AddTypeIfNotNull(knownSymbols.JsonArrayType);
+                AddTypeIfNotNull(knownSymbols.JsonElementType);
+                AddTypeIfNotNull(knownSymbols.JsonNodeType);
+                AddTypeIfNotNull(knownSymbols.JsonObjectType);
+                AddTypeIfNotNull(knownSymbols.JsonValueType);
+                AddTypeIfNotNull(knownSymbols.JsonDocumentType);
+
+                return builtInSupportTypes;
 
                 void AddTypeIfNotNull(ITypeSymbol? type)
                 {
@@ -1616,6 +1558,14 @@ namespace System.Text.Json.SourceGeneration
                     }
                 }
             }
+
+            private readonly struct TypeToGenerate
+            {
+                public required ITypeSymbol Type { get; init; }
+                public JsonSourceGenerationMode? Mode { get; init; }
+                public string? TypeInfoPropertyName { get; init; }
+                public Location? AttributeLocation { get; init; }
+            }
         }
     }
 }
index 870883e..4c58a3d 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.Generic;
+using System.Collections.Immutable;
 using System.Threading;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
@@ -23,7 +24,7 @@ namespace System.Text.Json.SourceGeneration
         public void Initialize(GeneratorInitializationContext context)
         {
 #if LAUNCH_DEBUGGER
-            Diagnostics.Debugger.Launch();
+            System.Diagnostics.Debugger.Launch();
 #endif
 
             // Unfortunately, there is no cancellation token that can be passed here
@@ -39,22 +40,45 @@ namespace System.Text.Json.SourceGeneration
         /// <param name="executionContext"></param>
         public void Execute(GeneratorExecutionContext executionContext)
         {
-            if (executionContext.SyntaxContextReceiver is not SyntaxContextReceiver receiver || receiver.ClassDeclarationSyntaxList == null)
+            if (executionContext.SyntaxContextReceiver is not SyntaxContextReceiver receiver || receiver.ContextClassDeclarations == null)
             {
                 // nothing to do yet
                 return;
             }
 
-            JsonSourceGenerationContext context = new JsonSourceGenerationContext(executionContext);
-            Parser parser = new(executionContext.Compilation);
-            SourceGenerationSpec? spec = parser.GetGenerationSpec(receiver.ClassDeclarationSyntaxList, executionContext.CancellationToken);
+            // Stage 1. Parse the identified JsonSerializerContext classes and store the model types.
+            KnownTypeSymbols knownSymbols = new(executionContext.Compilation);
+            Parser parser = new(knownSymbols);
 
-            OnSourceEmitting?.Invoke(spec);
+            List<ContextGenerationSpec>? contextGenerationSpecs = null;
+            foreach ((ClassDeclarationSyntax? contextClassDeclaration, SemanticModel semanticModel) in receiver.ContextClassDeclarations)
+            {
+                ContextGenerationSpec? contextGenerationSpec = parser.ParseContextGenerationSpec(contextClassDeclaration, semanticModel, executionContext.CancellationToken);
+                if (contextGenerationSpec is null)
+                {
+                    continue;
+                }
+
+                (contextGenerationSpecs ??= new()).Add(contextGenerationSpec);
+            }
+
+            // Stage 2. Report any diagnostics gathered by the parser.
+            foreach (DiagnosticInfo diagnosticInfo in parser.Diagnostics)
+            {
+                executionContext.ReportDiagnostic(diagnosticInfo.CreateDiagnostic());
+            }
+
+            if (contextGenerationSpecs is null)
+            {
+                return;
+            }
 
-            if (spec != null)
+            // Stage 3. Emit source code from the spec models.
+            OnSourceEmitting?.Invoke(contextGenerationSpecs.ToImmutableArray());
+            Emitter emitter = new(executionContext);
+            foreach (ContextGenerationSpec contextGenerationSpec in contextGenerationSpecs)
             {
-                Emitter emitter = new(context, spec);
-                emitter.Emit();
+                emitter.Emit(contextGenerationSpec);
             }
         }
 
@@ -67,7 +91,7 @@ namespace System.Text.Json.SourceGeneration
                 _cancellationToken = cancellationToken;
             }
 
-            public List<ClassDeclarationSyntax>? ClassDeclarationSyntaxList { get; private set; }
+            public List<(ClassDeclarationSyntax, SemanticModel)>? ContextClassDeclarations { get; private set; }
 
             public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
             {
@@ -76,7 +100,7 @@ namespace System.Text.Json.SourceGeneration
                     ClassDeclarationSyntax? classSyntax = GetSemanticTargetForGeneration(context, _cancellationToken);
                     if (classSyntax != null)
                     {
-                        (ClassDeclarationSyntaxList ??= new List<ClassDeclarationSyntax>()).Add(classSyntax);
+                        (ContextClassDeclarations ??= new()).Add((classSyntax, context.SemanticModel));
                     }
                 }
             }
@@ -107,7 +131,6 @@ namespace System.Text.Json.SourceGeneration
                             return classDeclarationSyntax;
                         }
                     }
-
                 }
 
                 return null;
@@ -117,26 +140,17 @@ namespace System.Text.Json.SourceGeneration
         /// <summary>
         /// Instrumentation helper for unit tests.
         /// </summary>
-        public Action<SourceGenerationSpec?>? OnSourceEmitting { get; init; }
-    }
-
-    internal readonly struct JsonSourceGenerationContext
-    {
-        private readonly GeneratorExecutionContext _context;
+        public Action<ImmutableArray<ContextGenerationSpec>>? OnSourceEmitting { get; init; }
 
-        public JsonSourceGenerationContext(GeneratorExecutionContext context)
+        private partial class Emitter
         {
-            _context = context;
-        }
+            private readonly GeneratorExecutionContext _context;
 
-        public void ReportDiagnostic(Diagnostic diagnostic)
-        {
-            _context.ReportDiagnostic(diagnostic);
-        }
+            public Emitter(GeneratorExecutionContext context)
+                => _context = context;
 
-        public void AddSource(string hintName, SourceText sourceText)
-        {
-            _context.AddSource(hintName, sourceText);
+            private partial void AddSource(string hintName, SourceText sourceText)
+                => _context.AddSource(hintName, sourceText);
         }
     }
 }
index 2224048..e3f8b4a 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Collections.Immutable;
 using System.Linq;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -24,70 +25,67 @@ namespace System.Text.Json.SourceGeneration
         public void Initialize(IncrementalGeneratorInitializationContext context)
         {
 #if LAUNCH_DEBUGGER
-            Diagnostics.Debugger.Launch();
+            System.Diagnostics.Debugger.Launch();
 #endif
-            IncrementalValuesProvider<ClassDeclarationSyntax> classDeclarations = context.SyntaxProvider
+            IncrementalValueProvider<KnownTypeSymbols> knownTypeSymbols = context.CompilationProvider
+                .Select((compilation, _) => new KnownTypeSymbols(compilation));
+
+            IncrementalValuesProvider<(ContextGenerationSpec?, ImmutableEquatableArray<DiagnosticInfo>)> contextGenerationSpecs = context.SyntaxProvider
                 .ForAttributeWithMetadataName(
 #if !ROSLYN4_4_OR_GREATER
                     context,
 #endif
                     Parser.JsonSerializableAttributeFullName,
                     (node, _) => node is ClassDeclarationSyntax,
-                    (context, _) => (ClassDeclarationSyntax)context.TargetNode);
-
-            IncrementalValueProvider<SourceGenerationSpec?> sourceGenSpec = context.CompilationProvider
-                .Combine(classDeclarations.Collect())
+                    (context, _) => (ContextClass: (ClassDeclarationSyntax)context.TargetNode, context.SemanticModel))
+                .Combine(knownTypeSymbols)
                 .Select(static (tuple, cancellationToken) =>
                 {
-                    Parser parser = new(tuple.Left);
-                    return parser.GetGenerationSpec(tuple.Right, cancellationToken);
+                    Parser parser = new(tuple.Right);
+                    ContextGenerationSpec? contextGenerationSpec = parser.ParseContextGenerationSpec(tuple.Left.ContextClass, tuple.Left.SemanticModel, cancellationToken);
+                    ImmutableEquatableArray<DiagnosticInfo> diagnostics = parser.Diagnostics.ToImmutableEquatableArray();
+                    return (contextGenerationSpec, diagnostics);
                 })
 #if ROSLYN4_4_OR_GREATER
-                .WithTrackingName(SourceGenerationSpecTrackingName);
-#else
-                ;
+                .WithTrackingName(SourceGenerationSpecTrackingName)
 #endif
+                ;
 
-            context.RegisterSourceOutput(sourceGenSpec, EmitSource);
+            context.RegisterSourceOutput(contextGenerationSpecs, ReportDiagnosticsAndEmitSource);
         }
 
-        private void EmitSource(SourceProductionContext sourceProductionContext, SourceGenerationSpec? sourceGenSpec)
+        private void ReportDiagnosticsAndEmitSource(SourceProductionContext sourceProductionContext, (ContextGenerationSpec? ContextGenerationSpec, ImmutableEquatableArray<DiagnosticInfo> Diagnostics) input)
         {
-            OnSourceEmitting?.Invoke(sourceGenSpec);
+            // Report any diagnostics ahead of emitting.
+            foreach (DiagnosticInfo diagnostic in input.Diagnostics)
+            {
+                sourceProductionContext.ReportDiagnostic(diagnostic.CreateDiagnostic());
+            }
 
-            if (sourceGenSpec is null)
+            if (input.ContextGenerationSpec is null)
             {
                 return;
             }
 
-            JsonSourceGenerationContext context = new JsonSourceGenerationContext(sourceProductionContext);
-            Emitter emitter = new(context, sourceGenSpec);
-            emitter.Emit();
+            OnSourceEmitting?.Invoke(ImmutableArray.Create(input.ContextGenerationSpec));
+            Emitter emitter = new(sourceProductionContext);
+            emitter.Emit(input.ContextGenerationSpec);
         }
 
         /// <summary>
         /// Instrumentation helper for unit tests.
         /// </summary>
-        public Action<SourceGenerationSpec?>? OnSourceEmitting { get; init; }
-    }
+        public Action<ImmutableArray<ContextGenerationSpec>>? OnSourceEmitting { get; init; }
 
-    internal readonly struct JsonSourceGenerationContext
-    {
-        private readonly SourceProductionContext _context;
-
-        public JsonSourceGenerationContext(SourceProductionContext context)
+        private partial class Emitter
         {
-            _context = context;
-        }
+            private readonly SourceProductionContext _context;
 
-        public void ReportDiagnostic(Diagnostic diagnostic)
-        {
-            _context.ReportDiagnostic(diagnostic);
-        }
+            public Emitter(SourceProductionContext context)
+                => _context = context;
 
-        public void AddSource(string hintName, SourceText sourceText)
-        {
-            _context.AddSource(hintName, sourceText);
+            private partial void AddSource(string hintName, SourceText sourceText)
+                => _context.AddSource(hintName, sourceText);
         }
     }
 }
diff --git a/src/libraries/System.Text.Json/gen/Model/SourceGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/SourceGenerationSpec.cs
deleted file mode 100644 (file)
index b93e1e5..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-namespace System.Text.Json.SourceGeneration
-{
-    /// <summary>
-    /// Models all output produced by the source generator
-    /// </summary>
-    /// <remarks>
-    /// Type needs to be cacheable as a Roslyn incremental value so it must be
-    ///
-    /// 1) immutable and
-    /// 2) implement structural (pointwise) equality comparison.
-    ///
-    /// We can get these properties for free provided that we
-    ///
-    /// a) define the type as an immutable C# record and
-    /// b) ensure all nested members are also immutable and implement structural equality.
-    ///
-    /// When adding new members to the type, please ensure that these properties
-    /// are satisfied otherwise we risk breaking incremental caching in the source generator!
-    /// </remarks>
-    public sealed record SourceGenerationSpec
-    {
-        public required ImmutableEquatableArray<ContextGenerationSpec> ContextGenerationSpecs { get; init; }
-
-        public required ImmutableEquatableArray<DiagnosticInfo> Diagnostics { get; init; }
-    }
-}
index affe327..b7aa0a9 100644 (file)
@@ -54,6 +54,7 @@
     <Compile Include="Helpers\RoslynExtensions.cs" />
     <Compile Include="Helpers\SourceWriter.cs" />
     <Compile Include="JsonConstants.cs" />
+    <Compile Include="JsonSourceGenerator.DiagnosticDescriptors.cs" />
     <Compile Include="JsonSourceGenerator.Emitter.cs" />
     <Compile Include="JsonSourceGenerator.Emitter.ExceptionMessages.cs" />
     <Compile Include="JsonSourceGenerator.Parser.cs" />
@@ -65,7 +66,6 @@
     <Compile Include="Model\ParameterGenerationSpec.cs" />
     <Compile Include="Model\PropertyGenerationSpec.cs" />
     <Compile Include="Model\PropertyInitializerGenerationSpec.cs" />
-    <Compile Include="Model\SourceGenerationSpec.cs" />
     <Compile Include="Model\TypeGenerationSpec.cs" />
     <Compile Include="Model\TypeRef.cs" />
   </ItemGroup>
index 37a5e4c..66a55a4 100644 (file)
@@ -20,13 +20,11 @@ namespace System.Text.Json.SourceGeneration.UnitTests
     public record JsonSourceGeneratorResult
     {
         public Compilation NewCompilation { get; set; }
-        public SourceGenerationSpec? SourceGenModel { get; set; }
+        public ImmutableArray<ContextGenerationSpec> ContextGenerationSpecs { get; set; }
         public ImmutableArray<Diagnostic> Diagnostics { get; set; }
 
         public IEnumerable<TypeGenerationSpec> AllGeneratedTypes
-            => SourceGenModel is { } model
-                ? model.ContextGenerationSpecs.SelectMany(ctx => ctx.GeneratedTypes)
-                : Array.Empty<TypeGenerationSpec>();
+            => ContextGenerationSpecs.SelectMany(ctx => ctx.GeneratedTypes);
 
         public void AssertContainsType(string fullyQualifiedName)
             => Assert.Contains(
@@ -118,10 +116,10 @@ namespace System.Text.Json.SourceGeneration.UnitTests
 
         public static JsonSourceGeneratorResult RunJsonSourceGenerator(Compilation compilation)
         {
-            SourceGenerationSpec? generatedSpec = null;
+            var generatedSpecs = ImmutableArray<ContextGenerationSpec>.Empty;
             var generator = new JsonSourceGenerator
             {
-                OnSourceEmitting = spec => generatedSpec = spec
+                OnSourceEmitting = specs => generatedSpecs = specs
             };
 
             CSharpGeneratorDriver driver = CreateJsonSourceGeneratorDriver(generator);
@@ -130,7 +128,7 @@ namespace System.Text.Json.SourceGeneration.UnitTests
             {
                 NewCompilation = outCompilation,
                 Diagnostics = diagnostics,
-                SourceGenModel = generatedSpec,
+                ContextGenerationSpecs = generatedSpecs,
             };
         }
 
index 35081ca..60a3ed9 100644 (file)
@@ -21,17 +21,18 @@ namespace System.Text.Json.SourceGeneration.UnitTests
             JsonSourceGeneratorResult result1 = CompilationHelper.RunJsonSourceGenerator(factory());
             JsonSourceGeneratorResult result2 = CompilationHelper.RunJsonSourceGenerator(factory());
 
-            if (result1.SourceGenModel is null)
-            {
-                Assert.Null(result2.SourceGenModel);
-            }
-            else
+            Assert.Equal(result1.ContextGenerationSpecs.Length, result2.ContextGenerationSpecs.Length);
+
+            for (int i = 0; i < result1.ContextGenerationSpecs.Length; i++)
             {
-                Assert.NotSame(result1.SourceGenModel, result2.SourceGenModel);
-                AssertStructurallyEqual(result1.SourceGenModel, result2.SourceGenModel);
+                ContextGenerationSpec ctx1 = result1.ContextGenerationSpecs[i];
+                ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[i];
+
+                Assert.NotSame(ctx1, ctx2);
+                AssertStructurallyEqual(ctx1, ctx2);
 
-                Assert.Equal(result1.SourceGenModel, result2.SourceGenModel);
-                Assert.Equal(result1.SourceGenModel.GetHashCode(), result2.SourceGenModel.GetHashCode());
+                Assert.Equal(ctx1, ctx2);
+                Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode());
             }
         }
 
@@ -81,11 +82,17 @@ namespace System.Text.Json.SourceGeneration.UnitTests
             Assert.Empty(result1.Diagnostics);
             Assert.Empty(result2.Diagnostics);
 
-            Assert.NotSame(result1.SourceGenModel, result2.SourceGenModel);
-            AssertStructurallyEqual(result1.SourceGenModel, result2.SourceGenModel);
+            Assert.Equal(1, result1.ContextGenerationSpecs.Length);
+            Assert.Equal(1, result2.ContextGenerationSpecs.Length);
 
-            Assert.Equal(result1.SourceGenModel, result2.SourceGenModel);
-            Assert.Equal(result1.SourceGenModel.GetHashCode(), result2.SourceGenModel.GetHashCode());
+            ContextGenerationSpec ctx1 = result1.ContextGenerationSpecs[0];
+            ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[0];
+
+            Assert.NotSame(ctx1, ctx2);
+            AssertStructurallyEqual(ctx1, ctx2);
+
+            Assert.Equal(ctx1, ctx2);
+            Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode());
         }
 
         [Fact]
@@ -128,7 +135,12 @@ namespace System.Text.Json.SourceGeneration.UnitTests
             Assert.Empty(result1.Diagnostics);
             Assert.Empty(result2.Diagnostics);
 
-            Assert.NotEqual(result1.SourceGenModel, result2.SourceGenModel);
+            Assert.Equal(1, result1.ContextGenerationSpecs.Length);
+            Assert.Equal(1, result2.ContextGenerationSpecs.Length);
+
+            ContextGenerationSpec ctx1 = result1.ContextGenerationSpecs[0];
+            ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[0];
+            Assert.NotEqual(ctx1, ctx2);
         }
 
         [Theory]
@@ -136,7 +148,7 @@ namespace System.Text.Json.SourceGeneration.UnitTests
         public static void SourceGenModelDoesNotEncapsulateSymbolsOrCompilationData(Func<Compilation> factory)
         {
             JsonSourceGeneratorResult result = CompilationHelper.RunJsonSourceGenerator(factory());
-            WalkObjectGraph(result.SourceGenModel);
+            WalkObjectGraph(result.ContextGenerationSpecs);
 
             static void WalkObjectGraph(object obj)
             {
@@ -187,26 +199,50 @@ namespace System.Text.Json.SourceGeneration.UnitTests
 
             driver = driver.RunGenerators(compilation);
             GeneratorRunResult runResult = driver.GetRunResult().Results[0];
-            Assert.Collection(runResult.TrackedSteps[JsonSourceGenerator.SourceGenerationSpecTrackingName],
-                step =>
-                {
-                    Assert.Collection(step.Inputs,
-                        source => Assert.Equal(IncrementalStepRunReason.New, source.Source.Outputs[source.OutputIndex].Reason));
-                    Assert.Collection(step.Outputs,
-                        output => Assert.Equal(IncrementalStepRunReason.New, output.Reason));
-                });
+
+            IncrementalGeneratorRunStep[] runSteps = GetSourceGenRunStep(runResult);
+            if (runSteps != null)
+            {
+                Assert.Collection(runSteps,
+                    step =>
+                    {
+                        Assert.Collection(step.Inputs,
+                            source => Assert.Equal(IncrementalStepRunReason.New, source.Source.Outputs[source.OutputIndex].Reason));
+                        Assert.Collection(step.Outputs,
+                            output => Assert.Equal(IncrementalStepRunReason.New, output.Reason));
+                    });
+            }
 
             // run the same compilation through again, and confirm the output wasn't called
             driver = driver.RunGenerators(compilation);
             runResult = driver.GetRunResult().Results[0];
-            Assert.Collection(runResult.TrackedSteps[JsonSourceGenerator.SourceGenerationSpecTrackingName],
-                step =>
+            IncrementalGeneratorRunStep[] runSteps2 = GetSourceGenRunStep(runResult);
+
+            if (runSteps != null)
+            {
+                Assert.Collection(runSteps2,
+                    step =>
+                    {
+                        Assert.Collection(step.Inputs,
+                            source => Assert.Equal(IncrementalStepRunReason.Cached, source.Source.Outputs[source.OutputIndex].Reason));
+                        Assert.Collection(step.Outputs,
+                            output => Assert.Equal(IncrementalStepRunReason.Cached, output.Reason));
+                    });
+            }
+            else
+            {
+                Assert.Null(runSteps2);
+            }
+
+            static IncrementalGeneratorRunStep[]? GetSourceGenRunStep(GeneratorRunResult runResult)
+            {
+                if (!runResult.TrackedSteps.TryGetValue(JsonSourceGenerator.SourceGenerationSpecTrackingName, out var runSteps))
                 {
-                    Assert.Collection(step.Inputs,
-                        source => Assert.Equal(IncrementalStepRunReason.Cached, source.Source.Outputs[source.OutputIndex].Reason));
-                    Assert.Collection(step.Outputs,
-                        output => Assert.Equal(IncrementalStepRunReason.Cached, output.Reason));
-                });
+                    return null;
+                }
+
+                return runSteps.ToArray();
+            }
         }
 
         [Fact]