Skip to content

Commit

Permalink
Merge branch 'feature/return-result' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
artiomchi committed Nov 24, 2024
2 parents 3445887 + eae18cb commit 924bc1c
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Also supports injecting sql command generators to add support for other provider
<PackageReleaseNotes>
v8.1.0
+ Adding initial support for Oracle DB! (Thanks to @dadyarri)
+ Adding test support for returning inserted objects (Thanks to @PhenX)
+ Adding support for upserting into views (ymmv)
! Patching argument count calculation (for max argument count handling)
! Patching null constant handling in the update condition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ int Run<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntit
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions)
where TEntity : class;

/// <summary>
/// Run the upsert command for the entities passed and return new or updated entities
/// </summary>
/// <typeparam name="TEntity">Entity type of the entities</typeparam>
/// <param name="dbContext">Data context to be used</param>
/// <param name="entityType">Metadata for the entity</param>
/// <param name="entities">Array of entities to be upserted</param>
/// <param name="matchExpression">Expression that represents which properties will be used as a match clause for the upsert command</param>
/// <param name="updateExpression">Expression that represents which properties will be updated, and what values will be set</param>
/// <param name="updateCondition">Expression that checks whether the database entry should be updated</param>
/// <param name="queryOptions">Options for the current query that will affect it's behaviour</param>
ICollection<TEntity> RunAndReturn<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions)
where TEntity : class;

/// <summary>
/// Run the upsert command for the entities passed
/// </summary>
Expand All @@ -51,5 +66,20 @@ int Run<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntit
Task<int> RunAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions,
CancellationToken cancellationToken) where TEntity : class;

/// <summary>
/// Run the upsert command for the entities passed and return new or updated entities
/// </summary>
/// <typeparam name="TEntity">Entity type of the entities</typeparam>
/// <param name="dbContext">Data context to be used</param>
/// <param name="entityType">Metadata for the entity</param>
/// <param name="entities">Array of entities to be upserted</param>
/// <param name="matchExpression">Expression that represents which properties will be used as a match clause for the upsert command</param>
/// <param name="updateExpression">Expression that represents which properties will be updated, and what values will be set</param>
/// <param name="updateCondition">Expression that checks whether the database entry should be updated</param>
/// <param name="queryOptions">Options for the current query that will affect it's behaviour</param>
Task<ICollection<TEntity>> RunAndReturnAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions)
where TEntity : class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class InMemoryUpsertCommandRunner : UpsertCommandRunnerBase
/// <inheritdoc/>
public override bool Supports(string providerName) => providerName == "Microsoft.EntityFrameworkCore.InMemory";

private static void RunCore<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
private static IEnumerable<TEntity> RunCore<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions) where TEntity : class
{
// Find matching entities in the dbContext
Expand Down Expand Up @@ -92,6 +92,8 @@ private static void RunCore<TEntity>(DbContext dbContext, IEntityType entityType

updateAction?.Invoke(match.DbEntity, match.NewEntity);
}

return matches.Select(m => m.DbEntity ?? m.NewEntity);
}

private struct EntityMatch<TEntity>
Expand Down Expand Up @@ -142,6 +144,20 @@ public override int Run<TEntity>(DbContext dbContext, IEntityType entityType, IC
return dbContext.SaveChanges();
}

/// <inheritdoc/>
public override ICollection<TEntity> RunAndReturn<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities,
Expression<Func<TEntity, object>>? matchExpression, Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition,
RunnerQueryOptions queryOptions)
{
ArgumentNullException.ThrowIfNull(dbContext);
ArgumentNullException.ThrowIfNull(entityType);

var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions);
dbContext.SaveChanges();

return result.ToArray();
}

/// <inheritdoc/>
public override Task<int> RunAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions,
Expand All @@ -153,5 +169,19 @@ public override Task<int> RunAsync<TEntity>(DbContext dbContext, IEntityType ent
RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions);
return dbContext.SaveChangesAsync(cancellationToken);
}

/// <inheritdoc/>
public override async Task<ICollection<TEntity>> RunAndReturnAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities,
Expression<Func<TEntity, object>>? matchExpression, Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition,
RunnerQueryOptions queryOptions)
{
ArgumentNullException.ThrowIfNull(dbContext);
ArgumentNullException.ThrowIfNull(entityType);

var result = RunCore(dbContext, entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions);
await dbContext.SaveChangesAsync().ConfigureAwait(false);

return result.ToArray();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@ public class MySqlUpsertCommandRunner : RelationalUpsertCommandRunner
protected override int? MaxQueryParams => 65535;

/// <inheritdoc/>
public override string GenerateCommand(string tableName, ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition)
public override string GenerateCommand(
string tableName,
ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns,
ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition,
bool returnResult = false)
{
if (returnResult)
throw new NotImplementedException("MySql runner does not support returning the result of the upsert operation yet");

var result = new StringBuilder("INSERT ");
if (updateExpressions == null)
result.Append("IGNORE ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ public override string GenerateCommand(
ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns,
ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition)
KnownExpression? updateCondition,
bool returnResult = false)
{
ArgumentNullException.ThrowIfNull(entities);
var result = new StringBuilder();

if (returnResult)
throw new NotImplementedException("Oracle runner does not support returning the result of the upsert operation yet");

var result = new StringBuilder();
result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} t USING (");
foreach (var item in entities.Select((e, ind) => new {e, ind}))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ public class PostgreSqlUpsertCommandRunner : RelationalUpsertCommandRunner
protected override int? MaxQueryParams => 32767;

/// <inheritdoc/>
public override string GenerateCommand(string tableName, ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition)
public override string GenerateCommand(
string tableName,
ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns,
ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition,
bool returnResult = false)
{
var result = new StringBuilder();
result.Append(CultureInfo.InvariantCulture, $"INSERT INTO {tableName} AS \"T\" (");
Expand All @@ -46,6 +50,12 @@ public override string GenerateCommand(string tableName, ICollection<ICollection
{
result.Append("NOTHING");
}

if (returnResult)
{
result.Append(" RETURNING *");
}

return result.ToString();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ public abstract class RelationalUpsertCommandRunner : UpsertCommandRunnerBase
/// <param name="joinColumns">The columns used to match existing items in the database</param>
/// <param name="updateExpressions">The expressions that represent update commands for matched entities</param>
/// <param name="updateCondition">The expression that tests whether existing entities should be updated</param>
/// <param name="returnResult">If true, the generated command should return upserted entities</param>
/// <returns>A fully formed database query</returns>
public abstract string GenerateCommand(string tableName, ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition);
KnownExpression? updateCondition, bool returnResult = false);
/// <summary>
/// Escape the name of the table/column/schema in a given database language
/// </summary>
Expand Down Expand Up @@ -96,7 +97,7 @@ protected virtual string GetTableName(IEntityType entityType)

private IEnumerable<(string SqlCommand, IEnumerable<ConstantValue> Arguments)> PrepareCommand<TEntity>(IEntityType entityType, ICollection<TEntity> entities,
Expression<Func<TEntity, object>>? match, Expression<Func<TEntity, TEntity, TEntity>>? updater, Expression<Func<TEntity, TEntity, bool>>? updateCondition,
RunnerQueryOptions queryOptions)
RunnerQueryOptions queryOptions, bool returnResult = false)
{
var joinColumns = ProcessMatchExpression(entityType, match, queryOptions);
var joinColumnNames = joinColumns.Select(c => (ColumnName: c.GetColumnName(), c.IsColumnNullable())).ToArray();
Expand Down Expand Up @@ -202,7 +203,7 @@ protected virtual string GetTableName(IEntityType entityType)
var columnUpdateExpressions = updateExpressions?.Count > 0
? updateExpressions.Select(x => (x.Property.GetColumnName(), x.Value)).ToArray()
: null;
var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression);
var sqlCommand = GenerateCommand(GetTableName(entityType), newEntities.Skip(entitiesProcessed - entitiesHere).Take(entitiesHere).ToArray(), joinColumnNames, columnUpdateExpressions, updateConditionExpression, returnResult);
yield return (sqlCommand, arguments);
}
}
Expand Down Expand Up @@ -376,6 +377,26 @@ public override int Run<TEntity>(DbContext dbContext, IEntityType entityType, IC
return result;
}

/// <inheritdoc/>
public override ICollection<TEntity> RunAndReturn<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions)
{
ArgumentNullException.ThrowIfNull(dbContext);
ArgumentNullException.ThrowIfNull(entityType);

var relationalTypeMappingSource = dbContext.GetService<IRelationalTypeMappingSource>();
var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true);

var result = new List<TEntity>();
foreach (var (sqlCommand, arguments) in commands)
{
using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand();
var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray();
result.AddRange(dbContext.Set<TEntity>().FromSqlRaw(sqlCommand, dbArguments).ToArray());
}
return result;
}

/// <inheritdoc/>
public override async Task<int> RunAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions,
Expand All @@ -397,6 +418,26 @@ public override async Task<int> RunAsync<TEntity>(DbContext dbContext, IEntityTy
return result;
}

/// <inheritdoc/>
public override async Task<ICollection<TEntity>> RunAndReturnAsync<TEntity>(DbContext dbContext, IEntityType entityType, ICollection<TEntity> entities, Expression<Func<TEntity, object>>? matchExpression,
Expression<Func<TEntity, TEntity, TEntity>>? updateExpression, Expression<Func<TEntity, TEntity, bool>>? updateCondition, RunnerQueryOptions queryOptions)
{
ArgumentNullException.ThrowIfNull(dbContext);
ArgumentNullException.ThrowIfNull(entityType);

var relationalTypeMappingSource = dbContext.GetService<IRelationalTypeMappingSource>();
var commands = PrepareCommand(entityType, entities, matchExpression, updateExpression, updateCondition, queryOptions, true);

var result = new List<TEntity>();
foreach (var (sqlCommand, arguments) in commands)
{
using var dbCommand = dbContext.Database.GetDbConnection().CreateCommand();
var dbArguments = arguments.Select(a => PrepareDbCommandArgument(dbCommand, relationalTypeMappingSource, a)).ToArray();
result.AddRange(await dbContext.Set<TEntity>().FromSqlRaw(sqlCommand, dbArguments).ToArrayAsync().ConfigureAwait(false));
}
return result;
}

private DbParameter PrepareDbCommandArgument(DbCommand dbCommand, IRelationalTypeMappingSource relationalTypeMappingSource, ConstantValue constantValue)
{
RelationalTypeMapping? relationalTypeMapping = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ public class SqlServerUpsertCommandRunner : RelationalUpsertCommandRunner
protected override int? MaxQueryParams => 2090;

/// <inheritdoc/>
public override string GenerateCommand(string tableName, ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns, ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition)
public override string GenerateCommand(
string tableName,
ICollection<ICollection<(string ColumnName, ConstantValue Value, string DefaultSql, bool AllowInserts)>> entities,
ICollection<(string ColumnName, bool IsNullable)> joinColumns,
ICollection<(string ColumnName, IKnownValue Value)>? updateExpressions,
KnownExpression? updateCondition,
bool returnResult = false)
{
var result = new StringBuilder();
result.Append(CultureInfo.InvariantCulture, $"MERGE INTO {tableName} WITH (HOLDLOCK) AS [T] USING ( VALUES (");
Expand All @@ -49,6 +53,10 @@ public override string GenerateCommand(string tableName, ICollection<ICollection
result.Append(" THEN UPDATE SET ");
result.Append(string.Join(", ", updateExpressions.Select((e, i) => $"{EscapeName(e.ColumnName)} = {ExpandValue(e.Value)}")));
}
if (returnResult)
{
result.Append(" OUTPUT inserted.*");
}
result.Append(';');
return result.ToString();
}
Expand Down
Loading

0 comments on commit 924bc1c

Please sign in to comment.