diff --git a/src/Generators/DurableTaskSourceGenerator.cs b/src/Generators/DurableTaskSourceGenerator.cs index 19a24cc9..538448e7 100644 --- a/src/Generators/DurableTaskSourceGenerator.cs +++ b/src/Generators/DurableTaskSourceGenerator.cs @@ -67,6 +67,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(compilationAndTasks, static (spc, source) => Execute(spc, source.Item1, source.Item2, source.Item3)); } + static string GetNamespaceOrEmpty(INamespaceSymbol namespaceSymbol) + { + // Return empty string for global namespace, otherwise return the display string + return namespaceSymbol.IsGlobalNamespace ? string.Empty : namespaceSymbol.ToDisplayString(); + } + static DurableTaskTypeInfo? GetDurableTaskTypeInfo(GeneratorSyntaxContext context) { AttributeSyntax attribute = (AttributeSyntax)context.Node; @@ -94,6 +100,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } string className = classType.ToDisplayString(); + string classNamespace = GetNamespaceOrEmpty(classType.ContainingNamespace); INamedTypeSymbol? taskType = null; DurableTaskKind kind = DurableTaskKind.Orchestrator; @@ -158,7 +165,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) taskName = context.SemanticModel.GetConstantValue(expression).ToString(); } - return new DurableTaskTypeInfo(className, taskName, inputType, outputType, kind); + return new DurableTaskTypeInfo(className, classNamespace, taskName, inputType, outputType, kind); } static DurableFunction? GetDurableFunction(GeneratorSyntaxContext context) @@ -216,12 +223,31 @@ static void Execute( return; } + // Group tasks by namespace + // For tasks in the global namespace (empty string), use "Microsoft.DurableTask" for backward compatibility + Dictionary> tasksByNamespace = new(); + + foreach (DurableTaskTypeInfo task in allTasks) + { + string targetNamespace = string.IsNullOrEmpty(task.Namespace) ? "Microsoft.DurableTask" : task.Namespace; + + if (!tasksByNamespace.TryGetValue(targetNamespace, out List? tasksInNamespace)) + { + tasksInNamespace = new List(); + tasksByNamespace[targetNamespace] = tasksInNamespace; + } + + tasksInNamespace.Add(task); + } + + // Generate a separate class for each namespace StringBuilder sourceBuilder = new(capacity: found * 1024); sourceBuilder.Append(@"// #nullable enable using System; using System.Threading.Tasks; +using Microsoft.DurableTask; using Microsoft.DurableTask.Internal;"); if (isDurableFunctions) @@ -231,135 +257,260 @@ static void Execute( using Microsoft.Extensions.DependencyInjection;"); } - sourceBuilder.Append(@" + sourceBuilder.AppendLine(); -namespace Microsoft.DurableTask -{ - public static class GeneratedDurableTaskExtensions - {"); - if (isDurableFunctions) + // Activity function triggers are supported for code-gen (but not orchestration triggers) + IEnumerable activityTriggers = allFunctions.Where( + df => df.Kind == DurableFunctionKind.Activity); + + // For non-Functions, we need to add registration methods to Microsoft.DurableTask namespace + bool needsRegistrationBlock = !isDurableFunctions && (orchestrators.Count > 0 || activities.Count > 0 || entities.Count > 0); + + // Generate extension classes grouped by namespace + foreach (KeyValuePair> namespaceGroup in tasksByNamespace) { - // Generate a singleton orchestrator object instance that can be reused for all invocations. - foreach (DurableTaskTypeInfo orchestrator in orchestrators) + string targetNamespace = namespaceGroup.Key; + List tasksInNamespace = namespaceGroup.Value; + + List orchestratorsInNamespace = tasksInNamespace.Where(t => t.IsOrchestrator).ToList(); + List activitiesInNamespace = tasksInNamespace.Where(t => t.IsActivity).ToList(); + List entitiesInNamespace = tasksInNamespace.Where(t => t.IsEntity).ToList(); + + // Check if there's actually any content to generate for this namespace + bool hasOrchestratorMethods = orchestratorsInNamespace.Count > 0; + bool hasActivityMethods = activitiesInNamespace.Count > 0; + bool hasEntityFunctions = isDurableFunctions && entitiesInNamespace.Count > 0; + bool hasActivityTriggers = targetNamespace == "Microsoft.DurableTask" && activityTriggers.Any(); + bool hasRegistrationMethod = !isDurableFunctions && targetNamespace == "Microsoft.DurableTask" && needsRegistrationBlock; + + // Skip this namespace block if there's nothing to generate + if (!hasOrchestratorMethods && !hasActivityMethods && !hasEntityFunctions && !hasActivityTriggers && !hasRegistrationMethod) { - sourceBuilder.AppendLine($@" - static readonly ITaskOrchestrator singleton{orchestrator.TaskName} = new {orchestrator.TypeName}();"); + continue; } - } - foreach (DurableTaskTypeInfo orchestrator in orchestrators) - { + sourceBuilder.AppendLine(); + sourceBuilder.AppendLine($"namespace {targetNamespace}"); + sourceBuilder.AppendLine("{"); + sourceBuilder.AppendLine(" public static class GeneratedDurableTaskExtensions"); + sourceBuilder.AppendLine(" {"); + if (isDurableFunctions) { - // Generate the function definition required to trigger orchestrators in Azure Functions - AddOrchestratorFunctionDeclaration(sourceBuilder, orchestrator); + // Generate a singleton orchestrator object instance that can be reused for all invocations. + foreach (DurableTaskTypeInfo orchestrator in orchestratorsInNamespace) + { + string simplifiedTypeName = SimplifyTypeNameForNamespace(orchestrator.TypeName, targetNamespace); + sourceBuilder.AppendLine($@" static readonly ITaskOrchestrator singleton{orchestrator.TaskName} = new {simplifiedTypeName}();"); + } } - AddOrchestratorCallMethod(sourceBuilder, orchestrator); - AddSubOrchestratorCallMethod(sourceBuilder, orchestrator); - } + foreach (DurableTaskTypeInfo orchestrator in orchestratorsInNamespace) + { + if (isDurableFunctions) + { + // Generate the function definition required to trigger orchestrators in Azure Functions + AddOrchestratorFunctionDeclaration(sourceBuilder, orchestrator, targetNamespace); + } - foreach (DurableTaskTypeInfo activity in activities) - { - AddActivityCallMethod(sourceBuilder, activity); + AddOrchestratorCallMethod(sourceBuilder, orchestrator, targetNamespace); + AddSubOrchestratorCallMethod(sourceBuilder, orchestrator, targetNamespace); + } - if (isDurableFunctions) + foreach (DurableTaskTypeInfo activity in activitiesInNamespace) { - // Generate the function definition required to trigger activities in Azure Functions - AddActivityFunctionDeclaration(sourceBuilder, activity); + AddActivityCallMethod(sourceBuilder, activity, targetNamespace); + + if (isDurableFunctions) + { + // Generate the function definition required to trigger activities in Azure Functions + AddActivityFunctionDeclaration(sourceBuilder, activity, targetNamespace); + } + } + + foreach (DurableTaskTypeInfo entity in entitiesInNamespace) + { + if (isDurableFunctions) + { + // Generate the function definition required to trigger entities in Azure Functions + AddEntityFunctionDeclaration(sourceBuilder, entity, targetNamespace); + } + } + + // Add activity triggers from DurableFunction to Microsoft.DurableTask namespace only + if (targetNamespace == "Microsoft.DurableTask" && activityTriggers.Any()) + { + foreach (DurableFunction function in activityTriggers) + { + AddActivityCallMethod(sourceBuilder, function); + } } - } - foreach (DurableTaskTypeInfo entity in entities) - { if (isDurableFunctions) { - // Generate the function definition required to trigger entities in Azure Functions - AddEntityFunctionDeclaration(sourceBuilder, entity); + if (activitiesInNamespace.Count > 0) + { + // Functions-specific helper class, which is only needed when + // using the class-based syntax. + AddGeneratedActivityContextClass(sourceBuilder); + } + } + else + { + // ASP.NET Core-specific service registration methods - add to Microsoft.DurableTask namespace only + if (targetNamespace == "Microsoft.DurableTask" && needsRegistrationBlock) + { + AddRegistrationMethodForAllTasks( + sourceBuilder, + orchestrators, + activities, + entities); + needsRegistrationBlock = false; // Mark as added + } } - } - // Activity function triggers are supported for code-gen (but not orchestration triggers) - IEnumerable activityTriggers = allFunctions.Where( - df => df.Kind == DurableFunctionKind.Activity); - foreach (DurableFunction function in activityTriggers) - { - AddActivityCallMethod(sourceBuilder, function); + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine("}"); } - if (isDurableFunctions) + // If we still need to add activity triggers or registration methods and they haven't been added yet + // (because there's no Microsoft.DurableTask namespace block), create one now + if (activityTriggers.Any() && !tasksByNamespace.ContainsKey("Microsoft.DurableTask")) { - if (activities.Count > 0) + sourceBuilder.AppendLine(); + sourceBuilder.AppendLine("namespace Microsoft.DurableTask"); + sourceBuilder.AppendLine("{"); + sourceBuilder.AppendLine(" public static class GeneratedDurableTaskExtensions"); + sourceBuilder.AppendLine(" {"); + + foreach (DurableFunction function in activityTriggers) { - // Functions-specific helper class, which is only needed when - // using the class-based syntax. - AddGeneratedActivityContextClass(sourceBuilder); + AddActivityCallMethod(sourceBuilder, function); } + + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine("}"); } - else + + if (needsRegistrationBlock && !tasksByNamespace.ContainsKey("Microsoft.DurableTask")) { - // ASP.NET Core-specific service registration methods + sourceBuilder.AppendLine(); + sourceBuilder.AppendLine("namespace Microsoft.DurableTask"); + sourceBuilder.AppendLine("{"); + sourceBuilder.AppendLine(" public static class GeneratedDurableTaskExtensions"); + sourceBuilder.AppendLine(" {"); + AddRegistrationMethodForAllTasks( sourceBuilder, orchestrators, activities, entities); - } - sourceBuilder.AppendLine(" }").AppendLine("}"); + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine("}"); + } context.AddSource("GeneratedDurableTaskExtensions.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8, SourceHashAlgorithm.Sha256)); + } - static void AddOrchestratorFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator) + static string SimplifyTypeNameForNamespace(string fullyQualifiedTypeName, string targetNamespace) { + // Don't simplify if target namespace is empty (global namespace) + if (string.IsNullOrEmpty(targetNamespace)) + { + return fullyQualifiedTypeName; + } + + if (fullyQualifiedTypeName.StartsWith(targetNamespace + ".", StringComparison.Ordinal)) + { + return fullyQualifiedTypeName.Substring(targetNamespace.Length + 1); + } + + return fullyQualifiedTypeName; + } + + static void AddOrchestratorFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator, string targetNamespace) + { + string inputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(orchestrator.InputTypeSymbol, targetNamespace); + string outputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(orchestrator.OutputTypeSymbol, targetNamespace); + sourceBuilder.AppendLine($@" [Function(nameof({orchestrator.TaskName}))] - public static Task<{orchestrator.OutputType}> {orchestrator.TaskName}([OrchestrationTrigger] TaskOrchestrationContext context) + public static Task<{outputType}> {orchestrator.TaskName}([OrchestrationTrigger] TaskOrchestrationContext context) {{ - return singleton{orchestrator.TaskName}.RunAsync(context, context.GetInput<{orchestrator.InputType}>()) - .ContinueWith(t => ({orchestrator.OutputType})(t.Result ?? default({orchestrator.OutputType})!), TaskContinuationOptions.ExecuteSynchronously); + return singleton{orchestrator.TaskName}.RunAsync(context, context.GetInput<{inputType}>()) + .ContinueWith(t => ({outputType})(t.Result ?? default({outputType})!), TaskContinuationOptions.ExecuteSynchronously); }}"); } - static void AddOrchestratorCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator) + static void AddOrchestratorCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator, string targetNamespace) { + string inputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(orchestrator.InputTypeSymbol, targetNamespace); + string inputParameter = inputType + " input"; + if (inputType.EndsWith("?", StringComparison.Ordinal)) + { + inputParameter += " = default"; + } + + string simplifiedTypeName = SimplifyTypeNameForNamespace(orchestrator.TypeName, targetNamespace); + sourceBuilder.AppendLine($@" /// - /// Schedules a new instance of the orchestrator. + /// Schedules a new instance of the orchestrator. /// /// public static Task ScheduleNew{orchestrator.TaskName}InstanceAsync( - this IOrchestrationSubmitter client, {orchestrator.InputParameter}, StartOrchestrationOptions? options = null) + this IOrchestrationSubmitter client, {inputParameter}, StartOrchestrationOptions? options = null) {{ return client.ScheduleNewOrchestrationInstanceAsync(""{orchestrator.TaskName}"", input, options); }}"); } - static void AddSubOrchestratorCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator) + static void AddSubOrchestratorCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo orchestrator, string targetNamespace) { + string inputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(orchestrator.InputTypeSymbol, targetNamespace); + string outputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(orchestrator.OutputTypeSymbol, targetNamespace); + string inputParameter = inputType + " input"; + if (inputType.EndsWith("?", StringComparison.Ordinal)) + { + inputParameter += " = default"; + } + + string simplifiedTypeName = SimplifyTypeNameForNamespace(orchestrator.TypeName, targetNamespace); + sourceBuilder.AppendLine($@" /// - /// Calls the sub-orchestrator. + /// Calls the sub-orchestrator. /// /// - public static Task<{orchestrator.OutputType}> Call{orchestrator.TaskName}Async( - this TaskOrchestrationContext context, {orchestrator.InputParameter}, TaskOptions? options = null) + public static Task<{outputType}> Call{orchestrator.TaskName}Async( + this TaskOrchestrationContext context, {inputParameter}, TaskOptions? options = null) {{ - return context.CallSubOrchestratorAsync<{orchestrator.OutputType}>(""{orchestrator.TaskName}"", input, options); + return context.CallSubOrchestratorAsync<{outputType}>(""{orchestrator.TaskName}"", input, options); }}"); } - static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo activity) + static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeInfo activity, string targetNamespace) { + string inputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(activity.InputTypeSymbol, targetNamespace); + string outputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(activity.OutputTypeSymbol, targetNamespace); + string inputParameter = inputType + " input"; + if (inputType.EndsWith("?", StringComparison.Ordinal)) + { + inputParameter += " = default"; + } + + string simplifiedTypeName = SimplifyTypeNameForNamespace(activity.TypeName, targetNamespace); + sourceBuilder.AppendLine($@" /// - /// Calls the activity. + /// Calls the activity. /// /// - public static Task<{activity.OutputType}> Call{activity.TaskName}Async(this TaskOrchestrationContext ctx, {activity.InputParameter}, TaskOptions? options = null) + public static Task<{outputType}> Call{activity.TaskName}Async(this TaskOrchestrationContext ctx, {inputParameter}, TaskOptions? options = null) {{ - return ctx.CallActivityAsync<{activity.OutputType}>(""{activity.TaskName}"", input, options); + return ctx.CallActivityAsync<{outputType}>(""{activity.TaskName}"", input, options); }}"); } @@ -376,29 +527,41 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction a }}"); } - static void AddActivityFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo activity) + static void AddActivityFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo activity, string targetNamespace) { + string inputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(activity.InputTypeSymbol, targetNamespace); + string outputType = DurableTaskTypeInfo.GetRenderedTypeExpressionForNamespace(activity.OutputTypeSymbol, targetNamespace); + string inputParameter = inputType + " input"; + if (inputType.EndsWith("?", StringComparison.Ordinal)) + { + inputParameter += " = default"; + } + + string simplifiedActivityTypeName = SimplifyTypeNameForNamespace(activity.TypeName, targetNamespace); + // GeneratedActivityContext is a generated class that we use for each generated activity trigger definition. // Note that the second "instanceId" parameter is populated via the Azure Functions binding context. sourceBuilder.AppendLine($@" [Function(nameof({activity.TaskName}))] - public static async Task<{activity.OutputType}> {activity.TaskName}([ActivityTrigger] {activity.InputParameter}, string instanceId, FunctionContext executionContext) + public static async Task<{outputType}> {activity.TaskName}([ActivityTrigger] {inputParameter}, string instanceId, FunctionContext executionContext) {{ - ITaskActivity activity = ActivatorUtilities.GetServiceOrCreateInstance<{activity.TypeName}>(executionContext.InstanceServices); + ITaskActivity activity = ActivatorUtilities.GetServiceOrCreateInstance<{simplifiedActivityTypeName}>(executionContext.InstanceServices); TaskActivityContext context = new GeneratedActivityContext(""{activity.TaskName}"", instanceId); object? result = await activity.RunAsync(context, input); - return ({activity.OutputType})result!; + return ({outputType})result!; }}"); } - static void AddEntityFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo entity) + static void AddEntityFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo entity, string targetNamespace) { + string simplifiedEntityTypeName = SimplifyTypeNameForNamespace(entity.TypeName, targetNamespace); + // Generate the entity trigger function that dispatches to the entity implementation. sourceBuilder.AppendLine($@" [Function(nameof({entity.TaskName}))] public static Task {entity.TaskName}([EntityTrigger] TaskEntityDispatcher dispatcher) {{ - return dispatcher.DispatchAsync<{entity.TypeName}>(); + return dispatcher.DispatchAsync<{simplifiedEntityTypeName}>(); }}"); } @@ -474,12 +637,14 @@ class DurableTaskTypeInfo { public DurableTaskTypeInfo( string taskType, + string taskNamespace, string taskName, ITypeSymbol? inputType, ITypeSymbol? outputType, DurableTaskKind kind) { this.TypeName = taskType; + this.Namespace = taskNamespace; this.TaskName = taskName; this.Kind = kind; @@ -489,9 +654,13 @@ public DurableTaskTypeInfo( this.InputType = string.Empty; this.InputParameter = string.Empty; this.OutputType = string.Empty; + this.InputTypeSymbol = null; + this.OutputTypeSymbol = null; } else { + this.InputTypeSymbol = inputType; + this.OutputTypeSymbol = outputType; this.InputType = GetRenderedTypeExpression(inputType); this.InputParameter = this.InputType + " input"; if (this.InputType[this.InputType.Length - 1] == '?') @@ -504,11 +673,14 @@ public DurableTaskTypeInfo( } public string TypeName { get; } + public string Namespace { get; } public string TaskName { get; } public string InputType { get; } public string InputParameter { get; } public string OutputType { get; } public DurableTaskKind Kind { get; } + public ITypeSymbol? InputTypeSymbol { get; } + public ITypeSymbol? OutputTypeSymbol { get; } public bool IsActivity => this.Kind == DurableTaskKind.Activity; @@ -516,6 +688,34 @@ public DurableTaskTypeInfo( public bool IsEntity => this.Kind == DurableTaskKind.Entity; + /// + /// Gets a rendered type expression for the given type symbol relative to a target namespace. + /// + public static string GetRenderedTypeExpressionForNamespace(ITypeSymbol? symbol, string targetNamespace) + { + if (symbol == null) + { + return "object"; + } + + string expression = symbol.ToDisplayString(); + + // Simplify System types + if (expression.StartsWith("System.", StringComparison.Ordinal) + && symbol.ContainingNamespace.ToDisplayString() == "System") + { + expression = expression.Substring("System.".Length); + } + // Simplify types in the same namespace + else if (symbol.ContainingNamespace.ToDisplayString() == targetNamespace) + { + // Use the simple name if the type is in the same namespace + expression = symbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + } + + return expression; + } + static string GetRenderedTypeExpression(ITypeSymbol? symbol) { if (symbol == null) diff --git a/test/Generators.Tests/AzureFunctionsTests.cs b/test/Generators.Tests/AzureFunctionsTests.cs index d9d7fad0..ccba5697 100644 --- a/test/Generators.Tests/AzureFunctionsTests.cs +++ b/test/Generators.Tests/AzureFunctionsTests.cs @@ -221,7 +221,7 @@ public class MyOrchestrator : TaskOrchestrator<{inputType}, {outputType}> string expectedOutput = TestHelpers.WrapAndFormat( GeneratedClassName, methodList: $@" -static readonly ITaskOrchestrator singletonMyOrchestrator = new MyNS.MyOrchestrator(); +static readonly ITaskOrchestrator singletonMyOrchestrator = new MyOrchestrator(); [Function(nameof(MyOrchestrator))] public static Task<{outputType}> MyOrchestrator([OrchestrationTrigger] TaskOrchestrationContext context) @@ -231,7 +231,7 @@ public class MyOrchestrator : TaskOrchestrator<{inputType}, {outputType}> }} /// -/// Schedules a new instance of the orchestrator. +/// Schedules a new instance of the orchestrator. /// /// public static Task ScheduleNewMyOrchestratorInstanceAsync( @@ -241,7 +241,7 @@ public static Task ScheduleNewMyOrchestratorInstanceAsync( }} /// -/// Calls the sub-orchestrator. +/// Calls the sub-orchestrator. /// /// public static Task<{outputType}> CallMyOrchestratorAsync( @@ -249,6 +249,7 @@ public static Task ScheduleNewMyOrchestratorInstanceAsync( {{ return context.CallSubOrchestratorAsync<{outputType}>(""MyOrchestrator"", input, options); }}", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( @@ -304,7 +305,7 @@ public abstract class MyOrchestratorBase : TaskOrchestrator<{inputType}, {output string expectedOutput = TestHelpers.WrapAndFormat( GeneratedClassName, methodList: $@" -static readonly ITaskOrchestrator singletonMyOrchestrator = new MyNS.MyOrchestrator(); +static readonly ITaskOrchestrator singletonMyOrchestrator = new MyOrchestrator(); [Function(nameof(MyOrchestrator))] public static Task<{outputType}> MyOrchestrator([OrchestrationTrigger] TaskOrchestrationContext context) @@ -314,7 +315,7 @@ public abstract class MyOrchestratorBase : TaskOrchestrator<{inputType}, {output }} /// -/// Schedules a new instance of the orchestrator. +/// Schedules a new instance of the orchestrator. /// /// public static Task ScheduleNewMyOrchestratorInstanceAsync( @@ -324,7 +325,7 @@ public static Task ScheduleNewMyOrchestratorInstanceAsync( }} /// -/// Calls the sub-orchestrator. +/// Calls the sub-orchestrator. /// /// public static Task<{outputType}> CallMyOrchestratorAsync( @@ -332,6 +333,7 @@ public static Task ScheduleNewMyOrchestratorInstanceAsync( {{ return context.CallSubOrchestratorAsync<{outputType}>(""MyOrchestrator"", input, options); }}", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( @@ -372,8 +374,9 @@ public class MyEntity : TaskEntity<{stateType}> [Function(nameof(MyEntity))] public static Task MyEntity([EntityTrigger] TaskEntityDispatcher dispatcher) { - return dispatcher.DispatchAsync(); + return dispatcher.DispatchAsync(); }", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( @@ -419,8 +422,9 @@ public abstract class MyEntityBase : TaskEntity<{stateType}> [Function(nameof(MyEntity))] public static Task MyEntity([EntityTrigger] TaskEntityDispatcher dispatcher) { - return dispatcher.DispatchAsync(); + return dispatcher.DispatchAsync(); }", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( @@ -463,8 +467,9 @@ public class MyEntity : TaskEntity [Function(nameof(MyEntity))] public static Task MyEntity([EntityTrigger] TaskEntityDispatcher dispatcher) { - return dispatcher.DispatchAsync(); + return dispatcher.DispatchAsync(); }", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( @@ -512,7 +517,7 @@ public class MyEntity : TaskEntity string expectedOutput = TestHelpers.WrapAndFormat( GeneratedClassName, methodList: $@" -static readonly ITaskOrchestrator singletonMyOrchestrator = new MyNS.MyOrchestrator(); +static readonly ITaskOrchestrator singletonMyOrchestrator = new MyOrchestrator(); [Function(nameof(MyOrchestrator))] public static Task MyOrchestrator([OrchestrationTrigger] TaskOrchestrationContext context) @@ -522,7 +527,7 @@ public static Task MyOrchestrator([OrchestrationTrigger] TaskOrchestrati }} /// -/// Schedules a new instance of the orchestrator. +/// Schedules a new instance of the orchestrator. /// /// public static Task ScheduleNewMyOrchestratorInstanceAsync( @@ -532,7 +537,7 @@ public static Task ScheduleNewMyOrchestratorInstanceAsync( }} /// -/// Calls the sub-orchestrator. +/// Calls the sub-orchestrator. /// /// public static Task CallMyOrchestratorAsync( @@ -542,7 +547,7 @@ public static Task CallMyOrchestratorAsync( }} /// -/// Calls the activity. +/// Calls the activity. /// /// public static Task CallMyActivityAsync(this TaskOrchestrationContext ctx, int input, TaskOptions? options = null) @@ -553,7 +558,7 @@ public static Task CallMyActivityAsync(this TaskOrchestrationContext ctx [Function(nameof(MyActivity))] public static async Task MyActivity([ActivityTrigger] int input, string instanceId, FunctionContext executionContext) {{ - ITaskActivity activity = ActivatorUtilities.GetServiceOrCreateInstance(executionContext.InstanceServices); + ITaskActivity activity = ActivatorUtilities.GetServiceOrCreateInstance(executionContext.InstanceServices); TaskActivityContext context = new GeneratedActivityContext(""MyActivity"", instanceId); object? result = await activity.RunAsync(context, input); return (string)result!; @@ -562,9 +567,10 @@ public static async Task MyActivity([ActivityTrigger] int input, string [Function(nameof(MyEntity))] public static Task MyEntity([EntityTrigger] TaskEntityDispatcher dispatcher) {{ - return dispatcher.DispatchAsync(); + return dispatcher.DispatchAsync(); }} {TestHelpers.DeIndent(DurableTaskSourceGenerator.GetGeneratedActivityContextCode(), spacesToRemove: 8)}", + targetNamespace: "MyNS", isDurableFunctions: true); await TestHelpers.RunTestAsync( diff --git a/test/Generators.Tests/ClassBasedSyntaxTests.cs b/test/Generators.Tests/ClassBasedSyntaxTests.cs index c3abc6c5..b6960c6e 100644 --- a/test/Generators.Tests/ClassBasedSyntaxTests.cs +++ b/test/Generators.Tests/ClassBasedSyntaxTests.cs @@ -228,23 +228,44 @@ class MyActivityImpl : TaskActivity public class MyClass { } }"; - string expectedOutput = TestHelpers.WrapAndFormat( - GeneratedClassName, - methodList: @" -/// -/// Calls the activity. -/// -/// -public static Task CallMyActivityAsync(this TaskOrchestrationContext ctx, MyNS.MyClass input, TaskOptions? options = null) + string expectedOutput = @" +// +#nullable enable + +using System; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Internal; + +namespace MyNS { - return ctx.CallActivityAsync(""MyActivity"", input, options); + public static class GeneratedDurableTaskExtensions + { + + /// + /// Calls the activity. + /// + /// + public static Task CallMyActivityAsync(this TaskOrchestrationContext ctx, MyClass input, TaskOptions? options = null) + { + return ctx.CallActivityAsync(""MyActivity"", input, options); + } + } } -internal static DurableTaskRegistry AddAllGeneratedTasks(this DurableTaskRegistry builder) +namespace Microsoft.DurableTask { - builder.AddActivity(); - return builder; -}"); + public static class GeneratedDurableTaskExtensions + { + + internal static DurableTaskRegistry AddAllGeneratedTasks(this DurableTaskRegistry builder) + { + builder.AddActivity(); + return builder; + } + } +} +".TrimStart(); return TestHelpers.RunTestAsync( GeneratedFileName, diff --git a/test/Generators.Tests/Utils/TestHelpers.cs b/test/Generators.Tests/Utils/TestHelpers.cs index 67f030d9..480483a1 100644 --- a/test/Generators.Tests/Utils/TestHelpers.cs +++ b/test/Generators.Tests/Utils/TestHelpers.cs @@ -57,11 +57,17 @@ public static Task RunTestAsync( } public static string WrapAndFormat(string generatedClassName, string methodList, bool isDurableFunctions = false) + { + return WrapAndFormat(generatedClassName, methodList, "Microsoft.DurableTask", isDurableFunctions); + } + + public static string WrapAndFormat(string generatedClassName, string methodList, string targetNamespace, bool isDurableFunctions = false) { string formattedMethodList = IndentLines(spaces: 8, methodList); string usings = @" using System; using System.Threading.Tasks; +using Microsoft.DurableTask; using Microsoft.DurableTask.Internal;"; if (isDurableFunctions) @@ -71,16 +77,26 @@ public static string WrapAndFormat(string generatedClassName, string methodList, using Microsoft.Extensions.DependencyInjection;"; } + // The generator adds a blank line after the opening brace of the class, except when in Functions + // mode and the first content is singleton declarations. This logic matches that behavior. + // Note: This creates tight coupling between test formatting and generator implementation. + // If the generator's blank line logic changes, this will need to be updated as well. + string blankLineAfterBrace = ""; + if (!isDurableFunctions || !methodList.TrimStart().StartsWith("static readonly")) + { + blankLineAfterBrace = "\n"; + } + return $@" // #nullable enable {usings} -namespace Microsoft.DurableTask +namespace {targetNamespace} {{ public static class {generatedClassName} {{ - {formattedMethodList.TrimStart()} +{blankLineAfterBrace} {formattedMethodList.TrimStart()} }} }} ".TrimStart();