Skip to content

Commit

Permalink
Handling tasks that return internal types
Browse files Browse the repository at this point in the history
  • Loading branch information
Elad Zelingher committed Sep 7, 2014
1 parent 05c17b5 commit d3a678e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
44 changes: 43 additions & 1 deletion src/WampSharp.Tests/Api/RpcServerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Moq;
using System.Threading.Tasks;
using Moq;
using NUnit.Framework;
using WampSharp.Rpc;
using WampSharp.Tests.TestHelpers;
Expand All @@ -15,6 +16,12 @@ public interface ICalculator
int Square(int x);
}

public interface INumberProcessor
{
[WampRpcMethod("test/square")]
Task ProcessNumber(int x);
}

[Test]
public void RequestContextIsSet()
{
Expand Down Expand Up @@ -44,5 +51,40 @@ public void RequestContextIsSet()

Assert.That(context.SessionId, Is.EqualTo(channel.GetMonitor().SessionId));
}

#if NET45

[Test]
public void AsyncAwaitTaskWork()
{
WampPlayground playground = new WampPlayground();

IWampHost host = playground.Host;

WampRequestContext context = null;

Mock<INumberProcessor> mock = new Mock<INumberProcessor>();

mock.Setup(x => x.ProcessNumber(It.IsAny<int>()))
.Returns(async (int x) =>
{
});

host.HostService(mock.Object);

host.Open();

IWampChannel<MockRaw> channel = playground.CreateNewChannel();

channel.Open();

INumberProcessor proxy = channel.GetRpcProxy<INumberProcessor>();

Task task = proxy.ProcessNumber(4);

mock.Verify(x => x.ProcessNumber(4));
}

#endif
}
}
49 changes: 45 additions & 4 deletions src/WampSharp/Rpc/Client/TaskExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
using System;
using System.Reflection;
using System.Threading.Tasks;
using WampSharp.Core.Utilities;

namespace WampSharp.Rpc.Client
{
internal static class TaskExtensions
{
private static readonly MethodInfo mCastTask = GetCastTaskMethod();
private static readonly MethodInfo mCastTaskToGenericTask = GetCastTaskToGenericTaskMethod();
private static readonly MethodInfo mCastToNonGenericTask = GetCastGenericTaskToNonGenericMethod();

private static MethodInfo GetCastTaskMethod()
private static MethodInfo GetCastGenericTaskToNonGenericMethod()
{
return typeof(TaskExtensions).GetMethod("InnerCastTask",
BindingFlags.Static | BindingFlags.NonPublic);
}

private static MethodInfo GetCastTaskToGenericTaskMethod()
{
return typeof(TaskExtensions).GetMethod("InternalCastTask",
BindingFlags.Static | BindingFlags.NonPublic);
}

public static Task Cast(this Task<object> task, Type taskType)
{
return (Task)mCastTask.MakeGenericMethod(taskType).Invoke(null, new object[] { task });
return (Task)mCastTaskToGenericTask.MakeGenericMethod(taskType).Invoke(null, new object[] { task });
}

private static Task<T> InternalCastTask<T>(Task<object> task)
Expand All @@ -39,7 +47,12 @@ public static Task<object> CastTask(this Task task)
}
else
{
result = InnerCastTask((dynamic)task);
Type underlyingType = UnwrapReturnType(task.GetType());

MethodInfo method =
mCastToNonGenericTask.MakeGenericMethod(underlyingType);

result = (Task<object>) method.Invoke(null, new object[] {task});
}

return result;
Expand Down Expand Up @@ -71,5 +84,33 @@ private static TResult ContinueWithSafeCallback<TTask, TResult>(TTask task, Func

return result;
}

/// <summary>
/// Unwraps the return type of a given method.
/// </summary>
/// <param name="returnType">The given return type.</param>
/// <returns>The unwrapped return type.</returns>
/// <example>
/// void, Task -> object
/// Task{string} -> string
/// int -> int
/// </example>
public static Type UnwrapReturnType(Type returnType)
{
if (returnType == typeof(void) || returnType == typeof(Task))
{
return typeof(object);
}

Type taskType =
returnType.GetClosedGenericTypeImplementation(typeof(Task<>));

if (taskType != null)
{
return returnType.GetGenericArguments()[0];
}

return returnType;
}
}
}
15 changes: 1 addition & 14 deletions src/WampSharp/Rpc/Client/WampRpcSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,7 @@ public WampRpcCall Serialize(MethodInfo method, object[] arguments)

private Type ExtractReturnType(Type returnType)
{
if (returnType == typeof (void) || returnType == typeof(Task))
{
return typeof (object);
}

Type taskType =
returnType.GetClosedGenericTypeImplementation(typeof (Task<>));

if (taskType != null)
{
return returnType.GetGenericArguments()[0];
}

return returnType;
return TaskExtensions.UnwrapReturnType(returnType);
}
}
}

0 comments on commit d3a678e

Please sign in to comment.