Skip to content

Commit

Permalink
Only emit JsonDerivedType attribute when not already present
Browse files Browse the repository at this point in the history
  • Loading branch information
bash committed Nov 19, 2022
1 parent d290d2c commit 10d132a
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 9 deletions.
6 changes: 6 additions & 0 deletions Funcky.DiscriminatedUnion.SourceGeneration/Functional.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static class Functional
{
public static Lazy<T> Lazy<T>(Func<T> func) => new(func);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using static Funcky.DiscriminatedUnion.SourceGeneration.Functional;

namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static partial class Parser
{
private const string JsonPolymorphicAttributeName = "System.Text.Json.Serialization.JsonPolymorphicAttribute";
private const string JsonDerivedTypeAttributeName = "System.Text.Json.Serialization.JsonDerivedTypeAttribute";

private static Func<INamedTypeSymbol, bool> GenerateJsonDerivedTypeAttribute(INamedTypeSymbol discriminatedUnion)
{
var generateJsonDerivedTypeAttributes = Lazy(() => discriminatedUnion.GetAttributes().Any(IsJsonPolymorphicAttribute));
var jsonDerivedTypes = Lazy(() => GetJsonDerivedTypes(discriminatedUnion));
return variant => generateJsonDerivedTypeAttributes.Value && !jsonDerivedTypes.Value.Contains(variant);
}

private static ImmutableHashSet<INamedTypeSymbol> GetJsonDerivedTypes(INamedTypeSymbol discriminatedUnion)
=> discriminatedUnion.GetAttributes()
.Select(GetJsonDerivedType)
.Where(t => t is not null)!
.ToImmutableHashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);

private static bool IsJsonPolymorphicAttribute(AttributeData attribute)
=> attribute.AttributeClass?.ToDisplayString() is JsonPolymorphicAttributeName or JsonDerivedTypeAttributeName;

private static INamedTypeSymbol? GetJsonDerivedType(AttributeData attribute)
=> attribute.AttributeClass?.ToDisplayString() is JsonDerivedTypeAttributeName
&& attribute.ConstructorArguments.First() is { Kind: TypedConstantKind.Type, Value: INamedTypeSymbol value }
? value
: null;
}
11 changes: 2 additions & 9 deletions Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@

namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static class Parser
internal static partial class Parser
{
private const string JsonPolymorphicAttributeName = "System.Text.Json.Serialization.JsonPolymorphicAttribute";
private const string JsonDerivedTypeAttributeName = "System.Text.Json.Serialization.JsonDerivedTypeAttribute";

public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
=> node is ClassDeclarationSyntax { AttributeLists.Count: > 0 }
or RecordDeclarationSyntax { AttributeLists.Count: > 0 };
Expand All @@ -35,7 +32,6 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)

var (nonExhaustive, flatten, matchResultType) = ParseAttribute(typeSymbol);
var isVariant = flatten ? IsVariantOfDiscriminatedUnionFlattened(typeSymbol, semanticModel) : IsVariantOfDiscriminatedUnion(typeSymbol, semanticModel);
var generateJsonDerivedTypeAttributes = typeSymbol.GetAttributes().Any(IsJsonPolymorphicAttribute);

return new DiscriminatedUnion(
Type: typeDeclaration,
Expand All @@ -44,7 +40,7 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
MatchResultTypeName: matchResultType ?? "TResult",
MethodVisibility: nonExhaustive ? "internal" : "public",
Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant)
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, _ => generateJsonDerivedTypeAttributes))
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol)))
.ToList());
}

Expand Down Expand Up @@ -123,9 +119,6 @@ private static Func<AttributeSyntax, bool> IsDiscriminatedUnionAttribute(Generat
=> context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol
&& attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName;

private static bool IsJsonPolymorphicAttribute(AttributeData attribute)
=> attribute.AttributeClass?.ToDisplayString() is JsonPolymorphicAttributeName or JsonDerivedTypeAttributeName;

private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);

private sealed class VariantCollectingVisitor : CSharpSyntaxWalker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,40 @@ partial record Error
}
}
}

[global::System.Text.Json.Serialization.JsonDerivedType(typeof(global::Shape.EquilateralTriangle), "EquilateralTriangle")]
partial record Shape
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public abstract TResult Match<TResult>(global::System.Func<Rectangle, TResult> rectangle, global::System.Func<Circle, TResult> circle, global::System.Func<EquilateralTriangle, TResult> equilateralTriangle);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public abstract void Switch(global::System.Action<Rectangle> rectangle, global::System.Action<Circle> circle, global::System.Action<EquilateralTriangle> equilateralTriangle);

partial record Rectangle
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override TResult Match<TResult>(global::System.Func<Rectangle, TResult> rectangle, global::System.Func<Circle, TResult> circle, global::System.Func<EquilateralTriangle, TResult> equilateralTriangle) => rectangle(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override void Switch(global::System.Action<Rectangle> rectangle, global::System.Action<Circle> circle, global::System.Action<EquilateralTriangle> equilateralTriangle) => rectangle(this);
}

partial record Circle
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override TResult Match<TResult>(global::System.Func<Rectangle, TResult> rectangle, global::System.Func<Circle, TResult> circle, global::System.Func<EquilateralTriangle, TResult> equilateralTriangle) => circle(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override void Switch(global::System.Action<Rectangle> rectangle, global::System.Action<Circle> circle, global::System.Action<EquilateralTriangle> equilateralTriangle) => circle(this);
}

partial record EquilateralTriangle
{
[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override TResult Match<TResult>(global::System.Func<Rectangle, TResult> rectangle, global::System.Func<Circle, TResult> circle, global::System.Func<EquilateralTriangle, TResult> equilateralTriangle) => equilateralTriangle(this);

[global::System.CodeDom.Compiler.GeneratedCode("Funcky.DiscriminatedUnion.SourceGeneration", "1.0.0.0")]
public override void Switch(global::System.Action<Rectangle> rectangle, global::System.Action<Circle> circle, global::System.Action<EquilateralTriangle> equilateralTriangle) => equilateralTriangle(this);
}
}
12 changes: 12 additions & 0 deletions Funcky.DiscriminatedUnion.Test/Sources/JsonPolymorphic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,15 @@ public sealed partial record Error(string Message) : Result<T>;
}
}
}

[Funcky.DiscriminatedUnion]
[System.Text.Json.Serialization.JsonDerivedType(typeof(Rectangle), typeDiscriminator: 1)]
[System.Text.Json.Serialization.JsonDerivedType(typeof(Circle), typeDiscriminator: "")]
public abstract partial record Shape
{
public sealed partial record Rectangle(double Width, double Length) : Shape;

public sealed partial record Circle(double Radius) : Shape;

public partial record EquilateralTriangle(double SideLength) : Shape;
}

0 comments on commit 10d132a

Please sign in to comment.