Skip to content

Commit

Permalink
Merge pull request #38 from sveinungf/dev/srcgen-caching
Browse files Browse the repository at this point in the history
Source generator - Cachable pipeline
  • Loading branch information
sveinungf authored Feb 18, 2024
2 parents a9c4718 + 093ee7c commit 5b39f72
Show file tree
Hide file tree
Showing 29 changed files with 1,111 additions and 474 deletions.
192 changes: 186 additions & 6 deletions SpreadCheetah.SourceGenerator.SnapshotTest/Helpers/TestHelper.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using SpreadCheetah.SourceGeneration;
using System.Collections;
using System.Collections.Immutable;
using System.Reflection;

namespace SpreadCheetah.SourceGenerator.SnapshotTest.Helpers;

internal static class TestHelper
{
public static SettingsTask CompileAndVerify<T>(string source, params object?[] parameters) where T : IIncrementalGenerator, new()
private static PortableExecutableReference[] GetAssemblyReferences()
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);

var dotNetAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location) ?? throw new InvalidOperationException();

var references = new[]
{
return
[
MetadataReference.CreateFromFile(Path.Combine(dotNetAssemblyPath, "mscorlib.dll")),
MetadataReference.CreateFromFile(Path.Combine(dotNetAssemblyPath, "netstandard.dll")),
MetadataReference.CreateFromFile(Path.Combine(dotNetAssemblyPath, "System.dll")),
Expand All @@ -22,8 +23,13 @@ internal static class TestHelper
MetadataReference.CreateFromFile(Path.Combine(dotNetAssemblyPath, "System.Runtime.dll")),
MetadataReference.CreateFromFile(typeof(WorksheetRowAttribute).Assembly.Location),
MetadataReference.CreateFromFile(typeof(TestHelper).Assembly.Location)
};
];
}

public static SettingsTask CompileAndVerify<T>(string source, bool replaceEscapedLineEndings = false, params object?[] parameters) where T : IIncrementalGenerator, new()
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var references = GetAssemblyReferences();
var compilation = CSharpCompilation.Create("Tests", [syntaxTree], references);

var generator = new T();
Expand All @@ -35,10 +41,184 @@ internal static class TestHelper
var settings = new VerifySettings();
settings.UseDirectory("../Snapshots");

if (replaceEscapedLineEndings)
settings.ScrubLinesWithReplace(x => x.Replace("\\r\\n", "\\n", StringComparison.Ordinal));

var task = Verify(target, settings);

return parameters.Length > 0
? task.UseParameters(parameters)
: task;
}

/// <summary>
/// Based on the implementation from:
/// https://andrewlock.net/creating-a-source-generator-part-10-testing-your-incremental-generator-pipeline-outputs-are-cacheable/
/// </summary>
public static (ImmutableArray<Diagnostic> Diagnostics, string[] Output) GetGeneratedTrees<T>(
string source,
string[] trackingStages,
bool assertOutputs = true)
where T : IIncrementalGenerator, new()
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var references = GetAssemblyReferences();
var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary);
var compilation = CSharpCompilation.Create("SpreadCheetah.Generated", [syntaxTree], references, options);

// Run the generator, get the results, and assert cacheability if applicable
var runResult = RunGeneratorAndAssertOutput<T>(compilation, trackingStages, assertOutputs);

// Return the generator diagnostics and generated sources
return (runResult.Diagnostics, runResult.GeneratedTrees.Select(x => x.ToString()).ToArray());
}

private static GeneratorDriverRunResult RunGeneratorAndAssertOutput<T>(CSharpCompilation compilation, string[] trackingNames, bool assertOutput = true)
where T : IIncrementalGenerator, new()
{
var generator = new T().AsSourceGenerator();

// ⚠ Tell the driver to track all the incremental generator outputs
// without this, you'll have no tracked outputs!
var opts = new GeneratorDriverOptions(
disabledOutputs: IncrementalGeneratorOutputKind.None,
trackIncrementalGeneratorSteps: true);

GeneratorDriver driver = CSharpGeneratorDriver.Create([generator], driverOptions: opts);

// Create a clone of the compilation that we will use later
var clone = compilation.Clone();

// Do the initial run
// Note that we store the returned driver value, as it contains cached previous outputs
driver = driver.RunGenerators(compilation);
GeneratorDriverRunResult runResult = driver.GetRunResult();

if (assertOutput)
{
// Run again, using the same driver, with a clone of the compilation
var runResult2 = driver.RunGenerators(clone).GetRunResult();

// Compare all the tracked outputs, throw if there's a failure
AssertRunsEqual(runResult, runResult2, trackingNames);

// verify the second run only generated cached source outputs
var outputs = runResult2
.Results[0]
.TrackedOutputSteps
.SelectMany(x => x.Value) // step executions
.SelectMany(x => x.Outputs); // execution results

var output = Assert.Single(outputs);
Assert.Equal(IncrementalStepRunReason.Cached, output.Reason);
}

return runResult;
}

private static void AssertRunsEqual(
GeneratorDriverRunResult runResult1,
GeneratorDriverRunResult runResult2,
string[] trackingNames)
{
// We're given all the tracking names, but not all the
// stages will necessarily execute, so extract all the
// output steps, and filter to ones we know about
var trackedSteps1 = GetTrackedSteps(runResult1, trackingNames);
var trackedSteps2 = GetTrackedSteps(runResult2, trackingNames);

// Both runs should have the same tracked steps
var trackedSteps1Keys = trackedSteps1.Keys.ToHashSet(StringComparer.Ordinal);
Assert.True(trackedSteps1Keys.SetEquals(trackedSteps2.Keys));

// Get the IncrementalGeneratorRunStep collection for each run
foreach (var (trackingName, runSteps1) in trackedSteps1)
{
// Assert that both runs produced the same outputs
var runSteps2 = trackedSteps2[trackingName];
AssertEqual(runSteps1, runSteps2, trackingName);
}

static Dictionary<string, ImmutableArray<IncrementalGeneratorRunStep>> GetTrackedSteps(
GeneratorDriverRunResult runResult, string[] trackingNames)
{
return runResult
.Results[0] // We're only running a single generator, so this is safe
.TrackedSteps // Get the pipeline outputs
.Where(step => trackingNames.Contains(step.Key, StringComparer.Ordinal))
.ToDictionary(x => x.Key, x => x.Value, StringComparer.Ordinal);
}
}

private static void AssertEqual(
ImmutableArray<IncrementalGeneratorRunStep> runSteps1,
ImmutableArray<IncrementalGeneratorRunStep> runSteps2,
string stepName)
{
Assert.Equal(runSteps1.Length, runSteps2.Length);

foreach (var (runStep1, runStep2) in runSteps1.Zip(runSteps2))
{
// The outputs should be equal between different runs
var outputs1 = runStep1.Outputs.Select(x => x.Value);
var outputs2 = runStep2.Outputs.Select(x => x.Value);

Assert.True(outputs1.SequenceEqual(outputs2), $"Step {stepName} did not produce cacheable outputs");

// Therefore, on the second run the results should always be cached or unchanged!
// - Unchanged is when the _input_ has changed, but the output hasn't
// - Cached is when the the input has not changed, so the cached output is used
Assert.All(runStep2.Outputs, x => Assert.True(x.Reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged));

// Make sure we're not using anything we shouldn't
AssertObjectGraph(runStep1);
}
}

private static void AssertObjectGraph(IncrementalGeneratorRunStep runStep)
{
var visited = new HashSet<object>();

// Check all of the outputs - probably overkill, but why not
foreach (var (obj, _) in runStep.Outputs)
{
Visit(obj);
}

void Visit(object? node)
{
// If we've already seen this object, or it's null, stop.
if (node is null || !visited.Add(node))
return;

// Make sure it's not a banned type
Assert.IsNotAssignableFrom<Compilation>(node);
Assert.IsNotAssignableFrom<ISymbol>(node);
Assert.IsNotAssignableFrom<SyntaxNode>(node);

// Examine the object
var type = node.GetType();
if (type.IsPrimitive || type.IsEnum || type == typeof(string))
return;

// If the object is a collection, check each of the values
if (node is IEnumerable collection and not string)
{
foreach (object element in collection)
{
// recursively check each element in the collection
Visit(element);
}

return;
}

// Recursively check each field in the object
foreach (var field in type.GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
var fieldValue = field.GetValue(node);
Visit(fieldValue);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//HintName: MyNamespace.MyGenRowContext.g.cs
//HintName: MyNamespace.MyGenRowContext.g.cs
// <auto-generated />
#nullable enable
using SpreadCheetah;
Expand Down Expand Up @@ -32,9 +32,9 @@ private static async ValueTask AddHeaderRow0Async(SpreadCheetah.Spreadsheet spre
cells[0] = new StyledCell("First name", styleId);
cells[1] = new StyledCell("", styleId);
cells[2] = new StyledCell("Nationality (escaped characters \", ', \\)", styleId);
cells[3] = new StyledCell("Address line 1 (escaped characters \r\n, \t)", styleId);
cells[4] = new StyledCell("Address line 2 (verbatim\r\nstring: \", \\)", styleId);
cells[5] = new StyledCell(" Age (\r\n raw\r\n string\r\n literal\r\n )", styleId);
cells[3] = new StyledCell("Address line 1 (escaped characters \n, \t)", styleId);
cells[4] = new StyledCell("Address line 2 (verbatim\nstring: \", \\)", styleId);
cells[5] = new StyledCell(" Age (\n raw\n string\n literal\n )", styleId);
cells[6] = new StyledCell("Note (unicode escape sequence 🌉, 👍, ç)", styleId);
cells[7] = new StyledCell("Note 2 (constant interpolated string: This is a constant)", styleId);
await spreadsheet.AddRowAsync(cells.AsMemory(0, 8), token).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// <auto-generated />
#nullable enable
using SpreadCheetah;
using SpreadCheetah.SourceGeneration;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace MyNamespace
{
public partial class MyGenRowContext
{
private static MyGenRowContext? _default;
public static MyGenRowContext Default => _default ??= new MyGenRowContext();

public MyGenRowContext()
{
}

private WorksheetRowTypeInfo<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty>? _ClassWithSingleProperty;
public WorksheetRowTypeInfo<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty> ClassWithSingleProperty => _ClassWithSingleProperty
??= WorksheetRowMetadataServices.CreateObjectInfo<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty>(AddHeaderRow0Async, AddAsRowAsync, AddRangeAsRowsAsync);

private static async ValueTask AddHeaderRow0Async(SpreadCheetah.Spreadsheet spreadsheet, SpreadCheetah.Styling.StyleId? styleId, CancellationToken token)
{
var cells = ArrayPool<StyledCell>.Shared.Rent(1);
try
{
cells[0] = new StyledCell("Name", styleId);
await spreadsheet.AddRowAsync(cells.AsMemory(0, 1), token).ConfigureAwait(false);
}
finally
{
ArrayPool<StyledCell>.Shared.Return(cells, true);
}
}

private static ValueTask AddAsRowAsync(SpreadCheetah.Spreadsheet spreadsheet, SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty? obj, CancellationToken token)
{
if (spreadsheet is null)
throw new ArgumentNullException(nameof(spreadsheet));
if (obj is null)
return spreadsheet.AddRowAsync(ReadOnlyMemory<DataCell>.Empty, token);
return AddAsRowInternalAsync(spreadsheet, obj, token);
}

private static ValueTask AddRangeAsRowsAsync(SpreadCheetah.Spreadsheet spreadsheet, IEnumerable<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty?> objs, CancellationToken token)
{
if (spreadsheet is null)
throw new ArgumentNullException(nameof(spreadsheet));
if (objs is null)
throw new ArgumentNullException(nameof(objs));
return AddRangeAsRowsInternalAsync(spreadsheet, objs, token);
}

private static async ValueTask AddAsRowInternalAsync(SpreadCheetah.Spreadsheet spreadsheet, SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty obj, CancellationToken token)
{
var cells = ArrayPool<DataCell>.Shared.Rent(1);
try
{
await AddCellsAsRowAsync(spreadsheet, obj, cells, token).ConfigureAwait(false);
}
finally
{
ArrayPool<DataCell>.Shared.Return(cells, true);
}
}

private static async ValueTask AddRangeAsRowsInternalAsync(SpreadCheetah.Spreadsheet spreadsheet, IEnumerable<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty?> objs, CancellationToken token)
{
var cells = ArrayPool<DataCell>.Shared.Rent(1);
try
{
await AddEnumerableAsRowsAsync(spreadsheet, objs, cells, token).ConfigureAwait(false);
}
finally
{
ArrayPool<DataCell>.Shared.Return(cells, true);
}
}

private static async ValueTask AddEnumerableAsRowsAsync(SpreadCheetah.Spreadsheet spreadsheet, IEnumerable<SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty?> objs, DataCell[] cells, CancellationToken token)
{
foreach (var obj in objs)
{
await AddCellsAsRowAsync(spreadsheet, obj, cells, token).ConfigureAwait(false);
}
}

private static ValueTask AddCellsAsRowAsync(SpreadCheetah.Spreadsheet spreadsheet, SpreadCheetah.SourceGenerator.SnapshotTest.Models.ClassWithSingleProperty? obj, DataCell[] cells, CancellationToken token)
{
if (obj is null)
return spreadsheet.AddRowAsync(ReadOnlyMemory<DataCell>.Empty, token);

cells[0] = new DataCell(obj.Name);
return spreadsheet.AddRowAsync(cells.AsMemory(0, 1), token);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ public partial class MyGenRowContext : WorksheetRowContext;
""";

// Act & Assert
return TestHelper.CompileAndVerify<WorksheetRowGenerator>(source);
return TestHelper.CompileAndVerify<WorksheetRowGenerator>(source, replaceEscapedLineEndings: true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,36 @@ namespace SpreadCheetah.SourceGenerator.SnapshotTest.Tests;

public class WorksheetRowGeneratorTests
{
[Fact]
public Task WorksheetRowGenerator_Generate_CachingCorrectly()
{
// Arrange
const string source = """
using SpreadCheetah.SourceGeneration;
using SpreadCheetah.SourceGenerator.SnapshotTest.Models;
using System;
namespace MyNamespace
{
[WorksheetRow(typeof(ClassWithSingleProperty))]
public partial class MyGenRowContext : WorksheetRowContext
{
}
}
""";

// Act
var (diagnostics, output) = TestHelper.GetGeneratedTrees<WorksheetRowGenerator>(source, ["Transform"]);

// Assert
Assert.Empty(diagnostics);
var outputSource = Assert.Single(output);

var settings = new VerifySettings();
settings.UseDirectory("../Snapshots");
return Verify(outputSource, settings);
}

[Fact]
public Task WorksheetRowGenerator_Generate_ClassWithSingleProperty()
{
Expand Down
Loading

0 comments on commit 5b39f72

Please sign in to comment.