From bcc9f7878d1bf1dc5f5981cc6c24cb365aad6243 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sat, 12 Mar 2022 00:19:39 +0100 Subject: [PATCH 1/3] Add methods to get affected rows instead of their count --- .../Runners/IUpsertCommandRunner.cs | 30 +++++++++++ .../Runners/InMemoryUpsertCommandRunner.cs | 36 ++++++++++++- .../Runners/MySqlUpsertCommandRunner.cs | 9 ++-- .../Runners/PostgreSqlUpsertCommandRunner.cs | 14 +++-- .../Runners/RelationalUpsertCommandRunner.cs | 53 +++++++++++++++++-- .../Runners/SqlServerUpsertCommandRunner.cs | 14 +++-- .../Runners/UpsertCommandRunnerBase.cs | 10 ++++ .../UpsertCommandBuilder.cs | 24 +++++++++ .../DbTestsBase.cs | 53 +++++++++++++++++++ .../ReplaceRunnerTests.cs | 8 ++- 10 files changed, 236 insertions(+), 15 deletions(-) diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs index cf2b586..b1efd63 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs @@ -35,6 +35,36 @@ int Run(DbContext dbContext, IEntityType entityType, ICollection>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class; + /// + /// Run the upsert command for the entities passed + /// + /// Entity type of the entities + /// Data context to be used + /// Metadata for the entity + /// Array of entities to be upserted + /// Expression that represents which properties will be used as a match clause for the upsert command + /// Expression that represents which properties will be updated, and what values will be set + /// Expression that checks whether the database entry should be updated + /// Options for the current query that will affect it's behaviour + ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + + /// + /// Run the upsert command for the entities passed + /// + /// Entity type of the entities + /// Data context to be used + /// Metadata for the entity + /// Array of entities to be upserted + /// Expression that represents which properties will be used as a match clause for the upsert command + /// Expression that represents which properties will be updated, and what values will be set + /// Expression that checks whether the database entry should be updated + /// Options for the current query that will affect it's behaviour + Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// /// Run the upsert command for the entities passed /// diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs index 2eea9a7..30560dc 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs @@ -19,7 +19,7 @@ public class InMemoryUpsertCommandRunner : UpsertCommandRunnerBase /// public override bool Supports(string providerName) => providerName == "Microsoft.EntityFrameworkCore.InMemory"; - private static void RunCore(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + private static IEnumerable RunCore(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class { // Find matching entities in the dbContext @@ -92,6 +92,8 @@ private static void RunCore(DbContext dbContext, IEntityType entityType continue; updateAction?.Invoke(match.DbEntity, match.NewEntity); + + yield return match.NewEntity; } } @@ -145,6 +147,22 @@ public override int Run(DbContext dbContext, IEntityType entityType, IC return dbContext.SaveChanges(); } + /// + public override ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, + Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, + RunnerQueryOptions queryOptions) + { + if (dbContext is null) + throw new ArgumentNullException(nameof(dbContext)); + if (entityType == null) + throw new ArgumentNullException(nameof(entityType)); + + var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); + dbContext.SaveChanges(); + + return result.ToArray(); + } + /// public override Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, @@ -158,5 +176,21 @@ public override Task RunAsync(DbContext dbContext, IEntityType ent RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); return dbContext.SaveChangesAsync(cancellationToken); } + + /// + public override async Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, + Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, + RunnerQueryOptions queryOptions) + { + if (dbContext is null) + throw new ArgumentNullException(nameof(dbContext)); + if (entityType == null) + throw new ArgumentNullException(nameof(entityType)); + + var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); + await dbContext.SaveChangesAsync().ConfigureAwait(false); + + return result.ToArray(); + } } } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs index 0bda6cd..c163396 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs @@ -26,9 +26,12 @@ public class MySqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 65535; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand(string tableName, + ICollection> + entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, bool returnResult = false) { var result = new StringBuilder("INSERT "); if (updateExpressions == null) diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs index a432b47..fbe1d62 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs @@ -22,9 +22,11 @@ public class PostgreSqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 32767; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand(string tableName, + ICollection> entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, bool returnResult = false) { var result = new StringBuilder(); result.Append($"INSERT INTO {tableName} AS \"T\" ("); @@ -45,6 +47,12 @@ public override string GenerateCommand(string tableName, ICollectionThe columns used to match existing items in the database /// The expressions that represent update commands for matched entities /// The expression that tests whether existing entities should be updated + /// /// A fully formed database query public abstract string GenerateCommand(string tableName, ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition); + KnownExpression? updateCondition, bool returnResult = false); /// /// Escape the name of the table/column/schema in a given database language /// @@ -97,7 +98,7 @@ protected virtual string GetTableName(IEntityType entityType) private IEnumerable<(string SqlCommand, IEnumerable Arguments)> PrepareCommand(IEntityType entityType, ICollection entities, Expression>? match, Expression>? updater, Expression>? updateCondition, - RunnerQueryOptions queryOptions) + RunnerQueryOptions queryOptions, bool returnResult = false) { var joinColumns = ProcessMatchExpression(entityType, match, queryOptions); var joinColumnNames = joinColumns.Select(c => (ColumnName: c.GetColumnBaseName(), c.IsColumnNullable())).ToArray(); @@ -204,7 +205,7 @@ protected virtual string GetTableName(IEntityType entityType) var columnUpdateExpressions = updateExpressions?.Count > 0 ? updateExpressions.Select(x => (x.Property.GetColumnBaseName(), x.Value)).ToArray() : null; - var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression); + var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression, returnResult); yield return (sqlCommand, arguments); } } @@ -381,6 +382,29 @@ public override int Run(DbContext dbContext, IEntityType entityType, IC return result; } + /// + public override ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + { + if (dbContext == null) + throw new ArgumentNullException(nameof(dbContext)); + if (entityType == null) + throw new ArgumentNullException(nameof(entityType)); + + var relationalTypeMappingSource = dbContext.GetService(); + var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); + + var result = new List(); + + foreach (var (sqlCommand, arguments) in commands) + { + using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); + var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); + result.AddRange(dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).AsNoTracking().ToArray()); + } + return result; + } + /// public override async Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, @@ -404,6 +428,29 @@ public override async Task RunAsync(DbContext dbContext, IEntityTy return result; } + /// + public override async Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + { + if (dbContext == null) + throw new ArgumentNullException(nameof(dbContext)); + if (entityType == null) + throw new ArgumentNullException(nameof(entityType)); + + var relationalTypeMappingSource = dbContext.GetService(); + var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); + + var result = new List(); + + foreach (var (sqlCommand, arguments) in commands) + { + using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); + var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); + result.AddRange(await dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).AsNoTracking().ToArrayAsync().ConfigureAwait(false)); + } + return result; + } + private object PrepareDbCommandArgument(DbCommand dbCommand, IRelationalTypeMappingSource relationalTypeMappingSource, ConstantValue constantValue) { RelationalTypeMapping? relationalTypeMapping = null; diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs index 15c99c8..c3eff20 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs @@ -22,9 +22,12 @@ public class SqlServerUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 2100; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand(string tableName, + ICollection> + entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, bool returnResult = false) { var result = new StringBuilder(); result.Append($"MERGE INTO {tableName} WITH (HOLDLOCK) AS [T] USING ( VALUES ("); @@ -48,6 +51,11 @@ public override string GenerateCommand(string tableName, ICollection $"{EscapeName(e.ColumnName)} = {ExpandValue(e.Value)}"))); } + + if (returnResult) + { + result.Append(" OUTPUT inserted.*"); + } result.Append(';'); return result.ToString(); } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs index 50d4b18..fa799ab 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs @@ -24,11 +24,21 @@ public abstract int Run(DbContext dbContext, IEntityType entityType, IC Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class; + /// + public abstract ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// public abstract Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, CancellationToken cancellationToken) where TEntity : class; + /// + public abstract Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// /// Extract property metadata from the match expression /// diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs index 341d535..74f646c 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs @@ -193,6 +193,18 @@ public int Run() return commandRunner.Run(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); } + /// + /// Execute the upsert command against the database and returns new or updated entities + /// + public ICollection RunAndReturn() + { + if (_entities.Count == 0) + return Array.Empty(); + + var commandRunner = GetCommandRunner(); + return commandRunner.RunAndReturn(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); + } + /// /// Execute the upsert command against the database asynchronously /// @@ -206,5 +218,17 @@ public Task RunAsync(CancellationToken token = default) var commandRunner = GetCommandRunner(); return commandRunner.RunAsync(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions, token); } + + /// + /// Execute the upsert command against the database and returns new or updated entities + /// + public Task> RunAndReturnAsync() + { + if (_entities.Count == 0) + return Task.FromResult>(Array.Empty()); + + var commandRunner = GetCommandRunner(); + return commandRunner.RunAndReturnAsync(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); + } } } diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs index 35fa5b3..6e19b8f 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs @@ -207,6 +207,59 @@ public void Upsert_IdentityKey_NoOn_AllowWithOverride() dbContext.Countries.Should().HaveCount(2); } + [Fact] + public void Upsert_ReturnResult_Single() + { + ResetDb(); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTable = new DashTable + { + DataSet = "Test", + Updated = _now, + }; + + var result = dbContext.DashTable.Upsert(dashTable) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Should().ContainEquivalentOf(new DashTable + { + ID = 1, + DataSet = "Test", + Updated = _now, + }); + } + + [Fact] + public void Upsert_ReturnResult_Multiple() + { + ResetDb(); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTables = new[] + { + new DashTable + { + DataSet = "Test", + Updated = _now, + }, + new DashTable + { + DataSet = "Test", + Updated = _now, + } + }; + + var result = dbContext.DashTable.UpsertRange(dashTables) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Should().HaveCount(2); + + dbContext.DashTable.Should().HaveCount(1); + } + [Fact] public void Upsert_IdentityKey_ExplicitOn_AllowWithOverride() { diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs index edb807d..61d3b11 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs @@ -19,8 +19,12 @@ public class CustomSqliteCommandRunner : RelationalUpsertCommandRunner public override bool Supports(string name) => name == "Microsoft.EntityFrameworkCore.Sqlite"; public static int GenerateCalled; - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)> updateExpressions, KnownExpression updateCondition) + public override string GenerateCommand(string tableName, + ICollection> + entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)> updateExpressions, + KnownExpression updateCondition, bool returnResult = false) { GenerateCalled++; return "sql"; From 5a982530f8c85dce0db8fa06ca665a8404ad78d8 Mon Sep 17 00:00:00 2001 From: Artiom Chilaru Date: Sun, 24 Nov 2024 17:37:31 +0000 Subject: [PATCH 2/3] Code formatting updates --- .../Runners/IUpsertCommandRunner.cs | 20 +++++++++---------- .../Runners/InMemoryUpsertCommandRunner.cs | 12 ++++------- .../Runners/MySqlUpsertCommandRunner.cs | 12 +++++++---- .../Runners/OracleUpsertCommandRunner.cs | 5 ++++- .../Runners/PostgreSqlUpsertCommandRunner.cs | 6 ++++-- .../Runners/RelationalUpsertCommandRunner.cs | 16 +++++---------- .../Runners/SqlServerUpsertCommandRunner.cs | 10 +++++----- .../UpsertCommandBuilder.cs | 8 ++++---- .../DbTestsBase.cs | 16 ++++++++++----- .../ReplaceRunnerTests.cs | 9 +++++---- 10 files changed, 60 insertions(+), 54 deletions(-) diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs index b1efd63..e93dc6b 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs @@ -36,7 +36,7 @@ int Run(DbContext dbContext, IEntityType entityType, ICollection - /// Run the upsert command for the entities passed + /// Run the upsert command for the entities passed and return new or updated entities /// /// Entity type of the entities /// Data context to be used @@ -61,12 +61,14 @@ ICollection RunAndReturn(DbContext dbContext, IEntityType enti /// Expression that represents which properties will be updated, and what values will be set /// Expression that checks whether the database entry should be updated /// Options for the current query that will affect it's behaviour - Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, - Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) - where TEntity : class; + /// The CancellationToken to observe while waiting for the task to complete. + /// The task that represents the asynchronous upsert operation + Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, + CancellationToken cancellationToken) where TEntity : class; /// - /// Run the upsert command for the entities passed + /// Run the upsert command for the entities passed and return new or updated entities /// /// Entity type of the entities /// Data context to be used @@ -76,10 +78,8 @@ Task> RunAndReturnAsync(DbContext dbContext, IEnti /// Expression that represents which properties will be updated, and what values will be set /// Expression that checks whether the database entry should be updated /// Options for the current query that will affect it's behaviour - /// The CancellationToken to observe while waiting for the task to complete. - /// The task that represents the asynchronous upsert operation - Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, - Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, - CancellationToken cancellationToken) where TEntity : class; + Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; } } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs index edf6a3a..5b480d1 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs @@ -149,10 +149,8 @@ public override ICollection RunAndReturn(DbContext dbContext, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) { - if (dbContext is null) - throw new ArgumentNullException(nameof(dbContext)); - if (entityType == null) - throw new ArgumentNullException(nameof(entityType)); + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); dbContext.SaveChanges(); @@ -177,10 +175,8 @@ public override async Task> RunAndReturnAsync(DbCo Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) { - if (dbContext is null) - throw new ArgumentNullException(nameof(dbContext)); - if (entityType == null) - throw new ArgumentNullException(nameof(entityType)); + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); await dbContext.SaveChangesAsync().ConfigureAwait(false); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs index 09442ea..8b666f9 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs @@ -27,13 +27,17 @@ public class MySqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 65535; /// - public override string GenerateCommand(string tableName, - ICollection> - entities, + public override string GenerateCommand( + string tableName, + ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition, bool returnResult = false) + KnownExpression? updateCondition, + bool returnResult = false) { + if (returnResult) + throw new NotImplementedException("MySql runner does not support returning the result of the upsert operation yet"); + var result = new StringBuilder("INSERT "); if (updateExpressions == null) result.Append("IGNORE "); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs index a08bd29..3c6f0c1 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs @@ -37,8 +37,11 @@ public override string GenerateCommand( bool returnResult = false) { ArgumentNullException.ThrowIfNull(entities); - var result = new StringBuilder(); + if (returnResult) + throw new NotImplementedException("Oracle runner does not support returning the result of the upsert operation yet"); + + var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} t USING ("); foreach (var item in entities.Select((e, ind) => new {e, ind})) { diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs index 56a25f2..9c5f2ad 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs @@ -23,11 +23,13 @@ public class PostgreSqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 32767; /// - public override string GenerateCommand(string tableName, + public override string GenerateCommand( + string tableName, ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition, bool returnResult = false) + KnownExpression? updateCondition, + bool returnResult = false) { var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"INSERT INTO {tableName} AS \"T\" ("); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs index 8a47608..e45541d 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs @@ -27,7 +27,7 @@ public abstract class RelationalUpsertCommandRunner : UpsertCommandRunnerBase /// The columns used to match existing items in the database /// The expressions that represent update commands for matched entities /// The expression that tests whether existing entities should be updated - /// + /// If true, the generated command should return upserted entities /// A fully formed database query public abstract string GenerateCommand(string tableName, ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, @@ -381,16 +381,13 @@ public override int Run(DbContext dbContext, IEntityType entityType, IC public override ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) { - if (dbContext == null) - throw new ArgumentNullException(nameof(dbContext)); - if (entityType == null) - throw new ArgumentNullException(nameof(entityType)); + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); var relationalTypeMappingSource = dbContext.GetService(); var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); var result = new List(); - foreach (var (sqlCommand, arguments) in commands) { using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); @@ -425,16 +422,13 @@ public override async Task RunAsync(DbContext dbContext, IEntityTy public override async Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) { - if (dbContext == null) - throw new ArgumentNullException(nameof(dbContext)); - if (entityType == null) - throw new ArgumentNullException(nameof(entityType)); + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); var relationalTypeMappingSource = dbContext.GetService(); var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); var result = new List(); - foreach (var (sqlCommand, arguments) in commands) { using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs index e537c45..d6700f5 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs @@ -23,12 +23,13 @@ public class SqlServerUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 2090; /// - public override string GenerateCommand(string tableName, - ICollection> - entities, + public override string GenerateCommand( + string tableName, + ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition, bool returnResult = false) + KnownExpression? updateCondition, + bool returnResult = false) { var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} WITH (HOLDLOCK) AS [T] USING ( VALUES ("); @@ -52,7 +53,6 @@ public override string GenerateCommand(string tableName, result.Append(" THEN UPDATE SET "); result.Append(string.Join(", ", updateExpressions.Select((e, i) => $"{EscapeName(e.ColumnName)} = {ExpandValue(e.Value)}"))); } - if (returnResult) { result.Append(" OUTPUT inserted.*"); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs index 50ee697..7abbce7 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs @@ -188,12 +188,12 @@ public int Run() } /// - /// Execute the upsert command against the database and returns new or updated entities + /// Execute the upsert command against the database and return new or updated entities /// public ICollection RunAndReturn() { if (_entities.Count == 0) - return Array.Empty(); + return []; var commandRunner = GetCommandRunner(); return commandRunner.RunAndReturn(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); @@ -214,12 +214,12 @@ public Task RunAsync(CancellationToken token = default) } /// - /// Execute the upsert command against the database and returns new or updated entities + /// Execute the upsert command against the database and return new or updated entities /// public Task> RunAndReturnAsync() { if (_entities.Count == 0) - return Task.FromResult>(Array.Empty()); + return Task.FromResult>([]); var commandRunner = GetCommandRunner(); return commandRunner.RunAndReturnAsync(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs index d981597..b69099f 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs @@ -227,6 +227,9 @@ public void Upsert_IdentityKey_NoOn_AllowWithOverride() [Fact] public void Upsert_ReturnResult_Single() { + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + ResetDb(); using var dbContext = new TestDbContext(_fixture.DataContextOptions); @@ -242,7 +245,7 @@ public void Upsert_ReturnResult_Single() result.Should().ContainEquivalentOf(new DashTable { - ID = 1, + ID = result.First().ID, DataSet = "Test", Updated = _now, }); @@ -251,19 +254,22 @@ public void Upsert_ReturnResult_Single() [Fact] public void Upsert_ReturnResult_Multiple() { - ResetDb(); + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + + ResetDb(new DashTable { DataSet = "Test1" }); using var dbContext = new TestDbContext(_fixture.DataContextOptions); var dashTables = new[] { new DashTable { - DataSet = "Test", + DataSet = "Test1", Updated = _now, }, new DashTable { - DataSet = "Test", + DataSet = "Test2", Updated = _now, } }; @@ -274,7 +280,7 @@ public void Upsert_ReturnResult_Multiple() result.Should().HaveCount(2); - dbContext.DashTable.Should().HaveCount(1); + dbContext.DashTable.Should().HaveCount(2); } [Fact] diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs index 61d3b11..56ca1e5 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs @@ -19,12 +19,13 @@ public class CustomSqliteCommandRunner : RelationalUpsertCommandRunner public override bool Supports(string name) => name == "Microsoft.EntityFrameworkCore.Sqlite"; public static int GenerateCalled; - public override string GenerateCommand(string tableName, - ICollection> - entities, + public override string GenerateCommand( + string tableName, + ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)> updateExpressions, - KnownExpression updateCondition, bool returnResult = false) + KnownExpression updateCondition, + bool returnResult = false) { GenerateCalled++; return "sql"; From eae18cb8c56627eed5f5818f90fb42e589032ac1 Mon Sep 17 00:00:00 2001 From: Artiom Chilaru Date: Sun, 24 Nov 2024 18:08:45 +0000 Subject: [PATCH 3/3] Fix the inMemory runner Change the runner to enable tracking and add a test for that --- .../Runners/InMemoryUpsertCommandRunner.cs | 4 +-- .../Runners/RelationalUpsertCommandRunner.cs | 4 +-- .../DbTestsBase.cs | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs index 5b480d1..f2f7d3e 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs @@ -91,9 +91,9 @@ private static IEnumerable RunCore(DbContext dbContext, IEntit continue; updateAction?.Invoke(match.DbEntity, match.NewEntity); - - yield return match.NewEntity; } + + return matches.Select(m => m.DbEntity ?? m.NewEntity); } private struct EntityMatch diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs index e45541d..a68ec6d 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/RelationalUpsertCommandRunner.cs @@ -392,7 +392,7 @@ public override ICollection RunAndReturn(DbContext dbContext, { using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); - result.AddRange(dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).AsNoTracking().ToArray()); + result.AddRange(dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).ToArray()); } return result; } @@ -433,7 +433,7 @@ public override async Task> RunAndReturnAsync(DbCo { using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); - result.AddRange(await dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).AsNoTracking().ToArrayAsync().ConfigureAwait(false)); + result.AddRange(await dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).ToArrayAsync().ConfigureAwait(false)); } return result; } diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs index b69099f..c48f2a6 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs @@ -283,6 +283,32 @@ public void Upsert_ReturnResult_Multiple() dbContext.DashTable.Should().HaveCount(2); } + [Fact] + public void Upsert_ReturnResult_TracksChanges() + { + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + + ResetDb(new DashTable { DataSet = "Test" }); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTable = new DashTable + { + DataSet = "Test", + Updated = _now, + }; + + var result = dbContext.DashTable.Upsert(dashTable) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Single().Updated = _now.AddYears(1); + dbContext.SaveChanges(); + + var updatedResult = dbContext.DashTable.Single(); + updatedResult.Updated.Should().Be(_now.AddYears(1)); + } + [Fact] public void Upsert_IdentityKey_ExplicitOn_AllowWithOverride() {