Skip to content

Commit

Permalink
Merge pull request #11 from polyadic/json-polymorphic
Browse files Browse the repository at this point in the history
Generate `[JsonDerivedType]` attributes
  • Loading branch information
bash authored Nov 22, 2022
2 parents 3c41a93 + 10d132a commit 039f3bb
Show file tree
Hide file tree
Showing 14 changed files with 281 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ internal sealed record DiscriminatedUnionVariant(
TypeDeclarationSyntax Type,
IReadOnlyList<TypeDeclarationSyntax> ParentTypes,
string ParameterName,
string LocalTypeName);
string LocalTypeName,
string TypeOfTypeName,
string JsonDerivedTypeDiscriminator,
bool GenerateJsonDerivedTypeAttribute);
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ private static void AddSource(SourceProductionContext context, ImmutableArray<st
{
if (code.Any())
{
var combinedCode = GeneratedFileHeadersSource + string.Join(Environment.NewLine, code);
var combinedCode = $"{GeneratedFileHeadersSource}{Environment.NewLine}{Environment.NewLine}" +
$"{string.Join(Environment.NewLine, code)}";
context.AddSource("DiscriminatedUnionGenerator.g.cs", combinedCode);
}
}
Expand Down
17 changes: 17 additions & 0 deletions Funcky.DiscriminatedUnion.SourceGeneration/Emitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public static string EmitDiscriminatedUnion(DiscriminatedUnion discriminatedUnio

WriteParentTypes(writer, discriminatedUnion.ParentTypes);

WriteJsonDerivedTypeAttributes(writer, discriminatedUnion);
writer.WriteLine(FormatPartialTypeDeclaration(discriminatedUnion.Type));
writer.OpenScope();

Expand Down Expand Up @@ -77,6 +78,22 @@ private static void WriteGeneratedMethod(IndentedTextWriter writer, string metho
writer.WriteLine(method);
}

private static void WriteJsonDerivedTypeAttributes(IndentedTextWriter writer, DiscriminatedUnion discriminatedUnion)
{
foreach (var variant in discriminatedUnion.Variants)
{
WriteJsonDerivedTypeAttribute(writer, variant);
}
}

private static void WriteJsonDerivedTypeAttribute(IndentedTextWriter writer, DiscriminatedUnionVariant variant)
{
if (variant.GenerateJsonDerivedTypeAttribute)
{
writer.WriteLine($"[global::System.Text.Json.Serialization.JsonDerivedType(typeof({variant.TypeOfTypeName}), {SyntaxFactory.Literal(variant.JsonDerivedTypeDiscriminator)})]");
}
}

private static string FormatMatchMethodDeclaration(string genericTypeName, IEnumerable<DiscriminatedUnionVariant> variants)
=> $"{genericTypeName} Match<{genericTypeName}>({string.Join(", ", variants.Select(variant => $"global::System.Func<{variant.LocalTypeName}, {genericTypeName}> {FormatIdentifier(variant.ParameterName)}"))})";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<TargetFramework>netstandard2.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>10.0</LangVersion>
<LangVersion>11.0</LangVersion>
</PropertyGroup>
<PropertyGroup Label="NuGet Metadata">
<Version>1.0.0</Version>
Expand Down
6 changes: 6 additions & 0 deletions Funcky.DiscriminatedUnion.SourceGeneration/Functional.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static class Functional
{
public static Lazy<T> Lazy<T>(Func<T> func) => new(func);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using static Funcky.DiscriminatedUnion.SourceGeneration.Functional;

namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static partial class Parser
{
private const string JsonPolymorphicAttributeName = "System.Text.Json.Serialization.JsonPolymorphicAttribute";
private const string JsonDerivedTypeAttributeName = "System.Text.Json.Serialization.JsonDerivedTypeAttribute";

private static Func<INamedTypeSymbol, bool> GenerateJsonDerivedTypeAttribute(INamedTypeSymbol discriminatedUnion)
{
var generateJsonDerivedTypeAttributes = Lazy(() => discriminatedUnion.GetAttributes().Any(IsJsonPolymorphicAttribute));
var jsonDerivedTypes = Lazy(() => GetJsonDerivedTypes(discriminatedUnion));
return variant => generateJsonDerivedTypeAttributes.Value && !jsonDerivedTypes.Value.Contains(variant);
}

private static ImmutableHashSet<INamedTypeSymbol> GetJsonDerivedTypes(INamedTypeSymbol discriminatedUnion)
=> discriminatedUnion.GetAttributes()
.Select(GetJsonDerivedType)
.Where(t => t is not null)!
.ToImmutableHashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);

private static bool IsJsonPolymorphicAttribute(AttributeData attribute)
=> attribute.AttributeClass?.ToDisplayString() is JsonPolymorphicAttributeName or JsonDerivedTypeAttributeName;

private static INamedTypeSymbol? GetJsonDerivedType(AttributeData attribute)
=> attribute.AttributeClass?.ToDisplayString() is JsonDerivedTypeAttributeName
&& attribute.ConstructorArguments.First() is { Kind: TypedConstantKind.Type, Value: INamedTypeSymbol value }
? value
: null;
}
42 changes: 28 additions & 14 deletions Funcky.DiscriminatedUnion.SourceGeneration/Parser.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Funcky.DiscriminatedUnion.SourceGeneration.SourceCodeSnippets;

namespace Funcky.DiscriminatedUnion.SourceGeneration;

internal static class Parser
internal static partial class Parser
{
public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
=> node is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } }
or RecordDeclarationSyntax { AttributeLists: { Count: > 0 } };
=> node is ClassDeclarationSyntax { AttributeLists.Count: > 0 }
or RecordDeclarationSyntax { AttributeLists.Count: > 0 };

public static TypeDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -39,37 +40,43 @@ public static bool IsSyntaxTargetForGeneration(SyntaxNode node)
MatchResultTypeName: matchResultType ?? "TResult",
MethodVisibility: nonExhaustive ? "internal" : "public",
Variants: GetVariantTypeDeclarations(typeDeclaration, isVariant)
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel))
.Select(GetDiscriminatedUnionVariant(typeDeclaration, semanticModel, GenerateJsonDerivedTypeAttribute(typeSymbol)))
.ToList());
}

private static AttributeData ParseAttribute(ITypeSymbol type)
private static DiscriminatedUnionAttributeData ParseAttribute(ITypeSymbol type)
{
var attribute = type.GetAttributes().Single(a => a.AttributeClass?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.WithGlobalNamespaceStyle(SymbolDisplayGlobalNamespaceStyle.Omitted)) == AttributeFullName);
var nonExhaustive = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.NonExhaustive);
var flatten = attribute.GetNamedArgumentOrDefault<bool>(AttributeProperties.Flatten);
var matchResultType = attribute.GetNamedArgumentOrDefault<string>(AttributeProperties.MatchResultTypeName);
return new AttributeData(nonExhaustive, flatten, matchResultType);
return new DiscriminatedUnionAttributeData(nonExhaustive, flatten, matchResultType);
}

private static string? FormatNamespace(INamedTypeSymbol typeSymbol)
=> typeSymbol.ContainingNamespace?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.WithGlobalNamespaceStyle(SymbolDisplayGlobalNamespaceStyle.Omitted));

private static Func<TypeDeclarationSyntax, DiscriminatedUnionVariant> GetDiscriminatedUnionVariant(TypeDeclarationSyntax discrimatedUnionTypeDeclaration, SemanticModel semanticModel)
private static Func<TypeDeclarationSyntax, DiscriminatedUnionVariant> GetDiscriminatedUnionVariant(
TypeDeclarationSyntax discriminatedUnionTypeDeclaration,
SemanticModel semanticModel,
Func<INamedTypeSymbol, bool> generateJsonDerivedTypeAttribute)
=> typeDeclaration =>
{
var symbol = semanticModel.GetDeclaredSymbol(typeDeclaration)!;
return new DiscriminatedUnionVariant(
typeDeclaration,
ParentTypes: typeDeclaration.Ancestors().OfType<TypeDeclarationSyntax>().TakeWhile(t => t != discrimatedUnionTypeDeclaration).ToList(),
ParentTypes: typeDeclaration.Ancestors().OfType<TypeDeclarationSyntax>().TakeWhile(t => t != discriminatedUnionTypeDeclaration).ToList(),
ParameterName: FormatParameterName(symbol),
LocalTypeName: symbol.ToMinimalDisplayString(semanticModel, NullableFlowState.NotNull, discrimatedUnionTypeDeclaration.SpanStart));
LocalTypeName: symbol.ToMinimalDisplayString(semanticModel, NullableFlowState.NotNull, discriminatedUnionTypeDeclaration.SpanStart),
TypeOfTypeName: ToTypeNameSuitableForTypeOf(symbol),
JsonDerivedTypeDiscriminator: symbol.Name,
GenerateJsonDerivedTypeAttribute: generateJsonDerivedTypeAttribute(symbol));
};

private static IEnumerable<TypeDeclarationSyntax> GetVariantTypeDeclarations(TypeDeclarationSyntax discrimatedUnionTypeDeclaration, Func<TypeDeclarationSyntax, bool> isVariant)
private static IEnumerable<TypeDeclarationSyntax> GetVariantTypeDeclarations(TypeDeclarationSyntax discriminatedUnionTypeDeclaration, Func<TypeDeclarationSyntax, bool> isVariant)
{
var visitor = new VariantCollectingVisitor(isVariant);
discrimatedUnionTypeDeclaration.Accept(visitor);
discriminatedUnionTypeDeclaration.Accept(visitor);
return visitor.Variants;
}

Expand All @@ -78,8 +85,8 @@ private static Func<TypeDeclarationSyntax, bool> IsVariantOfDiscriminatedUnion(I
&& SymbolEqualityComparer.Default.Equals(symbol.BaseType, discriminatedUnionType);

private static Func<TypeDeclarationSyntax, bool> IsVariantOfDiscriminatedUnionFlattened(ITypeSymbol discriminatedUnionType, SemanticModel semanticModel)
=> node => semanticModel.GetDeclaredSymbol(node) is ITypeSymbol symbol
&& !symbol.IsAbstract
=> node
=> semanticModel.GetDeclaredSymbol(node) is ITypeSymbol { IsAbstract: false } symbol
&& GetBaseTypes(symbol).Any(t => SymbolEqualityComparer.Default.Equals(t, discriminatedUnionType));

private static IEnumerable<ITypeSymbol> GetBaseTypes(ITypeSymbol symbol)
Expand All @@ -97,6 +104,13 @@ private static string FormatParameterName(ITypeSymbol variant)

private static string LowerCaseFirst(string input) => char.ToLowerInvariant(input.First()) + input.Substring(1);

private static string ToTypeNameSuitableForTypeOf(INamedTypeSymbol type)
=> type
.ToDisplayParts(SymbolDisplayFormat.FullyQualifiedFormat)
.Where(part => part.Kind != SymbolDisplayPartKind.TypeParameterName)
.ToImmutableArray()
.ToDisplayString();

private static Func<AttributeListSyntax, bool> HasDiscriminatedUnionAttribute(GeneratorSyntaxContext context, CancellationToken cancellationToken)
=> attributeList => attributeList.Attributes.Any(IsDiscriminatedUnionAttribute(context, cancellationToken));

Expand All @@ -105,7 +119,7 @@ private static Func<AttributeSyntax, bool> IsDiscriminatedUnionAttribute(Generat
=> context.SemanticModel.GetSymbolInfo(attribute, cancellationToken).Symbol is IMethodSymbol attributeSymbol
&& attributeSymbol.ContainingType.ToDisplayString() == AttributeFullName;

private sealed record AttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);
private sealed record DiscriminatedUnionAttributeData(bool NonExhaustive, bool Flatten, string? MatchResultType);

private sealed class VariantCollectingVisitor : CSharpSyntaxWalker
{
Expand Down
43 changes: 25 additions & 18 deletions Funcky.DiscriminatedUnion.SourceGeneration/SourceCodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,37 @@ internal static class SourceCodeSnippets
{
public const string AttributeFullName = "Funcky.DiscriminatedUnionAttribute";

// language=c#
public const string GeneratedFileHeadersSource =
"// <auto-generated/>\n" +
"#nullable enable\n\n";
"#nullable enable";

// language=c#
public const string DiscriminatedUnionAttributeSource =
$"{GeneratedFileHeadersSource}" +
$"namespace Funcky\n" +
$"{{\n" +
$" [global::System.Diagnostics.Conditional(\"Funcky_DiscriminatedUnion\")]\n" +
$" [global::System.AttributeUsage(global::System.AttributeTargets.Class)]\n" +
$" internal sealed class DiscriminatedUnionAttribute : global::System.Attribute\n" +
$" {{\n" +
$" /// <summary>Allow only consumers in the same assembly to use the exhaustive <c>Match</c> and <c>Switch</c> methods.</summary>\n" +
$" public bool {AttributeProperties.NonExhaustive} {{ get; set; }}\n" +
$"\n" +
$" /// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>\n" +
$" public bool {AttributeProperties.Flatten} {{ get; set; }}\n" +
$"\n" +
$" /// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>\n" +
$" public string? {AttributeProperties.MatchResultTypeName} {{ get; set; }}\n" +
$" }}\n" +
$"}}\n";
$$"""
{{GeneratedFileHeadersSource}}

namespace Funcky
{
[global::System.Diagnostics.Conditional("Funcky_DiscriminatedUnion")]
[global::System.AttributeUsage(global::System.AttributeTargets.Class)]
internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
{
/// <summary>Allow only consumers in the same assembly to use the exhaustive <c>Match</c> and <c>Switch</c> methods.</summary>
public bool {{AttributeProperties.NonExhaustive}} { get; set; }

/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool {{AttributeProperties.Flatten}} { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? {{AttributeProperties.MatchResultTypeName}} { get; set; }
}
}

""";

private static readonly AssemblyName GeneratorAssemblyName = typeof(DiscriminatedUnionGenerator).Assembly.GetName();

public static readonly string GeneratedCodeAttributeSource = $"[global::System.CodeDom.Compiler.GeneratedCode(" +
$"{Literal(GeneratorAssemblyName.Name)}, " +
$"{Literal(GeneratorAssemblyName.Version.ToString())})]";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>10.0</LangVersion>
<LangVersion>11.0</LangVersion>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//HintName: DiscriminatedUnionAttribute.g.cs
// <auto-generated/>
#nullable enable

namespace Funcky
{
[global::System.Diagnostics.Conditional("Funcky_DiscriminatedUnion")]
[global::System.AttributeUsage(global::System.AttributeTargets.Class)]
internal sealed class DiscriminatedUnionAttribute : global::System.Attribute
{
/// <summary>Allow only consumers in the same assembly to use the exhaustive <c>Match</c> and <c>Switch</c> methods.</summary>
public bool NonExhaustive { get; set; }

/// <summary>Generates exhaustive <c>Match</c> and <c>Switch</c> methods for the entire type hierarchy.</summary>
public bool Flatten { get; set; }

/// <summary>Customized the generic type name used for the result in the generated <c>Match</c> methods. Defaults to <c>TResult</c>.</summary>
public string? MatchResultTypeName { get; set; }
}
}
Loading

0 comments on commit 039f3bb

Please sign in to comment.