Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CredentialVault path fix and unit tests #34

Closed
wants to merge 7 commits into from
14 changes: 7 additions & 7 deletions src/OneDrive.Sdk.Authentication.Common/OAuthHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public async Task<AccountSession> RedeemAuthorizationCodeAsync(
returnUrl,
scopes,
clientSecret),
httpProvider).ConfigureAwait(false);
httpProvider, clientId).ConfigureAwait(false);
}

public async Task<AccountSession> RedeemRefreshTokenAsync(
Expand Down Expand Up @@ -262,21 +262,21 @@ public async Task<AccountSession> RedeemRefreshTokenAsync(

if (httpProvider == null)
{
return await this.SendTokenRequestAsync(tokenRequestBody);
return await this.SendTokenRequestAsync(tokenRequestBody, clientId);
}

return await this.SendTokenRequestAsync(tokenRequestBody, httpProvider).ConfigureAwait(false);
return await this.SendTokenRequestAsync(tokenRequestBody, httpProvider, clientId).ConfigureAwait(false);
}

public async Task<AccountSession> SendTokenRequestAsync(string requestBodyString)
public async Task<AccountSession> SendTokenRequestAsync(string requestBodyString, string clientId)
{
using (var httpProvider = new HttpProvider())
{
return await this.SendTokenRequestAsync(requestBodyString, httpProvider).ConfigureAwait(false);
return await this.SendTokenRequestAsync(requestBodyString, httpProvider, clientId).ConfigureAwait(false);
}
}

public async Task<AccountSession> SendTokenRequestAsync(string requestBodyString, IHttpProvider httpProvider)
public async Task<AccountSession> SendTokenRequestAsync(string requestBodyString, IHttpProvider httpProvider, string clientId)
{
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, OAuthConstants.MicrosoftAccountTokenServiceUrl);

Expand All @@ -292,7 +292,7 @@ public async Task<AccountSession> SendTokenRequestAsync(string requestBodyString
if (responseValues != null)
{
OAuthErrorHandler.ThrowIfError(responseValues);
return new AccountSession(responseValues);
return new AccountSession(responseValues, clientId);
}

throw new ServiceException(
Expand Down
83 changes: 64 additions & 19 deletions src/OneDrive.Sdk.Authentication.Desktop/CredentialVault.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,28 @@ public class CredentialVault : ICredentialVault

private string VaultFileName => $"{VaultNamePrefix}_{this.ClientId}.dat";

private readonly byte[] _additionalEntropy;
private IProtectedData protectedData;

public CredentialVault(string clientId)
private IFile fileSystem;

public CredentialVault(string clientId, byte[] secondaryKeyBytes = null, IFile fileSystem = null, IProtectedData protectedData = null)
{
if (string.IsNullOrEmpty(clientId))
{
throw new ArgumentException("You must provide a clientId");
}

this.ClientId = clientId;
this._additionalEntropy = null;
}

public CredentialVault(string clientId, byte[] secondaryKeyBytes) : this(clientId)
{
this._additionalEntropy = secondaryKeyBytes;
this.protectedData = protectedData ?? new ProtectedDataDefault(secondaryKeyBytes);
this.fileSystem = fileSystem ?? new FileSystem();
}

public void AddCredentialCacheToVault(CredentialCache credentialCache)
{
this.DeleteStoredCredentialCache();
var cacheBlob = this.Protect(credentialCache.GetCacheBlob());
using (var outStream = File.OpenWrite(this.VaultFileName))

var cacheBlob = this.protectedData.Protect(credentialCache.GetCacheBlob());
using (var outStream = fileSystem.OpenWrite(this.GetVaultFilePath()))
{
outStream.Write(cacheBlob, 0, cacheBlob.Length);
}
Expand All @@ -49,9 +47,9 @@ public bool RetrieveCredentialCache(CredentialCache credentialCache)
{
var filePath = this.GetVaultFilePath();

if (File.Exists(filePath))
if (fileSystem.Exists(filePath))
{
credentialCache.InitializeCacheFromBlob(this.Unprotect(File.ReadAllBytes(filePath)));
credentialCache.InitializeCacheFromBlob(this.protectedData.Unprotect(fileSystem.ReadAllBytes(filePath)));
return true;
}

Expand All @@ -62,9 +60,9 @@ public bool DeleteStoredCredentialCache()
{
var filePath = this.GetVaultFilePath();

if (File.Exists(filePath))
if (fileSystem.Exists(filePath))
{
File.Delete(filePath);
fileSystem.Delete(filePath);
return true;
}

Expand All @@ -76,14 +74,61 @@ private string GetVaultFilePath()
return Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), this.VaultFileName);
}

private byte[] Protect(byte[] data)
public interface IFile
{
return ProtectedData.Protect(data, this._additionalEntropy, DataProtectionScope.CurrentUser);
Stream OpenWrite(string path);
bool Exists(string path);
void Delete(string path);
byte[] ReadAllBytes(string path);
}

private byte[] Unprotect(byte[] protectedData)
private class FileSystem : IFile
{
return ProtectedData.Unprotect(protectedData, this._additionalEntropy, DataProtectionScope.CurrentUser);
public void Delete(string path)
{
File.Delete(path);
}

public bool Exists(string path)
{
return File.Exists(path);
}

public Stream OpenWrite(string path)
{
return File.OpenWrite(path);
}

public byte[] ReadAllBytes(string path)
{
return File.ReadAllBytes(path);
}
}

public interface IProtectedData
{
byte[] Protect(byte[] data);
byte[] Unprotect(byte[] protectedData);
}

public class ProtectedDataDefault : IProtectedData
{
public ProtectedDataDefault(byte[] additionalEntropy = null)
{
this._additionalEntropy = additionalEntropy;
}

private readonly byte[] _additionalEntropy;

public byte[] Protect(byte[] data)
{
return ProtectedData.Protect(data, this._additionalEntropy, DataProtectionScope.CurrentUser);
}

public byte[] Unprotect(byte[] protectedData)
{
return ProtectedData.Unprotect(protectedData, this._additionalEntropy, DataProtectionScope.CurrentUser);
}
}
}
}
106 changes: 106 additions & 0 deletions tests/Test.OneDrive.Sdk.Authentication.Desktop/CredentialVaultTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using Microsoft.OneDrive.Sdk.Authentication;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using System.Linq;
using static Microsoft.OneDrive.Sdk.Authentication.CredentialVault;
using System.IO;
using Moq;

namespace Test.OneDrive.Sdk.Authentication.Desktop
{
[TestClass]
public class CredentialVaultTests
{
InMemoryFileSystem fileSystem = new InMemoryFileSystem();
CredentialVault credentialVault;
CredentialCache credentialCache;
AccountSession accountSession;
CredentialCache retrievedCache;

[TestInitialize]
public void TestInitialize()
{
var dict = new Dictionary<string, string>()
{
{ OAuthConstants.AccessTokenKeyName, "token"},
{ OAuthConstants.UserIdKeyName, "myUserId" }
};
accountSession = new AccountSession(dict, "myClientId");
credentialCache = new CredentialCache();
credentialCache.AddToCache(accountSession);
credentialVault = new CredentialVault("myClientId", fileSystem: fileSystem);
credentialVault.AddCredentialCacheToVault(credentialCache);
retrievedCache = new CredentialCache();
}

[TestMethod]
public void CredentialVaultTests_AddRetrieveSucceeds()
{
Assert.AreEqual(1, fileSystem.fs.Count, "File system should be storing only one cache");
bool success = credentialVault.RetrieveCredentialCache(retrievedCache);
Assert.IsTrue(success, "CredentialCache not found in vault.");
AccountSession retrievedAccountSession = retrievedCache.GetResultFromCache("myClientId", "myUserId");
Assert.IsNotNull(retrievedAccountSession, "AccountSession is null.");
Assert.AreEqual("token", retrievedAccountSession.AccessToken, "AccountSession not stored properly.");
}

[TestMethod]
public void CredentialVaultTests_DeleteSucceeds()
{
bool success1 = credentialVault.DeleteStoredCredentialCache();
Assert.IsTrue(success1, "CredentialCache not found in vault.");
bool success2 = credentialVault.RetrieveCredentialCache(retrievedCache);
Assert.IsFalse(success2, "CredentialCache is not erased from vault.");
AccountSession retrievedAccountSession = retrievedCache.GetResultFromCache("myClientId", "myUserId");
Assert.IsNull(retrievedAccountSession, "AccountSession must be null.");
}

[TestMethod]
public void CredentialVaultTests_ProtectMethodCalled()
{
var mockProtectedData = new Mock<IProtectedData>();
credentialVault = new CredentialVault("myClientId", null, fileSystem, mockProtectedData.Object);
credentialVault.AddCredentialCacheToVault(credentialCache);
mockProtectedData.Verify(
mock => mock.Protect(
It.Is<byte[]>(b => b.SequenceEqual(credentialCache.GetCacheBlob()))),
Times.Once(),
"Protect method not called with CredentialCache as parameter.");
}

[TestMethod]
public void CredentialVaultTests_ProtectMethodTransformsData()
{
ProtectedDataDefault protectedData = new ProtectedDataDefault();
byte[] b = { 1, 2, 3 };
var c = protectedData.Protect(b);
Assert.IsFalse(b.SequenceEqual(c),"Protect method does not transform data.");
}
}

internal class InMemoryFileSystem : IFile
{
public Dictionary<string, byte[]> fs = new Dictionary<string, byte[]>();

public void Delete(string path)
{
fs.Remove(path);
}

public bool Exists(string path)
{
return fs.ContainsKey(path);
}

public Stream OpenWrite(string path)
{
fs.Add(path, new byte[550]);
return new MemoryStream(fs[path]);
}

public byte[] ReadAllBytes(string path)
{
return fs[path];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,42 @@ private async Task AuthenticateWithRefreshToken(AccountSession refreshedAccountS
"Unexpected cached refresh token.");
}
}

[TestMethod]
public async Task AuthenticateUserAsync_ClientIdAddedToAccountSession()
{
using (var responseStream = new MemoryStream())
using (var streamContent = new StreamContent(responseStream))
{
httpResponseMessage.Content = streamContent;

var tokenResponseDictionary = new Dictionary<string, string> { { "code", "code" } };

this.webAuthenticationUi.Setup(webUi => webUi.AuthenticateAsync(
It.Is<Uri>(uri => uri.ToString().Contains("response_type=code")),
It.Is<Uri>(uri => uri.ToString().Equals(MsaAuthenticationProviderTests.ReturnUrl))))
.Returns(
Task.FromResult<IDictionary<string, string>>(tokenResponseDictionary));

this.httpProvider.Setup(
provider => provider.SendAsync(
It.Is<HttpRequestMessage>(
request => request.RequestUri.ToString().Equals(OAuthConstants.MicrosoftAccountTokenServiceUrl))))
.Returns(Task.FromResult<HttpResponseMessage>(httpResponseMessage));

this.serializer.Setup(
serializer => serializer.DeserializeObject<IDictionary<string, string>>(It.IsAny<Stream>()))
.Returns(new Dictionary<string, string>
{
{ OAuthConstants.AccessTokenKeyName, "token" },
{ OAuthConstants.UserIdKeyName, UserId }
});

await this.authenticationProvider.AuthenticateUserAsync(this.httpProvider.Object).ConfigureAwait(false);

var accountSession = this.authenticationProvider.CredentialCache.GetResultFromCache(ClientId, UserId);
Assert.IsNotNull(accountSession, "AccountSession not found in cache. AccountSession.ClientId may be null or have an incorrect value.");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
<Reference Include="System" />
<Reference Include="System.Net" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Security" />
</ItemGroup>
<Choose>
<When Condition="('$(VisualStudioVersion)' == '10.0' or '$(VisualStudioVersion)' == '') and '$(TargetFrameworkVersion)' == 'v3.5'">
Expand All @@ -83,6 +84,7 @@
<ItemGroup>
<Compile Include="AccountSessionTests.cs" />
<Compile Include="AdalAuthenticationProviderTestBase.cs" />
<Compile Include="CredentialVaultTests.cs" />
<Compile Include="DiscoveryServiceHelperTests.cs" />
<Compile Include="AdalAuthenticationProviderTests.cs" />
<Compile Include="AdalCredentialCacheTests.cs" />
Expand Down