diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 6f3c87291..40d59f4ae 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ @@ -21,6 +21,8 @@ namespace Snowflake.Data.Tests.IntegrationTests using Snowflake.Data.Tests.Mock; using System.Runtime.InteropServices; using System.Net.Http; + using Snowflake.Data.Core.CredentialManager; + using Snowflake.Data.Core.CredentialManager.Infrastructure; [TestFixture] class SFConnectionIT : SFBaseTest @@ -1046,6 +1048,71 @@ public void TestSSOConnectionTimeoutAfter10s() Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (waitSeconds + 5) * 1000); } + [Test] + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestSSOConnectionWithTokenCaching() + { + /* + * This test checks that the connector successfully stores an SSO token and uses it for authentication if it exists + * 1. Login normally using external browser with allow_sso_token_caching enabled + * 2. Login again, this time without a browser, as the connector should be using the SSO token retrieved from step 1 + */ + + using (IDbConnection conn = new SnowflakeDbConnection()) + { + // Set the allow_sso_token_caching property to true to enable token caching + // The specified user should be configured for SSO + conn.ConnectionString + = ConnectionStringWithoutAuth + + $";authenticator=externalbrowser;user={testConfig.user};allow_sso_token_caching=true;"; + + // Authenticate to retrieve and store the token if doesn't exist or invalid + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + // Authenticate using the SSO token (the connector will automatically use the token and a browser should not pop-up in this step) + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + conn.Close(); + Assert.AreEqual(ConnectionState.Closed, conn.State); + } + } + + [Test] + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestSSOConnectionWithInvalidCachedToken() + { + /* + * This test checks that the connector will attempt to re-authenticate using external browser if the token retrieved from the cache is invalid + * 1. Create a credential manager and save credentials for the user with a wrong token + * 2. Open a connection which initially should try to use the token and then switch to external browser when the token fails + */ + + using (IDbConnection conn = new SnowflakeDbConnection()) + { + // Set the allow_sso_token_caching property to true to enable token caching + conn.ConnectionString + = ConnectionStringWithoutAuth + + $";authenticator=externalbrowser;user={testConfig.user};allow_sso_token_caching=true;"; + + // Create a credential manager and save a wrong token for the test user + var key = SFCredentialManagerFactory.BuildCredentialKey(testConfig.host, testConfig.user, TokenType.IdToken); + var credentialManager = SFCredentialManagerInMemoryImpl.Instance; + credentialManager.SaveCredentials(key, "wrongToken"); + + // Use the credential manager with the wrong token + SFCredentialManagerFactory.SetCredentialManager(credentialManager); + + // Open a connection which should switch to external browser after trying to connect using the wrong token + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + // Switch back to the default credential manager + SFCredentialManagerFactory.UseDefaultCredentialManager(); + } + } + [Test] [Ignore("This test requires manual interaction and therefore cannot be run in CI")] public void TestSSOConnectionWithWrongUser() @@ -2311,6 +2378,40 @@ public void TestOpenAsyncThrowExceptionWhenOperationIsCancelled() } } + [Test] + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestSSOConnectionWithTokenCachingAsync() + { + /* + * This test checks that the connector successfully stores an SSO token and uses it for authentication if it exists + * 1. Login normally using external browser with allow_sso_token_caching enabled + * 2. Login again, this time without a browser, as the connector should be using the SSO token retrieved from step 1 + */ + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + // Set the allow_sso_token_caching property to true to enable token caching + // The specified user should be configured for SSO + conn.ConnectionString + = ConnectionStringWithoutAuth + + $";authenticator=externalbrowser;user={testConfig.user};allow_sso_token_caching=true;"; + + // Authenticate to retrieve and store the token if doesn't exist or invalid + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + // Authenticate using the SSO token (the connector will automatically use the token and a browser should not pop-up in this step) + connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + connectTask = conn.CloseAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Closed, conn.State); + } + } + [Test] public void TestCloseSessionWhenGarbageCollectorFinalizesConnection() { diff --git a/Snowflake.Data.Tests/Mock/MockExternalBrowser.cs b/Snowflake.Data.Tests/Mock/MockExternalBrowser.cs new file mode 100644 index 000000000..147a2d1b1 --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockExternalBrowser.cs @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Core; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace Snowflake.Data.Tests.Mock +{ + + class MockExternalBrowserRestRequester : IMockRestRequester + { + public string ProofKey { get; set; } + public string SSOUrl { get; set; } + + public T Get(IRestRequest request) + { + throw new System.NotImplementedException(); + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + throw new System.NotImplementedException(); + } + + public T Post(IRestRequest postRequest) + { + return Task.Run(async () => await (PostAsync(postRequest, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task PostAsync(IRestRequest postRequest, CancellationToken cancellationToken) + { + SFRestRequest sfRequest = (SFRestRequest)postRequest; + if (sfRequest.jsonBody is AuthenticatorRequest) + { + if (string.IsNullOrEmpty(SSOUrl)) + { + var body = (AuthenticatorRequest)sfRequest.jsonBody; + var port = body.Data.BrowserModeRedirectPort; + SSOUrl = $"http://localhost:{port}/?token=mockToken"; + } + + // authenticator + var authnResponse = new AuthenticatorResponse + { + success = true, + data = new AuthenticatorResponseData + { + proofKey = ProofKey, + ssoUrl = SSOUrl, + } + }; + + return Task.FromResult((T)(object)authnResponse); + } + else + { + // login + var loginResponse = new LoginResponse + { + success = true, + data = new LoginResponseData + { + sessionId = "", + token = "", + masterToken = "", + masterValidityInSeconds = 0, + authResponseSessionInfo = new SessionInfo + { + databaseName = "", + schemaName = "", + roleName = "", + warehouseName = "", + } + } + }; + + return Task.FromResult((T)(object)loginResponse); + } + } + + public HttpResponseMessage Get(IRestRequest request) + { + throw new System.NotImplementedException(); + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + throw new System.NotImplementedException(); + } + + public void setHttpClient(HttpClient httpClient) + { + // Nothing to do + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs new file mode 100644 index 000000000..83d4d84fe --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + using Mono.Unix; + using Mono.Unix.Native; + using Moq; + using NUnit.Framework; + using Snowflake.Data.Client; + using Snowflake.Data.Core.CredentialManager; + using Snowflake.Data.Core.CredentialManager.Infrastructure; + using Snowflake.Data.Core.Tools; + using System; + using System.IO; + using System.Runtime.InteropServices; + using System.Security; + + public abstract class SFBaseCredentialManagerTest + { + protected ISnowflakeCredentialManager _credentialManager; + + [Test] + public void TestSavingAndRemovingCredentials() + { + // arrange + var key = "mockKey"; + var expectedToken = "token"; + + // act + _credentialManager.SaveCredentials(key, expectedToken); + + // assert + Assert.AreEqual(expectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + + [Test] + public void TestSavingCredentialsForAnExistingKey() + { + // arrange + var key = "mockKey"; + var firstExpectedToken = "mockToken1"; + var secondExpectedToken = "mockToken2"; + + try + { + // act + _credentialManager.SaveCredentials(key, firstExpectedToken); + + // assert + Assert.AreEqual(firstExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.SaveCredentials(key, secondExpectedToken); + + // assert + Assert.AreEqual(secondExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + catch (Exception ex) + { + // assert + Assert.Fail("Should not throw an exception: " + ex.Message); + } + } + + [Test] + public void TestRemovingCredentialsForKeyThatDoesNotExist() + { + // arrange + var key = "mockKey"; + + try + { + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + catch (Exception ex) + { + // assert + Assert.Fail("Should not throw an exception: " + ex.Message); + } + } + } + + [TestFixture] + [Platform("Win")] + public class SFNativeCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerWindowsNativeImpl.Instance; + } + } + + [TestFixture] + public class SFInMemoryCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerInMemoryImpl.Instance; + } + } + + [TestFixture] + public class SFFileCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerFileImpl.Instance; + } + } + + [TestFixture] + class SFCredentialManagerTest + { + ISnowflakeCredentialManager _credentialManager; + + [ThreadStatic] + private static Mock t_fileOperations; + + [ThreadStatic] + private static Mock t_directoryOperations; + + [ThreadStatic] + private static Mock t_unixOperations; + + [ThreadStatic] + private static Mock t_environmentOperations; + + private const string CustomJsonDir = "testdirectory"; + + private static readonly string s_customJsonPath = Path.Combine(CustomJsonDir, SFCredentialManagerFileImpl.CredentialCacheFileName); + + [SetUp] public void SetUp() + { + t_fileOperations = new Mock(); + t_directoryOperations = new Mock(); + t_unixOperations = new Mock(); + t_environmentOperations = new Mock(); + SFCredentialManagerFactory.SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); + } + + [TearDown] public void TearDown() + { + SFCredentialManagerFactory.UseDefaultCredentialManager(); + } + + [Test] + public void TestUsingDefaultCredentialManager() + { + // arrange + SFCredentialManagerFactory.UseDefaultCredentialManager(); + + // act + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // assert + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.IsInstanceOf(_credentialManager); + } + else + { + Assert.IsInstanceOf(_credentialManager); + } + } + + [Test] + public void TestSettingCustomCredentialManager() + { + // arrange + SFCredentialManagerFactory.SetCredentialManager(SFCredentialManagerFileImpl.Instance); + + // act + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // assert + Assert.IsInstanceOf(_credentialManager); + } + + [Test] + public void TestThatThrowsErrorWhenCacheFileIsNotCreated() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_directoryOperations + .Setup(d => d.Exists(s_customJsonPath)) + .Returns(false); + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(-1); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Failed to create the JSON token cache file")); + } + + [Test] + public void TestThatThrowsErrorWhenCacheFileCanBeAccessedByOthers() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(0); + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.AllPermissions); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Permission for the JSON token cache file should contain only the owner access")); + } + + [Test] + public void TestThatJsonFileIsCheckedIfAlreadyExists() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(0); + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.UserReadWriteExecute); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(false) + .Returns(true); + + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + _credentialManager.SaveCredentials("key", "token"); + + // assert + t_fileOperations.Verify(f => f.Exists(s_customJsonPath), Times.Exactly(2)); + } + + [Test] + public void TestThatJsonFileIsCheckedIfOwnedByCurrentUser() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CheckFileIsNotOwnedByCurrentUser(s_customJsonPath)) + .Returns(true); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(true); + + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.GetCredentials("key")); + + // assert + Assert.That(thrown.Message, Does.Contain("Attempting to read a file not owned by the effective user of the current process")); + } + + [Test] + public void TestThatJsonFileIsCheckedIfOwnedByCurrentGroup() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CheckFileIsNotOwnedByCurrentGroup(s_customJsonPath)) + .Returns(true); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(true); + + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.GetCredentials("key")); + + // assert + Assert.That(thrown.Message, Does.Contain("Attempting to read a file not owned by the effective group of the current process")); + } + + [Test] + public void TestThatJsonFileIsCheckedIfItHasTooBroadPermissions() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CheckFileHasAnyOfPermissions(s_customJsonPath, FileAccessPermissions.GroupReadWriteExecute | FileAccessPermissions.OtherReadWriteExecute)) + .Returns(true); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(true); + + SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.GetCredentials("key")); + + // assert + Assert.That(thrown.Message, Does.Contain("Attempting to read a file with too broad permissions assigned")); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFExternalBrowserTest.cs b/Snowflake.Data.Tests/UnitTests/SFExternalBrowserTest.cs new file mode 100644 index 000000000..7816e335f --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SFExternalBrowserTest.cs @@ -0,0 +1,317 @@ +using Moq; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.CredentialManager; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Core.Tools; +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using System.Web; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + class SFExternalBrowserTest + { + [ThreadStatic] + private static Mock t_browserOperations; + + private static HttpClient s_httpClient = new HttpClient(); + + [SetUp] + public void BeforeEach() + { + t_browserOperations = new Mock(); + } + + [Test] + public void TestDefaultAuthentication() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + s_httpClient.GetAsync(url); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Once()); + } catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + + [Test] + public void TestConsoleLogin() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + Uri uri = new Uri(url); + var port = HttpUtility.ParseQueryString(uri.Query).Get("browser_mode_redirect_port"); + var browserUrl = $"http://localhost:{port}/?token=mockToken"; + s_httpClient.GetAsync(browserUrl); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("disable_console_login=false;account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Once()); + } + catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + + [Test] + public void TestSSOToken() + { + try + { + var user = "test"; + var host = $"{user}.okta.com"; + var key = SFCredentialManagerFactory.BuildCredentialKey(host, user, TokenType.IdToken); + var credentialManager = SFCredentialManagerInMemoryImpl.Instance; + credentialManager.SaveCredentials(key, "mockIdToken"); + SFCredentialManagerFactory.SetCredentialManager(credentialManager); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + SSOUrl = "https://www.mockSSOUrl.com", + }; + var sfSession = new SFSession($"allow_sso_token_caching=true;account=test;user={user};password=test;authenticator=externalbrowser;host={host}", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Never()); + } + catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + + [Test] + public void TestThatThrowsTimeoutErrorWhenNoBrowserResponse() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback(async (string url) => { + await Task.Delay(1000).ContinueWith(_ => + { + s_httpClient.GetAsync(url); + }); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession($"browser_response_timeout=0;account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + Assert.Fail("Should fail"); + } + catch (SnowflakeDbException e) + { + Assert.AreEqual(SFError.BROWSER_RESPONSE_TIMEOUT.GetAttribute().errorCode, e.ErrorCode); + } + } + + [Test] + public void TestThatThrowsErrorWhenUrlDoesNotMatchRegex() + { + try + { + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + SSOUrl = "non-matching-regex.com" + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + Assert.Fail("Should fail"); + } + catch (SnowflakeDbException e) + { + Assert.AreEqual(SFError.INVALID_BROWSER_URL.GetAttribute().errorCode, e.ErrorCode); + } + } + + [Test] + public void TestThatThrowsErrorWhenUrlIsNotWellFormedUriString() + { + try + { + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + SSOUrl = "http://localhost:123/?token=mockToken\\\\" + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + Assert.Fail("Should fail"); + } + catch (SnowflakeDbException e) + { + Assert.AreEqual(SFError.INVALID_BROWSER_URL.GetAttribute().errorCode, e.ErrorCode); + } + } + + [Test] + public void TestThatThrowsErrorWhenBrowserRequestMethodIsNotGet() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + s_httpClient.PostAsync(url, new StringContent("")); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + Assert.Fail("Should fail"); + } + catch (SnowflakeDbException e) + { + Assert.AreEqual(SFError.BROWSER_RESPONSE_WRONG_METHOD.GetAttribute().errorCode, e.ErrorCode); + } + } + + [Test] + public void TestThatThrowsErrorWhenBrowserRequestHasInvalidQuery() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + var urlWithoutQuery = url.Substring(0, url.IndexOf("?token=")); + s_httpClient.GetAsync(urlWithoutQuery); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + sfSession.Open(); + Assert.Fail("Should fail"); + } + catch (SnowflakeDbException e) + { + Assert.AreEqual(SFError.BROWSER_RESPONSE_INVALID_PREFIX.GetAttribute().errorCode, e.ErrorCode); + } + } + + [Test] + public void TestDefaultAuthenticationAsync() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + s_httpClient.GetAsync(url); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + Task connectTask = sfSession.OpenAsync(CancellationToken.None); + connectTask.Wait(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Once()); + } + catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + + [Test] + public void TestConsoleLoginAsync() + { + try + { + t_browserOperations + .Setup(b => b.OpenUrl(It.IsAny())) + .Callback((string url) => { + Uri uri = new Uri(url); + var port = HttpUtility.ParseQueryString(uri.Query).Get("browser_mode_redirect_port"); + var browserUrl = $"http://localhost:{port}/?token=mockToken"; + s_httpClient.GetAsync(browserUrl); + }); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + }; + var sfSession = new SFSession("disable_console_login=false;account=test;user=test;password=test;authenticator=externalbrowser;host=test.okta.com", null, restRequester, t_browserOperations.Object); + Task connectTask = sfSession.OpenAsync(CancellationToken.None); + connectTask.Wait(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Once()); + } + catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + + [Test] + public void TestSSOTokenAsync() + { + try + { + var user = "test"; + var host = $"{user}.okta.com"; + var key = SFCredentialManagerFactory.BuildCredentialKey(host, user, TokenType.IdToken); + var credentialManager = SFCredentialManagerInMemoryImpl.Instance; + credentialManager.SaveCredentials(key, "mockIdToken"); + SFCredentialManagerFactory.SetCredentialManager(credentialManager); + + var restRequester = new Mock.MockExternalBrowserRestRequester() + { + ProofKey = "mockProofKey", + SSOUrl = "https://www.mockSSOUrl.com", + }; + var sfSession = new SFSession($"allow_sso_token_caching=true;account=test;user={user};password=test;authenticator=externalbrowser;host={host}", null, restRequester, t_browserOperations.Object); + Task connectTask = sfSession.OpenAsync(CancellationToken.None); + connectTask.Wait(); + + t_browserOperations.Verify(b => b.OpenUrl(It.IsAny()), Times.Never()); + } + catch (SnowflakeDbException e) + { + Assert.Fail("Should pass without exception", e); + } + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index a57a9fb74..dffe36bd8 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2019 Snowflake Computing Inc. All rights reserved. */ @@ -166,6 +166,21 @@ public void TestResolveConnectionArea(string host, string expectedMessage) Assert.AreEqual(expectedMessage, message); } + [Test] + [TestCase("true")] + [TestCase("false")] + public void TestValidateAllowSSOTokenCachingProperty(string expectedAllowSsoTokenCaching) + { + // arrange + var connectionString = $"ACCOUNT=account;USER=test;PASSWORD=test;ALLOW_SSO_TOKEN_CACHING={expectedAllowSsoTokenCaching}"; + + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.AreEqual(expectedAllowSsoTokenCaching, properties[SFSessionProperty.ALLOW_SSO_TOKEN_CACHING]); + } + public static IEnumerable ConnectionStringTestCases() { string defAccount = "testaccount"; @@ -222,7 +237,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; @@ -258,7 +274,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; var testCaseWithProxySettings = new TestCase() @@ -296,7 +313,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};useProxy=true;proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -336,7 +354,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -375,7 +394,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; var testCaseWithIncludeRetryReason = new TestCase() @@ -411,7 +431,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; var testCaseWithDisableQueryContextCache = new TestCase() @@ -446,7 +467,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true" @@ -483,7 +505,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLE_CONSOLE_LOGIN=false" @@ -522,7 +545,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; var testCaseUnderscoredAccountName = new TestCase() @@ -558,7 +582,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; var testCaseUnderscoredAccountNameWithEnabledAllowUnderscores = new TestCase() @@ -594,9 +619,11 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; + var testQueryTag = "Test QUERY_TAG 12345"; var testCaseQueryTag = new TestCase() { @@ -632,7 +659,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.ALLOW_SSO_TOKEN_CACHING, DefaultValue(SFSessionProperty.ALLOW_SSO_TOKEN_CACHING) }, } }; diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index 262122b2d..ed889b436 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -6,6 +6,7 @@ using Snowflake.Data.Core; using NUnit.Framework; using Snowflake.Data.Tests.Mock; +using System; namespace Snowflake.Data.Tests.UnitTests { @@ -105,6 +106,48 @@ public void TestThatConfiguresEasyLogging(string configPath) easyLoggingStarter.Verify(starter => starter.Init(configPath)); } + [Test] + public void TestThatIdTokenIsStoredWhenCachingIsEnabled() + { + // arrange + var expectedIdToken = "mockIdToken"; + var connectionString = $"account=account;user=user;password=test;allow_sso_token_caching=true"; + var session = new SFSession(connectionString, null); + LoginResponse authnResponse = new LoginResponse + { + data = new LoginResponseData() + { + idToken = expectedIdToken, + authResponseSessionInfo = new SessionInfo(), + }, + success = true + }; + + // act + session.ProcessLoginResponse(authnResponse); + + // assert + Assert.AreEqual(expectedIdToken, session._idToken); + } + + [Test] + public void TestThatRetriesAuthenticationForInvalidIdToken() + { + // arrange + var connectionString = "account=test;user=test;password=test;allow_sso_token_caching=true"; + var session = new SFSession(connectionString, null); + LoginResponse authnResponse = new LoginResponse + { + code = SFError.ID_TOKEN_INVALID.GetAttribute().errorCode, + message = "", + success = false + }; + + // assert + Assert.Throws(() => session.ProcessLoginResponse(authnResponse)); + } + + [Test] [TestCase(null, "accountDefault", "accountDefault", false)] [TestCase("initial", "initial", "initial", false)] [TestCase("initial", null, "initial", false)] diff --git a/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs index 18f1ff7d7..0c76fff29 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs @@ -17,7 +17,8 @@ public class SFHttpClientPropertiesTest [Test] public void TestConvertToMapOnly2Properties( [Values(true, false)] bool validateDefaultParameters, - [Values(true, false)] bool clientSessionKeepAlive) + [Values(true, false)] bool clientSessionKeepAlive, + [Values(true, false)] bool clientStoreTemporaryCredential) { // arrange var proxyProperties = new SFSessionHttpClientProxyProperties() @@ -32,6 +33,7 @@ public void TestConvertToMapOnly2Properties( { validateDefaultParameters = validateDefaultParameters, clientSessionKeepAlive = clientSessionKeepAlive, + _allowSSOTokenCaching = clientStoreTemporaryCredential, connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, @@ -45,9 +47,10 @@ public void TestConvertToMapOnly2Properties( var parameterMap = properties.ToParameterMap(); // assert - Assert.AreEqual(2, parameterMap.Count); + Assert.AreEqual(3, parameterMap.Count); Assert.AreEqual(validateDefaultParameters, parameterMap[SFSessionParameter.CLIENT_VALIDATE_DEFAULT_PARAMETERS]); Assert.AreEqual(clientSessionKeepAlive, parameterMap[SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE]); + Assert.AreEqual(clientStoreTemporaryCredential, parameterMap[SFSessionParameter.CLIENT_STORE_TEMPORARY_CREDENTIAL]); } [Test] diff --git a/Snowflake.Data/Client/ISnowflakeCredentialManager.cs b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs new file mode 100644 index 000000000..802d8fe21 --- /dev/null +++ b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Client +{ + public interface ISnowflakeCredentialManager + { + string GetCredentials(string key); + + void RemoveCredentials(string key); + + void SaveCredentials(string key, string token); + } +} diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs index e39ec18f8..2cbd89e9d 100644 --- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs @@ -1,18 +1,17 @@ -/* +/* * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. */ using System; -using System.Diagnostics; using System.Net; using System.Net.Sockets; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Log; using Snowflake.Data.Client; using System.Text.RegularExpressions; using System.Collections.Generic; +using Snowflake.Data.Core.CredentialManager; namespace Snowflake.Data.Core.Authenticator { @@ -36,6 +35,8 @@ class ExternalBrowserAuthenticator : BaseAuthenticator, IAuthenticator private string _proofKey; // Event for successful authentication. private ManualResetEvent _successEvent; + // Placeholder in case an exception occurs while extracting the token from the browser response. + private Exception _tokenExtractionException; /// /// Constructor of the External authenticator @@ -44,51 +45,26 @@ class ExternalBrowserAuthenticator : BaseAuthenticator, IAuthenticator internal ExternalBrowserAuthenticator(SFSession session) : base(session, AUTH_NAME) { } + /// async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) { logger.Info("External Browser Authentication"); - int localPort = GetRandomUnusedPort(); - using (var httpListener = GetHttpListener(localPort)) + if (string.IsNullOrEmpty(session._idToken)) { - httpListener.Start(); - - logger.Debug("Get IdpUrl and ProofKey"); - string loginUrl; - if (session._disableConsoleLogin) - { - var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort); - var authenticatorRestResponse = - await session.restRequester.PostAsync( - authenticatorRestRequest, - cancellationToken - ).ConfigureAwait(false); - authenticatorRestResponse.FilterFailedResponse(); - - loginUrl = authenticatorRestResponse.data.ssoUrl; - _proofKey = authenticatorRestResponse.data.proofKey; - } - else - { - _proofKey = GenerateProofKey(); - loginUrl = GetLoginUrl(_proofKey, localPort); - } - - logger.Debug("Open browser"); - StartBrowser(loginUrl); - - logger.Debug("Get the redirect SAML request"); - _successEvent = new ManualResetEvent(false); - httpListener.BeginGetContext(GetContextCallback, httpListener); - var timeoutInSec = int.Parse(session.properties[SFSessionProperty.BROWSER_RESPONSE_TIMEOUT]); - if (!_successEvent.WaitOne(timeoutInSec * 1000)) + int localPort = GetRandomUnusedPort(); + using (var httpListener = GetHttpListener(localPort)) { - logger.Warn("Browser response timeout"); - throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); + httpListener.Start(); + logger.Debug("Get IdpUrl and ProofKey"); + var loginUrl = await GetIdpUrlAndProofKeyAsync(localPort, cancellationToken); + logger.Debug("Open browser"); + StartBrowser(loginUrl); + logger.Debug("Get the redirect SAML request"); + GetRedirectSamlRequest(httpListener); + httpListener.Stop(); } - - httpListener.Stop(); } logger.Debug("Send login request"); @@ -100,70 +76,108 @@ void IAuthenticator.Authenticate() { logger.Info("External Browser Authentication"); - int localPort = GetRandomUnusedPort(); - using (var httpListener = GetHttpListener(localPort)) + if (string.IsNullOrEmpty(session._idToken)) { - httpListener.Start(); - - logger.Debug("Get IdpUrl and ProofKey"); - string loginUrl; - if (session._disableConsoleLogin) - { - var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort); - var authenticatorRestResponse = session.restRequester.Post(authenticatorRestRequest); - authenticatorRestResponse.FilterFailedResponse(); - - loginUrl = authenticatorRestResponse.data.ssoUrl; - _proofKey = authenticatorRestResponse.data.proofKey; - } - else + int localPort = GetRandomUnusedPort(); + using (var httpListener = GetHttpListener(localPort)) { - _proofKey = GenerateProofKey(); - loginUrl = GetLoginUrl(_proofKey, localPort); + httpListener.Start(); + logger.Debug("Get IdpUrl and ProofKey"); + var loginUrl = GetIdpUrlAndProofKey(localPort); + logger.Debug("Open browser"); + StartBrowser(loginUrl); + logger.Debug("Get the redirect SAML request"); + GetRedirectSamlRequest(httpListener); + httpListener.Stop(); } + } - logger.Debug("Open browser"); - StartBrowser(loginUrl); + logger.Debug("Send login request"); + base.Login(); + } - logger.Debug("Get the redirect SAML request"); - _successEvent = new ManualResetEvent(false); - httpListener.BeginGetContext(GetContextCallback, httpListener); - var timeoutInSec = int.Parse(session.properties[SFSessionProperty.BROWSER_RESPONSE_TIMEOUT]); - if (!_successEvent.WaitOne(timeoutInSec * 1000)) - { - logger.Warn("Browser response timeout"); - throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); - } + private string GetIdpUrlAndProofKey(int localPort) + { + if (session._disableConsoleLogin) + { + var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort); + var authenticatorRestResponse = session.restRequester.Post(authenticatorRestRequest); + authenticatorRestResponse.FilterFailedResponse(); - httpListener.Stop(); + _proofKey = authenticatorRestResponse.data.proofKey; + return authenticatorRestResponse.data.ssoUrl; + } + else + { + _proofKey = GenerateProofKey(); + return GetLoginUrl(_proofKey, localPort); } + } - logger.Debug("Send login request"); - base.Login(); + private async Task GetIdpUrlAndProofKeyAsync(int localPort, CancellationToken cancellationToken) + { + if (session._disableConsoleLogin) + { + var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort); + var authenticatorRestResponse = + await session.restRequester.PostAsync( + authenticatorRestRequest, + cancellationToken + ).ConfigureAwait(false); + authenticatorRestResponse.FilterFailedResponse(); + + _proofKey = authenticatorRestResponse.data.proofKey; + return authenticatorRestResponse.data.ssoUrl; + } + else + { + _proofKey = GenerateProofKey(); + return GetLoginUrl(_proofKey, localPort); + } } - private void GetContextCallback(IAsyncResult result) + private void GetRedirectSamlRequest(HttpListener httpListener) { - HttpListener httpListener = (HttpListener) result.AsyncState; + _successEvent = new ManualResetEvent(false); + _tokenExtractionException = null; + httpListener.BeginGetContext(new AsyncCallback(GetContextCallback), httpListener); + var timeoutInSec = int.Parse(session.properties[SFSessionProperty.BROWSER_RESPONSE_TIMEOUT]); + if (!_successEvent.WaitOne(timeoutInSec * 1000)) + { + _successEvent.Set(); + logger.Error("Browser response timeout has been reached"); + throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); + } + if (_tokenExtractionException != null) + { + throw _tokenExtractionException; + } + } - if (httpListener.IsListening) + private void GetContextCallback(IAsyncResult result) + { + HttpListener httpListener = (HttpListener)result.AsyncState; + if (httpListener.IsListening && !_successEvent.WaitOne(0)) { HttpListenerContext context = httpListener.EndGetContext(result); HttpListenerRequest request = context.Request; _samlResponseToken = ValidateAndExtractToken(request); - HttpListenerResponse response = context.Response; - try + if (!string.IsNullOrEmpty(_samlResponseToken)) { - using (var output = response.OutputStream) + HttpListenerResponse response = context.Response; + try { - output.Write(SUCCESS_RESPONSE, 0, SUCCESS_RESPONSE.Length); + using (var output = response.OutputStream) + { + output.Write(SUCCESS_RESPONSE, 0, SUCCESS_RESPONSE.Length); + } + } + catch + { + // Ignore the exception as it does not affect the overall authentication flow + logger.Warn("External browser response not sent out"); } - } - catch - { - // Ignore the exception as it does not affect the overall authentication flow - logger.Warn("External browser response not sent out"); } } @@ -187,53 +201,33 @@ private static HttpListener GetHttpListener(int port) return listener; } - private static void StartBrowser(string url) + private void StartBrowser(string url) { string regexStr = "^http(s?)\\:\\/\\/[0-9a-zA-Z]([-.\\w]*[0-9a-zA-Z@:])*(:(0-9)*)*(\\/?)([a-zA-Z0-9\\-\\.\\?\\,\\&\\(\\)\\/\\\\\\+&%\\$#_=@]*)?$"; Match m = Regex.Match(url, regexStr, RegexOptions.IgnoreCase); - if (!m.Success) - { - logger.Error("Failed to start browser. Invalid url."); - throw new SnowflakeDbException(SFError.INVALID_BROWSER_URL); - } - - if (!Uri.IsWellFormedUriString(url, UriKind.Absolute)) + if (!m.Success || !Uri.IsWellFormedUriString(url, UriKind.Absolute)) { logger.Error("Failed to start browser. Invalid url."); - throw new SnowflakeDbException(SFError.INVALID_BROWSER_URL); + throw new SnowflakeDbException(SFError.INVALID_BROWSER_URL, url); } - // The following code is learnt from https://brockallen.com/2016/09/24/process-start-for-urls-on-net-core/ - // hack because of this: https://github.com/dotnet/corefx/issues/10361 - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - url = url.Replace("&", "^&"); - Process.Start(new ProcessStartInfo("cmd", $"/c start {url}") { UseShellExecute = true }); - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - Process.Start("xdg-open", url); - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - Process.Start("open", url); - } - else - { - throw new SnowflakeDbException(SFError.UNSUPPORTED_PLATFORM); - } + session._browserOperations.OpenUrl(url); } - private static string ValidateAndExtractToken(HttpListenerRequest request) + private string ValidateAndExtractToken(HttpListenerRequest request) { if (request.HttpMethod != "GET") { - throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_WRONG_METHOD, request.HttpMethod); + logger.Error("Failed to extract token due to invalid HTTP method."); + _tokenExtractionException = new SnowflakeDbException(SFError.BROWSER_RESPONSE_WRONG_METHOD, request.Url.Query); + return null; } if (request.Url.Query == null || !request.Url.Query.StartsWith(TOKEN_REQUEST_PREFIX)) { - throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_INVALID_PREFIX, request.Url.Query); + logger.Error("Failed to extract token due to invalid query."); + _tokenExtractionException = new SnowflakeDbException(SFError.BROWSER_RESPONSE_INVALID_PREFIX, request.Url.Query); + return null; } return Uri.UnescapeDataString(request.Url.Query.Substring(TOKEN_REQUEST_PREFIX.Length)); @@ -247,6 +241,8 @@ private SFRestRequest BuildAuthenticatorRestRequest(int port) AccountName = session.properties[SFSessionProperty.ACCOUNT], Authenticator = AUTH_NAME, BrowserModeRedirectPort = port.ToString(), + DriverName = SFEnvironment.DriverName, + DriverVersion = SFEnvironment.DriverVersion, }; int connectionTimeoutSec = int.Parse(session.properties[SFSessionProperty.CONNECTION_TIMEOUT]); @@ -257,9 +253,17 @@ private SFRestRequest BuildAuthenticatorRestRequest(int port) /// protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) { - // Add the token and proof key to the Data - data.Token = _samlResponseToken; - data.ProofKey = _proofKey; + if (string.IsNullOrEmpty(session._idToken)) + { + // Add the token and proof key to the Data + data.Token = _samlResponseToken; + data.ProofKey = _proofKey; + } + else + { + data.Token = session._idToken; + data.Authenticator = TokenType.IdToken.GetAttribute().value; + } } private string GetLoginUrl(string proofKey, int localPort) diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs new file mode 100644 index 000000000..bd9dbf386 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Mono.Unix; +using Mono.Unix.Native; +using Newtonsoft.Json; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using System; +using System.IO; +using System.Runtime.InteropServices; +using System.Security; +using KeyToken = System.Collections.Generic.Dictionary; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager + { + internal const string CredentialCacheDirectoryEnvironmentName = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; + + internal const string CredentialCacheDirName = ".snowflake"; + + internal const string CredentialCacheFileName = "temporary_credential.json"; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly string _jsonCacheDirectory; + + private readonly string _jsonCacheFilePath; + + private readonly FileOperations _fileOperations; + + private readonly DirectoryOperations _directoryOperations; + + private readonly UnixOperations _unixOperations; + + private readonly EnvironmentOperations _environmentOperations; + + public static readonly SFCredentialManagerFileImpl Instance = new SFCredentialManagerFileImpl(FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, EnvironmentOperations.Instance); + + internal SFCredentialManagerFileImpl(FileOperations fileOperations, DirectoryOperations directoryOperations, UnixOperations unixOperations, EnvironmentOperations environmentOperations) + { + _fileOperations = fileOperations; + _directoryOperations = directoryOperations; + _unixOperations = unixOperations; + _environmentOperations = environmentOperations; + SetCredentialCachePath(ref _jsonCacheDirectory, ref _jsonCacheFilePath); + } + + private void SetCredentialCachePath(ref string _jsonCacheDirectory, ref string _jsonCacheFilePath) + { + var customDirectory = _environmentOperations.GetEnvironmentVariable(CredentialCacheDirectoryEnvironmentName); + _jsonCacheDirectory = string.IsNullOrEmpty(customDirectory) ? Path.Combine(HomeDirectoryProvider.HomeDirectory(_environmentOperations), CredentialCacheDirName) : customDirectory; + if (!_directoryOperations.Exists(_jsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_jsonCacheDirectory); + } + _jsonCacheFilePath = Path.Combine(_jsonCacheDirectory, CredentialCacheFileName); + s_logger.Info($"Setting the json credential cache path to {_jsonCacheFilePath}"); + } + + internal void WriteToJsonFile(string content) + { + s_logger.Debug($"Writing credentials to json file in {_jsonCacheFilePath}"); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _fileOperations.Write(_jsonCacheFilePath, content); + } + else + { + if (!_directoryOperations.Exists(_jsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_jsonCacheDirectory); + } + s_logger.Info($"Creating the json file for credential cache in {_jsonCacheFilePath}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + s_logger.Info($"The existing json file for credential cache in {_jsonCacheFilePath} will be overwritten"); + } + var createFileResult = _unixOperations.CreateFileWithPermissions(_jsonCacheFilePath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR); + if (createFileResult == -1) + { + var errorMessage = "Failed to create the JSON token cache file"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + else + { + _fileOperations.Write(_jsonCacheFilePath, content); + } + + var jsonPermissions = _unixOperations.GetFilePermissions(_jsonCacheFilePath); + if (jsonPermissions != FileAccessPermissions.UserReadWriteExecute) + { + var errorMessage = "Permission for the JSON token cache file should contain only the owner access"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + } + + internal KeyToken ReadJsonFile() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return JsonConvert.DeserializeObject(File.ReadAllText(_jsonCacheFilePath)); + } + else + { + if (_unixOperations.CheckFileIsNotOwnedByCurrentUser(_jsonCacheFilePath)) + { + var errorMessage = "Attempting to read a file not owned by the effective user of the current process"; + s_logger.Error(errorMessage); + throw new SecurityException(errorMessage); + } + if (_unixOperations.CheckFileIsNotOwnedByCurrentGroup(_jsonCacheFilePath)) + { + var errorMessage = "Attempting to read a file not owned by the effective group of the current process"; + s_logger.Error(errorMessage); + throw new SecurityException(errorMessage); + } + if (_unixOperations.CheckFileHasAnyOfPermissions(_jsonCacheFilePath, + FileAccessPermissions.GroupReadWriteExecute | FileAccessPermissions.OtherReadWriteExecute)) + { + var errorMessage = "Attempting to read a file with too broad permissions assigned"; + s_logger.Error(errorMessage); + throw new SecurityException(errorMessage); + } + + return JsonConvert.DeserializeObject(_unixOperations.ReadFile(_jsonCacheFilePath)); + } + } + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + + if (keyTokenPairs.TryGetValue(key, out string token)) + { + return token; + } + } + + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + keyTokenPairs.Remove(key); + WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs)); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}"); + KeyToken keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyToken(); + keyTokenPairs[key] = token; + + string jsonString = JsonConvert.SerializeObject(keyTokenPairs); + WriteToJsonFile(jsonString); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs new file mode 100644 index 000000000..5b7fac8b3 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using System.Collections.Generic; +using System.Security; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerInMemoryImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private Dictionary s_credentials = new Dictionary(); + + public static readonly SFCredentialManagerInMemoryImpl Instance = new SFCredentialManagerInMemoryImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials from memory for key: {key}"); + SecureString token; + if (s_credentials.TryGetValue(key, out token)) + { + return SecureStringHelper.Decode(token); + } + else + { + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing credentials from memory for key: {key}"); + s_credentials.Remove(key); + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving credentials into memory for key: {key}"); + s_credentials[key] = SecureStringHelper.Encode(token); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs new file mode 100644 index 000000000..45bef2a38 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Microsoft.Win32.SafeHandles; +using Snowflake.Data.Client; +using Snowflake.Data.Log; +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerWindowsNativeImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + public static readonly SFCredentialManagerWindowsNativeImpl Instance = new SFCredentialManagerWindowsNativeImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting the credentials for key: {key}"); + + IntPtr nCredPtr; + if (!CredRead(key, 1 /* Generic */, 0, out nCredPtr)) + { + s_logger.Info($"Unable to get credentials for key: {key}"); + return ""; + } + + var critCred = new CriticalCredentialHandle(nCredPtr); + Credential cred = critCred.GetCredential(); + return cred.CredentialBlob; + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing the credentials for key: {key}"); + + if (!CredDelete(key, 1 /* Generic */, 0)) + { + s_logger.Info($"Unable to remove credentials because the specified key did not exist: {key}"); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving the credentials for key: {key}"); + + byte[] byteArray = Encoding.Unicode.GetBytes(token); + Credential credential = new Credential(); + credential.AttributeCount = 0; + credential.Attributes = IntPtr.Zero; + credential.Comment = IntPtr.Zero; + credential.TargetAlias = IntPtr.Zero; + credential.Type = 1; // Generic + credential.Persist = 2; // Local Machine + credential.CredentialBlobSize = (uint)(byteArray == null ? 0 : byteArray.Length); + credential.TargetName = key; + credential.CredentialBlob = token; + credential.UserName = Environment.UserName; + + CredWrite(ref credential, 0); + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + private struct Credential + { + public uint Flags; + public uint Type; + [MarshalAs(UnmanagedType.LPWStr)] + public string TargetName; + public IntPtr Comment; + public System.Runtime.InteropServices.ComTypes.FILETIME LastWritten; + public uint CredentialBlobSize; + [MarshalAs(UnmanagedType.LPWStr)] + public string CredentialBlob; + public uint Persist; + public uint AttributeCount; + public IntPtr Attributes; + public IntPtr TargetAlias; + [MarshalAs(UnmanagedType.LPWStr)] + public string UserName; + } + + sealed class CriticalCredentialHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + public CriticalCredentialHandle(IntPtr handle) + { + SetHandle(handle); + } + + public Credential GetCredential() + { + var credential = (Credential)Marshal.PtrToStructure(handle, typeof(Credential)); + return credential; + } + + protected override bool ReleaseHandle() + { + if (IsInvalid) + { + return false; + } + + CredFree(handle); + SetHandleAsInvalid(); + return true; + } + } + + [DllImport("Advapi32.dll", EntryPoint = "CredDeleteW", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern bool CredDelete(string target, uint type, int reservedFlag); + + [DllImport("Advapi32.dll", EntryPoint = "CredReadW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredRead(string target, uint type, int reservedFlag, out IntPtr credentialPtr); + + [DllImport("Advapi32.dll", EntryPoint = "CredWriteW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredWrite([In] ref Credential userCredential, [In] uint flags); + + [DllImport("Advapi32.dll", EntryPoint = "CredFree", SetLastError = true)] + static extern bool CredFree([In] IntPtr cred); + } +} diff --git a/Snowflake.Data/Core/CredentialManager/SFCredentialManagerFactory.cs b/Snowflake.Data/Core/CredentialManager/SFCredentialManagerFactory.cs new file mode 100644 index 000000000..8e573cde8 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/SFCredentialManagerFactory.cs @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Client; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Log; +using System.Runtime.InteropServices; + +namespace Snowflake.Data.Core.CredentialManager +{ + internal enum TokenType + { + [StringAttr(value = "ID_TOKEN")] + IdToken + } + + internal class SFCredentialManagerFactory + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static ISnowflakeCredentialManager s_customCredentialManager = null; + + internal static string BuildCredentialKey(string host, string user, TokenType tokenType) + { + return $"{host.ToUpper()}:{user.ToUpper()}:{SFEnvironment.DriverName}:{tokenType.ToString().ToUpper()}"; + } + + public static void UseDefaultCredentialManager() + { + s_logger.Info("Clearing the custom credential manager"); + s_customCredentialManager = null; + } + + public static void SetCredentialManager(ISnowflakeCredentialManager customCredentialManager) + { + s_logger.Info($"Setting the custom credential manager: {customCredentialManager.GetType().Name}"); + s_customCredentialManager = customCredentialManager; + } + + internal static ISnowflakeCredentialManager GetCredentialManager() + { + if (s_customCredentialManager == null) + { + var defaultCredentialManager = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? (ISnowflakeCredentialManager) + SFCredentialManagerWindowsNativeImpl.Instance : SFCredentialManagerInMemoryImpl.Instance; + s_logger.Info($"Using the default credential manager: {defaultCredentialManager.GetType().Name}"); + return defaultCredentialManager; + } + else + { + s_logger.Info($"Using a custom credential manager: {s_customCredentialManager.GetType().Name}"); + return s_customCredentialManager; + } + } + } +} diff --git a/Snowflake.Data/Core/ErrorMessages.resx b/Snowflake.Data/Core/ErrorMessages.resx index 3532f3394..664122e11 100755 --- a/Snowflake.Data/Core/ErrorMessages.resx +++ b/Snowflake.Data/Core/ErrorMessages.resx @@ -180,6 +180,9 @@ Snowflake type {0} is not supported for parameters. + + Invalid browser url "{0}" cannot be used for authentication. + Browser response timed out after {0} seconds. diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index 64275fa42..c4cd43cdc 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -91,6 +91,9 @@ internal class LoginResponseData [JsonProperty(PropertyName = "masterValidityInSeconds", NullValueHandling = NullValueHandling.Ignore)] internal int masterValidityInSeconds { get; set; } + + [JsonProperty(PropertyName = "idToken", NullValueHandling = NullValueHandling.Ignore)] + internal string idToken { get; set; } } internal class AuthenticatorResponseData diff --git a/Snowflake.Data/Core/SFError.cs b/Snowflake.Data/Core/SFError.cs old mode 100755 new mode 100644 index 44de969a1..6677cb2ee --- a/Snowflake.Data/Core/SFError.cs +++ b/Snowflake.Data/Core/SFError.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. */ @@ -92,7 +92,10 @@ public enum SFError STRUCTURED_TYPE_READ_ERROR, [SFErrorAttr(errorCode = 270062)] - STRUCTURED_TYPE_READ_DETAILED_ERROR + STRUCTURED_TYPE_READ_DETAILED_ERROR, + + [SFErrorAttr(errorCode = 390195)] + ID_TOKEN_INVALID } class SFErrorAttr : Attribute diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs old mode 100755 new mode 100644 index b6a0ebf79..662a954ef --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ @@ -14,6 +14,7 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; +using Snowflake.Data.Core.CredentialManager; using Snowflake.Data.Core.Session; using Snowflake.Data.Core.Tools; @@ -69,6 +70,8 @@ public class SFSession private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; + internal readonly BrowserOperations _browserOperations = BrowserOperations.Instance; + private long _startTime = 0; internal string ConnectionString { get; } internal SecureString Password { get; } @@ -98,6 +101,12 @@ public void SetPooling(bool isEnabled) internal String _queryTag; + private readonly ISnowflakeCredentialManager _credManager = SFCredentialManagerFactory.GetCredentialManager(); + + internal bool _allowSSOTokenCaching; + + internal string _idToken; + internal void ProcessLoginResponse(LoginResponse authnResponse) { if (authnResponse.success) @@ -116,6 +125,12 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) { logger.Debug("Query context cache disabled."); } + if (_allowSSOTokenCaching && !string.IsNullOrEmpty(authnResponse.data.idToken)) + { + _idToken = authnResponse.data.idToken; + var key = SFCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.IdToken); + _credManager.SaveCredentials(key, _idToken); + } logger.Debug($"Session opened: {sessionId}"); _startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); } @@ -128,7 +143,17 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) ""); logger.Error("Authentication failed", e); - throw e; + + if (e.ErrorCode == SFError.ID_TOKEN_INVALID.GetAttribute().errorCode) + { + logger.Info("SSO Token has expired or not valid. Reauthenticating without SSO token...", e); + _idToken = null; + authenticator.Authenticate(); + } + else + { + throw e; + } } } @@ -190,6 +215,13 @@ internal SFSession( _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; _disableSamlUrlCheck = extractedProperties._disableSamlUrlCheck; + _allowSSOTokenCaching = extractedProperties._allowSSOTokenCaching; + + if (_allowSSOTokenCaching) + { + var key = SFCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.IdToken); + _idToken = _credManager.GetCredentials(key); + } } catch (SnowflakeDbException e) { @@ -229,6 +261,11 @@ internal SFSession(String connectionString, SecureString password, IMockRestRequ this.restRequester = restRequester; } + internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester, BrowserOperations browserOperations) : this(connectionString, password, restRequester) + { + _browserOperations = browserOperations; + } + internal Uri BuildUri(string path, Dictionary queryParams = null) { UriBuilder uriBuilder = new UriBuilder(); diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index 2d818f8c8..1cd2b2c98 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -40,6 +40,7 @@ internal class SFSessionHttpClientProperties private TimeSpan _waitingForSessionIdleTimeout; private TimeSpan _expirationTimeout; private bool _poolingEnabled; + internal bool _allowSSOTokenCaching; public static SFSessionHttpClientProperties ExtractAndValidate(SFSessionProperties properties) { @@ -207,6 +208,7 @@ internal Dictionary ToParameterMap() var parameterMap = new Dictionary(); parameterMap[SFSessionParameter.CLIENT_VALIDATE_DEFAULT_PARAMETERS] = validateDefaultParameters; parameterMap[SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE] = clientSessionKeepAlive; + parameterMap[SFSessionParameter.CLIENT_STORE_TEMPORARY_CREDENTIAL] = _allowSSOTokenCaching; return parameterMap; } @@ -245,7 +247,8 @@ public SFSessionHttpClientProperties ExtractProperties(SFSessionProperties prope _waitingForSessionIdleTimeout = extractor.ExtractTimeout(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT), _expirationTimeout = extractor.ExtractTimeout(SFSessionProperty.EXPIRATIONTIMEOUT), _poolingEnabled = extractor.ExtractBooleanWithDefaultValue(SFSessionProperty.POOLINGENABLED), - _disableSamlUrlCheck = extractor.ExtractBooleanWithDefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) + _disableSamlUrlCheck = extractor.ExtractBooleanWithDefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK), + _allowSSOTokenCaching = Boolean.Parse(propertiesDictionary[SFSessionProperty.ALLOW_SSO_TOKEN_CACHING]), }; } diff --git a/Snowflake.Data/Core/Session/SFSessionParameter.cs b/Snowflake.Data/Core/Session/SFSessionParameter.cs index 97fdcec23..445e4fad5 100755 --- a/Snowflake.Data/Core/Session/SFSessionParameter.cs +++ b/Snowflake.Data/Core/Session/SFSessionParameter.cs @@ -14,5 +14,6 @@ internal enum SFSessionParameter QUERY_CONTEXT_CACHE_SIZE, DATE_OUTPUT_FORMAT, TIME_OUTPUT_FORMAT, + CLIENT_STORE_TEMPORARY_CREDENTIAL, } } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index bfbe71a2a..987fb4060 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -112,7 +112,9 @@ internal enum SFSessionProperty [SFSessionPropertyAttr(required = false, defaultValue = "true")] POOLINGENABLED, [SFSessionPropertyAttr(required = false, defaultValue = "false")] - DISABLE_SAML_URL_CHECK + DISABLE_SAML_URL_CHECK, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + ALLOW_SSO_TOKEN_CACHING } class SFSessionPropertyAttr : Attribute diff --git a/Snowflake.Data/Core/Tools/BrowserOperations.cs b/Snowflake.Data/Core/Tools/BrowserOperations.cs new file mode 100644 index 000000000..48ca1baff --- /dev/null +++ b/Snowflake.Data/Core/Tools/BrowserOperations.cs @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Client; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Snowflake.Data.Core.Tools +{ + internal class BrowserOperations + { + public static readonly BrowserOperations Instance = new BrowserOperations(); + + public virtual void OpenUrl(string url) + { + // The following code is learnt from https://brockallen.com/2016/09/24/process-start-for-urls-on-net-core/ +#if NETFRAMEWORK + // .net standard would pass here + Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); +#else + // hack because of this: https://github.com/dotnet/corefx/issues/10361 + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + url = url.Replace("&", "^&"); + Process.Start(new ProcessStartInfo("cmd", $"/c start {url}") { UseShellExecute = true }); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + Process.Start("xdg-open", url); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + Process.Start("open", url); + } + else + { + throw new SnowflakeDbException(SFError.UNSUPPORTED_PLATFORM); + } +#endif + } + } +} diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index 9efe481bd..656c51257 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -14,5 +14,7 @@ public virtual bool Exists(string path) { return File.Exists(path); } + + public virtual void Write(string path, string content) => File.WriteAllText(path, content); } } diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index cb44099b7..4d3d82f59 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -4,6 +4,8 @@ using Mono.Unix; using Mono.Unix.Native; +using System.IO; +using System.Text; namespace Snowflake.Data.Core.Tools { @@ -11,11 +13,34 @@ internal class UnixOperations { public static readonly UnixOperations Instance = new UnixOperations(); + public virtual int CreateFileWithPermissions(string path, FilePermissions permissions) + { + return Syscall.creat(path, permissions); + } + + public virtual string ReadFile(string path) + { + var fileInfo = new UnixFileInfo(path); + using (var handle = fileInfo.OpenRead()) + { + using (var streamReader = new StreamReader(handle, Encoding.Default)) + { + return streamReader.ReadToEnd(); + } + } + } + public virtual int CreateDirectoryWithPermissions(string path, FilePermissions permissions) { return Syscall.mkdir(path, permissions); } + public virtual FileAccessPermissions GetFilePermissions(string path) + { + var fileInfo = new UnixFileInfo(path); + return fileInfo.FileAccessPermissions; + } + public virtual FileAccessPermissions GetDirPermissions(string path) { var dirInfo = new UnixDirectoryInfo(path); @@ -27,5 +52,23 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi var fileInfo = new UnixFileInfo(path); return (permissions & fileInfo.FileAccessPermissions) != 0; } + + public virtual bool CheckFileIsNotOwnedByCurrentUser(string path) + { + var fileInfo = new UnixFileInfo(path); + using (var handle = fileInfo.OpenRead()) + { + return handle.OwnerUser.UserId != Syscall.geteuid(); + } + } + + public virtual bool CheckFileIsNotOwnedByCurrentGroup(string path) + { + var fileInfo = new UnixFileInfo(path); + using (var handle = fileInfo.OpenRead()) + { + return handle.OwnerGroup.GroupId != Syscall.getegid(); + } + } } } diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index 35599c903..3ed3356c9 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 Snowflake.Data diff --git a/doc/Connecting.md b/doc/Connecting.md index 576120f79..1c5f44697 100644 --- a/doc/Connecting.md +++ b/doc/Connecting.md @@ -50,6 +50,7 @@ The following table lists all valid connection properties: | EXPIRATIONTIMEOUT | No | Timeout for using each connection. Connections which last more than specified timeout are considered to be expired and are being removed from the pool. The default is 1 hour. Usage of units possible and allowed are: e. g. `360000ms` (milliseconds), `3600s` (seconds), `60m` (minutes) where seconds are default for a skipped postfix. Special values: `0` - immediate expiration of the connection just after its creation. Expiration timeout cannot be set to infinity. | | POOLINGENABLED | No | Boolean flag indicating if the connection should be a part of a pool. The default value is `true`. | | DISABLE_SAML_URL_CHECK | No | Specifies whether to check if the saml postback url matches the host url from the connection string. The default value is `false`. | +| ALLOW_SSO_TOKEN_CACHING | No | Specifies whether to cache tokens and use them for SSO authentication. The default value is `false`. |