If you're using Rider, there is a convenient template for adding a source generator project to your solution that sets up the above for you. Use the wizard to add a new project to the solution and select the Roslyn Source Generators template.
With that done, we can move on to our code generator - or, rather, a test.
Our test class doesn't use anything special but needs some setup to run the code generation. To keep the tests tidy, I will break this out into a helper method.
public class GeneratorHelpers
{
public static IDictionary<string, string> RunSourceGenerator(IIncrementalGenerator generator, string sourceCode)
{
var driver = CSharpGeneratorDriver.Create(generator);
// Create a compilation with the provided source code
var compilation = CSharpCompilation.Create(nameof(GeneratorHelpers),
new[] { CSharpSyntaxTree.ParseText(sourceCode) },
new[]
{
// To support 'System.Attribute' inheritance, add reference to 'System.Private.CoreLib'.
MetadataReference.CreateFromFile(typeof(object).Assembly.Location)
});
// Run generators and retrieve all results.
var runResult = driver.RunGenerators(compilation).GetRunResult();
// Map the generated file names to their source code and return that
return runResult.GeneratedTrees
.ToDictionary(tree => tree.FilePath, tree => tree.GetText().ToString());
}
}
What happens here is that we manually create a compilation, register our source code as a file and then capture the generated source files at the end of the compilation so that we can make assertions about them.
We then use this helper to write our test:
[Fact]
public void ThatOneFileIsGeneratedForEachDatabaseEngine()
{
// Given a source code generator
var generator = new DbTestGenerator();
// And a source file that should trigger source generation
var source = @"
using Codegen;
namespace User.Repository;
[DbTest]
public abstract class UserRepositoryTests<T>
{
}
";
// When source code generation is run
var generatedSources = GeneratorHelpers.RunSourceGenerator(generator, source);
// Then a file for each database engine is generated
generatedSources.Should().ContainKey("Codegen/Codegen.DbTestGenerator/UserRepositoryTests.generated.cs")
.WhoseValue.Trim().Should().Be(@"
// <auto-generated/>
namespace User.Repository;
public class PostgresUserRepositoryTests(): UserRepositoryTests<PostgresEngine>(new PostgresEngine()) {}
public class MariaDbUserRepositoryTests(): UserRepositoryTests<MariaDbEngine>(new MariaDbEngine()) {}
".Trim());
}
The key part of this test is the source code for the generator input. It differs from the code at the top of this post in two important ways:
The generic parameter is required if we are using Xunit collections to supply the database engine as a shared fixture. The attribute is what is called a marker attribute. We use this as a filter in our generator so that we only generate code where we should - we don't want to create database-flavoured implementations for all abstract classes in the code base :)
With the test written, let's implement some code to satisfy it:
[Generator]
public class DbTestGenerator: IIncrementalGenerator
{
private const string Namespace = "Codegen";
private const string AttributeName = "DbTestAttribute";
private const string AttributeSourceCode = $@"
// <auto-generated/>
namespace {Namespace};
[System.AttributeUsage(System.AttributeTargets.Class)]
public class {AttributeName} : System.Attribute
{{
}}
";
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Add the marker attribute to the compilation.
context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
$"{AttributeName}.g.cs",
SourceText.From(AttributeSourceCode, Encoding.UTF8)));
// Locate and register generators for all classes marked with the marker attribute
var testsToGenerate = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsClassWithAttributes(s),
transform: static (context, _) => GetTestGenerator(context)
)
.Where(static c => c is not null)
.Collect()
.Combine(context.CompilationProvider);
context.RegisterSourceOutput(testsToGenerate,
static (context, pair) =>
{
foreach (var generator in pair.Left)
{
generator.Generate(context, pair.Right);
}
});
}
A lot is going on here, so let's break it down.
First, we register the marker attribute so that it is available to our test classes. We do this by registering a source file and its contents (the const string).
To avoid inspecting everything, we use a simple filter that locates all classes in the compilation with at least one attribute. For each class lotcated, we ensure it is marked with the correct attribute and create a Generator (more on that later).
private static bool IsClassWithAttributes(SyntaxNode syntaxNode)
{
return syntaxNode is ClassDeclarationSyntax classNode
&& classNode.AttributeLists
.SelectMany(list => list.Attributes).Any();
}
private static Generator? GetTestGenerator(GeneratorSyntaxContext context)
{
var classNode = context.Node as ClassDeclarationSyntax; // We know this won't fail
var hasAttribute = classNode!.AttributeLists.SelectMany(list => list.Attributes)
.Select(attribute => context.SemanticModel.GetSymbolInfo(attribute).Symbol)
.OfType<IMethodSymbol>()
.Any(symbol => symbol.ContainingType.ToDisplayString().Equals($"{Namespace}.{AttributeName}"));
return hasAttribute ? new Generator(context) : null;
}
Finally, let's look at the Generator class:
private class Generator(GeneratorSyntaxContext gsc)
{
private readonly ClassDeclarationSyntax classNode = (ClassDeclarationSyntax)gsc.Node;
public void Generate(SourceProductionContext spc, Compilation compilation)
{
var semanticModel = compilation.GetSemanticModel(classNode.SyntaxTree);
if (semanticModel.GetDeclaredSymbol(classNode) is not INamedTypeSymbol classSymbol)
{
return;
}
var @namespace = classSymbol.ContainingNamespace.ToDisplayString();
var className = classNode.Identifier.Text;
GenerateCode(spc, @namespace, className, ["Postgres", "MariaDb"]);
}
private void GenerateCode(SourceProductionContext spc, string @namespace, string className, IEnumerable<string> providers)
{
var code = $@"
// <auto-generated/>
namespace {@namespace};
";
foreach (var provider in providers)
{
code += $"public class {provider}{className}(): {className}<{provider}Engine>(new {provider}Engine()) {{}}\n";
}
spc.AddSource($"{className}.generated.cs", code);
}
}
What we do here is get some metadata about the test class - name and namespace - and use that to generate a new class using string interpolation. We then add that to the compilation under a unique name (first-hand experience taught me that non-unique names will cause a silent failure).
We can now run our test, and it will pass.
The last step is to add the code generator to our unit-test project and let it generate the test classes for us. We do this by adding a dependency on the code generator project, telling MSBuild that it is a code generator.
<ProjectReference
Include="..\Codegen\Codegen.csproj"
OutputItemType="Analyzer"
ReferenceOutputAssembly="false"
/>
Now, when we compile our tests, the generator will generate our database flavours for us:
Adding support for a new database is as easy as implementing a new IDbEngine
and adding it to the array in our generator:
diff --git a/Codegen.Tests/DbTestGeneratorTests.cs b/Codegen.Tests/DbTestGeneratorTests.cs
index 6c843da..365b806 100644
--- a/Codegen.Tests/DbTestGeneratorTests.cs
+++ b/Codegen.Tests/DbTestGeneratorTests.cs
@@ -38,6 +38,7 @@ public class DbTestGeneratorTests
namespace User.Repository;
public class PostgresUserRepositoryTests(): UserRepositoryTests<PostgresEngine>(new PostgresEngine()) {}
public class MariaDbUserRepositoryTests(): UserRepositoryTests<MariaDbEngine>(new MariaDbEngine()) {}
+public class SqliteUserRepositoryTests(): UserRepositoryTests<SqliteEngine>(new SqliteEngine()) {}
".Trim());
}
}
diff --git a/Codegen/DbTestGenerator.cs b/Codegen/DbTestGenerator.cs
index d189496..c01d53b 100644
--- a/Codegen/DbTestGenerator.cs
+++ b/Codegen/DbTestGenerator.cs
@@ -100,7 +100,7 @@ namespace {Namespace}
var @namespace = classSymbol.ContainingNamespace.ToDisplayString();
var className = classNode.Identifier.Text;
- GenerateCode(spc, @namespace, className, ["Postgres", "MariaDb"]);
+ GenerateCode(spc, @namespace, className, ["Postgres", "MariaDb", "Sqlite"]);
}
private void GenerateCode(SourceProductionContext spc, string @namespace, string className, IEnumerable<string> providers)
All that remains is to implement the database engines and the user repository. I leave that as an exercise for the reader. You can also check out a complete example with passing tests over on my GitHub.