diff --git a/src/persistence/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs b/src/persistence/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs index 0238d4ed..00ba7052 100644 --- a/src/persistence/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs +++ b/src/persistence/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs @@ -2,6 +2,8 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.Extensions.DependencyInjection; using System.Linq.Expressions; // ReSharper disable once CheckNamespace @@ -55,6 +57,42 @@ public static async Task BulkUpsertAsync( { if (entities.Count == 0) return; + + // Call entity saving handlers if this is an ElsaDbContextBase and handlers are registered + if (dbContext is ElsaDbContextBase elsaDbContext) + { + var serviceProvider = elsaDbContext.GetService(); + var entitySavingHandlers = serviceProvider.GetServices().ToList(); + + if (entitySavingHandlers.Count > 0) + { + // Get key property info once since it's the same for all entities + var keyPropertyInfo = keySelector.GetMemberAccess() as System.Reflection.PropertyInfo; + + // Process each entity through the handlers + foreach (var entity in entities) + { + // Attach entity temporarily to create EntityEntry + var entry = dbContext.Entry(entity); + + // Determine proper EntityState based on whether it exists + if (keyPropertyInfo != null) + { + var keyValue = keyPropertyInfo.GetValue(entity) as string; + entry.State = !string.IsNullOrEmpty(keyValue) ? EntityState.Modified : EntityState.Added; + } + + // Call each handler + foreach (var handler in entitySavingHandlers) + { + await handler.HandleAsync(elsaDbContext, entry, cancellationToken); + } + + // Detach the entity to avoid duplicate tracking + entry.State = EntityState.Detached; + } + } + } // Identify the current provider (e.g., "Microsoft.EntityFrameworkCore.SqlServer") var providerName = dbContext.Database.ProviderName?.ToLowerInvariant() ?? string.Empty;