Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/Generators/AzureFunctions/DurableFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ public class DurableFunction
public DurableFunctionKind Kind { get; }
public TypedParameter Parameter { get; }
public string ReturnType { get; }
public bool ReturnsVoid { get; }

public DurableFunction(
string fullTypeName,
string name,
DurableFunctionKind kind,
TypedParameter parameter,
ITypeSymbol returnType,
bool returnsVoid,
HashSet<string> requiredNamespaces)
{
this.FullTypeName = fullTypeName;
Expand All @@ -37,6 +39,7 @@ public DurableFunction(
this.Kind = kind;
this.Parameter = parameter;
this.ReturnType = SyntaxNodeUtility.GetRenderedTypeExpression(returnType, false);
this.ReturnsVoid = returnsVoid;
}

public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method, out DurableFunction? function)
Expand All @@ -59,12 +62,42 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,
return false;
}

INamedTypeSymbol taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")!;
INamedTypeSymbol returnSymbol = (INamedTypeSymbol)model.GetTypeInfo(returnType).Type!;
if (SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
ITypeSymbol returnTypeSymbol = model.GetTypeInfo(returnType).Type!;
bool returnsVoid = false;
INamedTypeSymbol returnSymbol;

// Check if it's a void return type
if (returnTypeSymbol.SpecialType == SpecialType.System_Void)
{
returnsVoid = true;
// For void, we'll use object as a placeholder since it won't be used
returnSymbol = model.Compilation.GetSpecialType(SpecialType.System_Object);
}
// Check if it's Task (non-generic)
else if (returnTypeSymbol is INamedTypeSymbol namedReturn)
{
INamedTypeSymbol? nonGenericTaskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
if (nonGenericTaskSymbol != null && SymbolEqualityComparer.Default.Equals(namedReturn, nonGenericTaskSymbol))
{
returnsVoid = true;
// For Task with no return, we'll use object as a placeholder since it won't be used
returnSymbol = model.Compilation.GetSpecialType(SpecialType.System_Object);
}
// Check if it's Task<T>
else
{
INamedTypeSymbol? taskSymbol = model.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1");
returnSymbol = namedReturn;
if (taskSymbol != null && SymbolEqualityComparer.Default.Equals(returnSymbol.OriginalDefinition, taskSymbol))
{
// this is a Task<T> return value, lets pull out the generic.
returnSymbol = (INamedTypeSymbol)returnSymbol.TypeArguments[0];
}
}
}
else
{
// this is a Task<T> return value, lets pull out the generic.
returnSymbol = (INamedTypeSymbol)returnSymbol.TypeArguments[0];
returnSymbol = (INamedTypeSymbol)returnTypeSymbol;
}

if (!SyntaxNodeUtility.TryGetParameter(model, method, kind, out TypedParameter? parameter) || parameter == null)
Expand Down Expand Up @@ -93,7 +126,7 @@ public static bool TryParse(SemanticModel model, MethodDeclarationSyntax method,

requiredNamespaces!.UnionWith(GetRequiredGlobalNamespaces());

function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, requiredNamespaces);
function = new DurableFunction(fullTypeName!, name, kind, parameter, returnSymbol, returnsVoid, requiredNamespaces);
return true;
}

Expand Down
12 changes: 11 additions & 1 deletion src/Generators/AzureFunctions/TypedParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@ public TypedParameter(INamedTypeSymbol type, string name)

public override string ToString()
{
return $"{SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false)} {this.Name}";
// Use the type as-is, preserving the nullability annotation from the source
string typeExpression = SyntaxNodeUtility.GetRenderedTypeExpression(this.Type, false);

// Special case: if the type is exactly System.Object (not a nullable object), make it nullable
// This is because object parameters are typically nullable in the context of Durable Functions
if (this.Type.SpecialType == SpecialType.System_Object && this.Type.NullableAnnotation != NullableAnnotation.Annotated)
{
typeExpression = "object?";
}

return $"{typeExpression} {this.Name}";
}
}
}
17 changes: 16 additions & 1 deletion src/Generators/DurableTaskSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,21 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableTaskTypeIn

static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction activity)
{
sourceBuilder.AppendLine($@"
if (activity.ReturnsVoid)
{
sourceBuilder.AppendLine($@"
/// <summary>
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task Call{activity.Name}Async(this TaskOrchestrationContext ctx, {activity.Parameter}, TaskOptions? options = null)
{{
return ctx.CallActivityAsync(""{activity.Name}"", {activity.Parameter.Name}, options);
}}");
}
else
{
sourceBuilder.AppendLine($@"
/// <summary>
/// Calls the <see cref=""{activity.FullTypeName}""/> activity.
/// </summary>
Expand All @@ -374,6 +388,7 @@ static void AddActivityCallMethod(StringBuilder sourceBuilder, DurableFunction a
{{
return ctx.CallActivityAsync<{activity.ReturnType}>(""{activity.Name}"", {activity.Parameter.Name}, options);
}}");
}
}

static void AddActivityFunctionDeclaration(StringBuilder sourceBuilder, DurableTaskTypeInfo activity)
Expand Down
73 changes: 73 additions & 0 deletions test/Generators.Tests/AzureFunctionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,79 @@ await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
isDurableFunctions: true);
}

[Fact]
public async Task Activities_SimpleFunctionTrigger_VoidReturn()
{
string code = @"
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

public class Activities
{
[Function(nameof(FlakeyActivity))]
public static void FlakeyActivity([ActivityTrigger] object _)
{
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
}
}";

string expectedOutput = TestHelpers.WrapAndFormat(
GeneratedClassName,
methodList: @"
/// <summary>
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
{
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
}",
isDurableFunctions: true);

await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
GeneratedFileName,
code,
expectedOutput,
isDurableFunctions: true);
}

[Fact]
public async Task Activities_SimpleFunctionTrigger_TaskReturn()
{
string code = @"
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.DurableTask;

public class Activities
{
[Function(nameof(FlakeyActivity))]
public static Task FlakeyActivity([ActivityTrigger] object _)
{
throw new System.ApplicationException(""Kah-BOOOOM!!!"");
}
}";

string expectedOutput = TestHelpers.WrapAndFormat(
GeneratedClassName,
methodList: @"
/// <summary>
/// Calls the <see cref=""Activities.FlakeyActivity""/> activity.
/// </summary>
/// <inheritdoc cref=""TaskOrchestrationContext.CallActivityAsync(TaskName, object?, TaskOptions?)""/>
public static Task CallFlakeyActivityAsync(this TaskOrchestrationContext ctx, object? _, TaskOptions? options = null)
{
return ctx.CallActivityAsync(""FlakeyActivity"", _, options);
}",
isDurableFunctions: true);

await TestHelpers.RunTestAsync<DurableTaskSourceGenerator>(
GeneratedFileName,
code,
expectedOutput,
isDurableFunctions: true);
}

/// <summary>
/// Verifies that using the class-based activity syntax generates a <see cref="TaskOrchestrationContext"/>
/// extension method as well as an <see cref="ActivityTriggerAttribute"/> function definition.
Expand Down
Loading