diff --git a/Funcky.DiscriminatedUnion.SourceGeneration/Functional.cs b/Funcky.DiscriminatedUnion.SourceGeneration/Functional.cs new file mode 100644 index 0000000..7342824 --- /dev/null +++ b/Funcky.DiscriminatedUnion.SourceGeneration/Functional.cs @@ -0,0 +1,6 @@ +namespace Funcky.DiscriminatedUnion.SourceGeneration; + +internal static class Functional +{ + public static Lazy Lazy(Func func) => new(func); +} diff --git a/Funcky.DiscriminatedUnion.SourceGeneration/Parser.JsonPolymorphic.cs b/Funcky.DiscriminatedUnion.SourceGeneration/Parser.JsonPolymorphic.cs new file mode 100644 index 0000000..2e22109 --- /dev/null +++ b/Funcky.DiscriminatedUnion.SourceGeneration/Parser.JsonPolymorphic.cs @@ -0,0 +1,33 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using static Funcky.DiscriminatedUnion.SourceGeneration.Functional; + +namespace Funcky.DiscriminatedUnion.SourceGeneration; + +internal static partial class Parser +{ + private const string JsonPolymorphicAttributeName = "System.Text.Json.Serialization.JsonPolymorphicAttribute"; + private const string JsonDerivedTypeAttributeName = "System.Text.Json.Serialization.JsonDerivedTypeAttribute"; + + private static Func GenerateJsonDerivedTypeAttribute(INamedTypeSymbol discriminatedUnion) + { + var generateJsonDerivedTypeAttributes = Lazy(() => discriminatedUnion.GetAttributes().Any(IsJsonPolymorphicAttribute)); + var jsonDerivedTypes = Lazy(() => GetJsonDerivedTypes(discriminatedUnion)); + return variant => generateJsonDerivedTypeAttributes.Value && !jsonDerivedTypes.Value.Contains(variant); + } + + private static ImmutableHashSet GetJsonDerivedTypes(INamedTypeSymbol discriminatedUnion) + => discriminatedUnion.GetAttributes() + .Select(GetJsonDerivedType) + .Where(t => t is not null)! + .ToImmutableHashSet(SymbolEqualityComparer.Default); + + private static bool IsJsonPolymorphicAttribute(AttributeData attribute) + => attribute.AttributeClass?.ToDisplayString() is JsonPolymorphicAttributeName or JsonDerivedTypeAttributeName; + + private static INamedTypeSymbol? GetJsonDerivedType(AttributeData attribute) + => attribute.AttributeClass?.ToDisplayString() is JsonDerivedTypeAttributeName + && attribute.ConstructorArguments.First() is { Kind: TypedConstantKind.Type, Value: INamedTypeSymbol value } + ? value + : null; +} diff --git a/Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs b/Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs index e4f467c..f236948 100644 --- a/Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs +++ b/Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs @@ -6,11 +6,8 @@ namespace Funcky.DiscriminatedUnion.SourceGeneration; -internal static class Parser +internal static partial class Parser { - private const string JsonPolymorphicAttributeName = "System.Text.Json.Serialization.JsonPolymorphicAttribute"; - private const string JsonDerivedTypeAttributeName = "System.Text.Json.Serialization.JsonDerivedTypeAttribute"; - public static bool IsSyntaxTargetForGeneration(SyntaxNode node) => node is ClassDeclarationSyntax { AttributeLists.Count: > 0 } or RecordDeclarationSyntax { AttributeLists.Count: > 0 }; @@ -35,7 +32,6 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node) var (nonExhaustive, flatten, matchResultType) = ParseAttribute(typeSymbol); var isVariant = flatten ? IsVariantOfDiscriminatedUnionFlattened(typeSymbol, semanticModel) : IsVariantOfDiscriminatedUnion(typeSymbol, semanticModel); - var generateJsonDerivedTypeAttributes = typeSymbol.GetAttributes().Any(IsJsonPolymorphicAttribute); return new DiscriminatedUnion( Type: typeDeclaration, @@ -44,7 +40,7 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node) MatchResultTypeName: matchResultType ?? "TResult", MethodVisibility: nonExhaustive ? "internal" : "public", Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant) - .Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, _ => generateJsonDerivedTypeAttributes)) + .Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol))) .ToList()); } @@ -123,9 +119,6 @@ private static Func IsDiscriminatedUnionAttribute(Generat => context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol && attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName; - private static bool IsJsonPolymorphicAttribute(AttributeData attribute) - => attribute.AttributeClass?.ToDisplayString() is JsonPolymorphicAttributeName or JsonDerivedTypeAttributeName; - private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType); private sealed class VariantCollectingVisitor : CSharpSyntaxWalker diff --git a/Funcky.DiscriminatedUnion.Test/SourceGeneratorTest.GeneratesExpectedSourceCode_sourceFileName=JsonPolymorphic.01.verified.cs b/Funcky.DiscriminatedUnion.Test/SourceGeneratorTest.GeneratesExpectedSourceCode_sourceFileName=JsonPolymorphic.01.verified.cs index 8fa0555..77018a6 100644 --- a/Funcky.DiscriminatedUnion.Test/SourceGeneratorTest.GeneratesExpectedSourceCode_sourceFileName=JsonPolymorphic.01.verified.cs +++ b/Funcky.DiscriminatedUnion.Test/SourceGeneratorTest.GeneratesExpectedSourceCode_sourceFileName=JsonPolymorphic.01.verified.cs @@ -65,3 +65,40 @@ partial record Error } } } + +[global::System.Text.Json.Serialization.JsonDerivedType(typeof(global::Shape.EquilateralTriangle), "EquilateralTriangle")] +partial record Shape +{ + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public abstract TResult Match(global::System.Func rectangle, global::System.Func circle, global::System.Func equilateralTriangle); + + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public abstract void Switch(global::System.Action rectangle, global::System.Action circle, global::System.Action equilateralTriangle); + + partial record Rectangle + { + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override TResult Match(global::System.Func rectangle, global::System.Func circle, global::System.Func equilateralTriangle) => rectangle(this); + + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override void Switch(global::System.Action rectangle, global::System.Action circle, global::System.Action equilateralTriangle) => rectangle(this); + } + + partial record Circle + { + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override TResult Match(global::System.Func rectangle, global::System.Func circle, global::System.Func equilateralTriangle) => circle(this); + + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override void Switch(global::System.Action rectangle, global::System.Action circle, global::System.Action equilateralTriangle) => circle(this); + } + + partial record EquilateralTriangle + { + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override TResult Match(global::System.Func rectangle, global::System.Func circle, global::System.Func equilateralTriangle) => equilateralTriangle(this); + + [global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")] + public override void Switch(global::System.Action rectangle, global::System.Action circle, global::System.Action equilateralTriangle) => equilateralTriangle(this); + } +} diff --git a/Funcky.DiscriminatedUnion.Test/Sources/JsonPolymorphic.cs b/Funcky.DiscriminatedUnion.Test/Sources/JsonPolymorphic.cs index dd38d77..b4c3406 100644 --- a/Funcky.DiscriminatedUnion.Test/Sources/JsonPolymorphic.cs +++ b/Funcky.DiscriminatedUnion.Test/Sources/JsonPolymorphic.cs @@ -21,3 +21,15 @@ public sealed partial record Error(string Message) : Result; } } } + +[Funcky.DiscriminatedUnion] +[System.Text.Json.Serialization.JsonDerivedType(typeof(Rectangle), typeDiscriminator: 1)] +[System.Text.Json.Serialization.JsonDerivedType(typeof(Circle), typeDiscriminator: "⏺")] +public abstract partial record Shape +{ + public sealed partial record Rectangle(double Width, double Length) : Shape; + + public sealed partial record Circle(double Radius) : Shape; + + public partial record EquilateralTriangle(double SideLength) : Shape; +}