diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/FlexLabs.EntityFrameworkCore.Upsert.csproj b/src/FlexLabs.EntityFrameworkCore.Upsert/FlexLabs.EntityFrameworkCore.Upsert.csproj index ee41933..cd08d0b 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/FlexLabs.EntityFrameworkCore.Upsert.csproj +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/FlexLabs.EntityFrameworkCore.Upsert.csproj @@ -24,6 +24,7 @@ Also supports injecting sql command generators to add support for other provider v8.1.0 + Adding initial support for Oracle DB! (Thanks to @dadyarri) ++ Adding test support for returning inserted objects (Thanks to @PhenX) + Adding support for upserting into views (ymmv) ! Patching argument count calculation (for max argument count handling) ! Patching null constant handling in the update condition diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs index cf2b586..e93dc6b 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/IUpsertCommandRunner.cs @@ -35,6 +35,21 @@ int Run(DbContext dbContext, IEntityType entityType, ICollection>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class; + /// + /// Run the upsert command for the entities passed and return new or updated entities + /// + /// Entity type of the entities + /// Data context to be used + /// Metadata for the entity + /// Array of entities to be upserted + /// Expression that represents which properties will be used as a match clause for the upsert command + /// Expression that represents which properties will be updated, and what values will be set + /// Expression that checks whether the database entry should be updated + /// Options for the current query that will affect it's behaviour + ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// /// Run the upsert command for the entities passed /// @@ -51,5 +66,20 @@ int Run(DbContext dbContext, IEntityType entityType, ICollection RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, CancellationToken cancellationToken) where TEntity : class; + + /// + /// Run the upsert command for the entities passed and return new or updated entities + /// + /// Entity type of the entities + /// Data context to be used + /// Metadata for the entity + /// Array of entities to be upserted + /// Expression that represents which properties will be used as a match clause for the upsert command + /// Expression that represents which properties will be updated, and what values will be set + /// Expression that checks whether the database entry should be updated + /// Options for the current query that will affect it's behaviour + Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; } } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs index f72aeae..f2f7d3e 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/InMemoryUpsertCommandRunner.cs @@ -18,7 +18,7 @@ public class InMemoryUpsertCommandRunner : UpsertCommandRunnerBase /// public override bool Supports(string providerName) => providerName == "Microsoft.EntityFrameworkCore.InMemory"; - private static void RunCore(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + private static IEnumerable RunCore(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class { // Find matching entities in the dbContext @@ -92,6 +92,8 @@ private static void RunCore(DbContext dbContext, IEntityType entityType updateAction?.Invoke(match.DbEntity, match.NewEntity); } + + return matches.Select(m => m.DbEntity ?? m.NewEntity); } private struct EntityMatch @@ -142,6 +144,20 @@ public override int Run(DbContext dbContext, IEntityType entityType, IC return dbContext.SaveChanges(); } + /// + public override ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, + Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, + RunnerQueryOptions queryOptions) + { + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); + + var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); + dbContext.SaveChanges(); + + return result.ToArray(); + } + /// public override Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, @@ -153,5 +169,19 @@ public override Task RunAsync(DbContext dbContext, IEntityType ent RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); return dbContext.SaveChangesAsync(cancellationToken); } + + /// + public override async Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, + Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, + RunnerQueryOptions queryOptions) + { + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); + + var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions); + await dbContext.SaveChangesAsync().ConfigureAwait(false); + + return result.ToArray(); + } } } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs index 89e0d4b..8b666f9 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/MySqlUpsertCommandRunner.cs @@ -27,10 +27,17 @@ public class MySqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 65535; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand( + string tableName, + ICollection> entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, + bool returnResult = false) { + if (returnResult) + throw new NotImplementedException("MySql runner does not support returning the result of the upsert operation yet"); + var result = new StringBuilder("INSERT "); if (updateExpressions == null) result.Append("IGNORE "); diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs index 0391001..3c6f0c1 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/OracleUpsertCommandRunner.cs @@ -33,11 +33,15 @@ public override string GenerateCommand( ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + KnownExpression? updateCondition, + bool returnResult = false) { ArgumentNullException.ThrowIfNull(entities); - var result = new StringBuilder(); + if (returnResult) + throw new NotImplementedException("Oracle runner does not support returning the result of the upsert operation yet"); + + var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} t USING ("); foreach (var item in entities.Select((e, ind) => new {e, ind})) { diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs index 712c4b2..9c5f2ad 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/PostgreSqlUpsertCommandRunner.cs @@ -23,9 +23,13 @@ public class PostgreSqlUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 32767; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand( + string tableName, + ICollection> entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, + bool returnResult = false) { var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"INSERT INTO {tableName} AS \"T\" ("); @@ -46,6 +50,12 @@ public override string GenerateCommand(string tableName, ICollectionThe columns used to match existing items in the database /// The expressions that represent update commands for matched entities /// The expression that tests whether existing entities should be updated + /// If true, the generated command should return upserted entities /// A fully formed database query public abstract string GenerateCommand(string tableName, ICollection> entities, ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition); + KnownExpression? updateCondition, bool returnResult = false); /// /// Escape the name of the table/column/schema in a given database language /// @@ -96,7 +97,7 @@ protected virtual string GetTableName(IEntityType entityType) private IEnumerable<(string SqlCommand, IEnumerable Arguments)> PrepareCommand(IEntityType entityType, ICollection entities, Expression>? match, Expression>? updater, Expression>? updateCondition, - RunnerQueryOptions queryOptions) + RunnerQueryOptions queryOptions, bool returnResult = false) { var joinColumns = ProcessMatchExpression(entityType, match, queryOptions); var joinColumnNames = joinColumns.Select(c => (ColumnName: c.GetColumnName(), c.IsColumnNullable())).ToArray(); @@ -202,7 +203,7 @@ protected virtual string GetTableName(IEntityType entityType) var columnUpdateExpressions = updateExpressions?.Count > 0 ? updateExpressions.Select(x => (x.Property.GetColumnName(), x.Value)).ToArray() : null; - var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression); + var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression, returnResult); yield return (sqlCommand, arguments); } } @@ -376,6 +377,26 @@ public override int Run(DbContext dbContext, IEntityType entityType, IC return result; } + /// + public override ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + { + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); + + var relationalTypeMappingSource = dbContext.GetService(); + var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); + + var result = new List(); + foreach (var (sqlCommand, arguments) in commands) + { + using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); + var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); + result.AddRange(dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).ToArray()); + } + return result; + } + /// public override async Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, @@ -397,6 +418,26 @@ public override async Task RunAsync(DbContext dbContext, IEntityTy return result; } + /// + public override async Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + { + ArgumentNullException.ThrowIfNull(dbContext); + ArgumentNullException.ThrowIfNull(entityType); + + var relationalTypeMappingSource = dbContext.GetService(); + var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true); + + var result = new List(); + foreach (var (sqlCommand, arguments) in commands) + { + using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand(); + var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray(); + result.AddRange(await dbContext.Set().FromSqlRaw(sqlCommand, dbArguments).ToArrayAsync().ConfigureAwait(false)); + } + return result; + } + private DbParameter PrepareDbCommandArgument(DbCommand dbCommand, IRelationalTypeMappingSource relationalTypeMappingSource, ConstantValue constantValue) { RelationalTypeMapping? relationalTypeMapping = null; diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs index f525a71..d6700f5 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/SqlServerUpsertCommandRunner.cs @@ -23,9 +23,13 @@ public class SqlServerUpsertCommandRunner : RelationalUpsertCommandRunner protected override int? MaxQueryParams => 2090; /// - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, - KnownExpression? updateCondition) + public override string GenerateCommand( + string tableName, + ICollection> entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions, + KnownExpression? updateCondition, + bool returnResult = false) { var result = new StringBuilder(); result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} WITH (HOLDLOCK) AS [T] USING ( VALUES ("); @@ -49,6 +53,10 @@ public override string GenerateCommand(string tableName, ICollection $"{EscapeName(e.ColumnName)} = {ExpandValue(e.Value)}"))); } + if (returnResult) + { + result.Append(" OUTPUT inserted.*"); + } result.Append(';'); return result.ToString(); } diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs index c0c8c45..3a27473 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/Runners/UpsertCommandRunnerBase.cs @@ -23,11 +23,21 @@ public abstract int Run(DbContext dbContext, IEntityType entityType, IC Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class; + /// + public abstract ICollection RunAndReturn(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// public abstract Task RunAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions, CancellationToken cancellationToken) where TEntity : class; + /// + public abstract Task> RunAndReturnAsync(DbContext dbContext, IEntityType entityType, ICollection entities, Expression>? matchExpression, + Expression>? updateExpression, Expression>? updateCondition, RunnerQueryOptions queryOptions) + where TEntity : class; + /// /// Extract property metadata from the match expression /// diff --git a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs index 8fde16a..7abbce7 100644 --- a/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs +++ b/src/FlexLabs.EntityFrameworkCore.Upsert/UpsertCommandBuilder.cs @@ -187,6 +187,18 @@ public int Run() return commandRunner.Run(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); } + /// + /// Execute the upsert command against the database and return new or updated entities + /// + public ICollection RunAndReturn() + { + if (_entities.Count == 0) + return []; + + var commandRunner = GetCommandRunner(); + return commandRunner.RunAndReturn(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); + } + /// /// Execute the upsert command against the database asynchronously /// @@ -200,5 +212,17 @@ public Task RunAsync(CancellationToken token = default) var commandRunner = GetCommandRunner(); return commandRunner.RunAsync(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions, token); } + + /// + /// Execute the upsert command against the database and return new or updated entities + /// + public Task> RunAndReturnAsync() + { + if (_entities.Count == 0) + return Task.FromResult>([]); + + var commandRunner = GetCommandRunner(); + return commandRunner.RunAndReturnAsync(_dbContext, _entityType, _entities, _matchExpression, _updateExpression, _updateCondition, _queryOptions); + } } } diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs index 12eb409..c48f2a6 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.IntegrationTests/DbTestsBase.cs @@ -224,6 +224,91 @@ public void Upsert_IdentityKey_NoOn_AllowWithOverride() dbContext.Countries.Should().HaveCount(2); } + [Fact] + public void Upsert_ReturnResult_Single() + { + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + + ResetDb(); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTable = new DashTable + { + DataSet = "Test", + Updated = _now, + }; + + var result = dbContext.DashTable.Upsert(dashTable) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Should().ContainEquivalentOf(new DashTable + { + ID = result.First().ID, + DataSet = "Test", + Updated = _now, + }); + } + + [Fact] + public void Upsert_ReturnResult_Multiple() + { + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + + ResetDb(new DashTable { DataSet = "Test1" }); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTables = new[] + { + new DashTable + { + DataSet = "Test1", + Updated = _now, + }, + new DashTable + { + DataSet = "Test2", + Updated = _now, + } + }; + + var result = dbContext.DashTable.UpsertRange(dashTables) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Should().HaveCount(2); + + dbContext.DashTable.Should().HaveCount(2); + } + + [Fact] + public void Upsert_ReturnResult_TracksChanges() + { + if (_fixture.DbDriver == DbDriver.MySQL || _fixture.DbDriver == DbDriver.Oracle) + return; // Returning records is not implemented in MySQL and Oracle runners + + ResetDb(new DashTable { DataSet = "Test" }); + using var dbContext = new TestDbContext(_fixture.DataContextOptions); + + var dashTable = new DashTable + { + DataSet = "Test", + Updated = _now, + }; + + var result = dbContext.DashTable.Upsert(dashTable) + .On(c => c.DataSet) + .RunAndReturn(); + + result.Single().Updated = _now.AddYears(1); + dbContext.SaveChanges(); + + var updatedResult = dbContext.DashTable.Single(); + updatedResult.Updated.Should().Be(_now.AddYears(1)); + } + [Fact] public void Upsert_IdentityKey_ExplicitOn_AllowWithOverride() { diff --git a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs index edb807d..56ca1e5 100644 --- a/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs +++ b/test/FlexLabs.EntityFrameworkCore.Upsert.Tests/ReplaceRunnerTests.cs @@ -19,8 +19,13 @@ public class CustomSqliteCommandRunner : RelationalUpsertCommandRunner public override bool Supports(string name) => name == "Microsoft.EntityFrameworkCore.Sqlite"; public static int GenerateCalled; - public override string GenerateCommand(string tableName, ICollection> entities, - ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)> updateExpressions, KnownExpression updateCondition) + public override string GenerateCommand( + string tableName, + ICollection> entities, + ICollection<(string ColumnName, bool IsNullable)> joinColumns, + ICollection<(string ColumnName, IKnownValue Value)> updateExpressions, + KnownExpression updateCondition, + bool returnResult = false) { GenerateCalled++; return "sql";