diff --git a/src/Foundatio.TestHarness/Storage/FileStorageTestsBase.cs b/src/Foundatio.TestHarness/Storage/FileStorageTestsBase.cs index 493622bf..ea177b22 100644 --- a/src/Foundatio.TestHarness/Storage/FileStorageTestsBase.cs +++ b/src/Foundatio.TestHarness/Storage/FileStorageTestsBase.cs @@ -525,17 +525,30 @@ public virtual async Task WillWriteStreamContentAsync() using (storage) { - await using var stream = await storage.GetFileStreamAsync(path, StreamMode.Write); - Assert.NotNull(stream); - Assert.True(stream.CanWrite); + await using (var stream = await storage.GetFileStreamAsync(path, StreamMode.Write)) { + Assert.NotNull(stream); + Assert.True(stream.CanWrite); - await using (var writer = new StreamWriter(stream, Encoding.UTF8, 1024, false)) - { + await using var writer = new StreamWriter(stream, Encoding.UTF8, 1024, false); await writer.WriteAsync(testContent); } string content = await storage.GetFileContentsAsync(path); Assert.Equal(testContent, content); + + string newTestContent = testContent.Substring(0, testContent.Length - 1); + + await using (var stream2 = await storage.GetFileStreamAsync(path, StreamMode.Write)) + { + Assert.NotNull(stream2); + Assert.True(stream2.CanWrite); + + await using var writer = new StreamWriter(stream2, Encoding.UTF8, 1024, false); + await writer.WriteAsync(newTestContent); + } + + content = await storage.GetFileContentsAsync(path); + Assert.Equal(newTestContent, content); } } diff --git a/src/Foundatio/Storage/ActionableStream.cs b/src/Foundatio/Storage/ActionableStream.cs index 582804f7..216ba9a7 100644 --- a/src/Foundatio/Storage/ActionableStream.cs +++ b/src/Foundatio/Storage/ActionableStream.cs @@ -5,24 +5,23 @@ namespace Foundatio.Storage; -public class ActionableStream : Stream +public class ActionableStream : Stream, IAsyncDisposable { - private readonly Action _disposeAction; + private readonly Func _disposeAction; private readonly Stream _stream; + private bool _disposed; - protected override void Dispose(bool disposing) + public ActionableStream(Stream stream, Action disposeAction) { - try + _stream = stream ?? throw new ArgumentNullException(); + _disposeAction = () => { - _disposeAction.Invoke(); - } - catch { /* ignore if these are already disposed; this is to make sure they are */ } - - _stream.Dispose(); - base.Dispose(disposing); + disposeAction(); + return Task.CompletedTask; + }; } - public ActionableStream(Stream stream, Action disposeAction) + public ActionableStream(Stream stream, Func disposeAction) { _stream = stream ?? throw new ArgumentNullException(); _disposeAction = disposeAction; @@ -86,4 +85,37 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati { return _stream.WriteAsync(buffer, offset, count, cancellationToken); } + + protected override void Dispose(bool disposing) + { + DisposeAsync().GetAwaiter().GetResult(); + } + + protected virtual async ValueTask DisposeAsyncCore() + { + if (_disposed) + return; + + try + { + _disposed = true; + await _disposeAction.Invoke(); + } + catch { /* ignore if these are already disposed; this is to make sure they are */ } + + if (_stream is IAsyncDisposable streamAsyncDisposable) + { + await streamAsyncDisposable.DisposeAsync(); + } + else + { + _stream?.Dispose(); + } + } + + public async ValueTask DisposeAsync() + { + await DisposeAsyncCore(); + GC.SuppressFinalize(this); + } }