Skip to content

Commit

Permalink
add [BatchSize] and pass thru to multi-row execute (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgravell authored Nov 15, 2023
1 parent b8d5dd0 commit 66e1e2b
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 7 deletions.
11 changes: 9 additions & 2 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ enum ParameterMode
}
}

int? batchSize = null;
foreach (var attrib in methodAttribs)
{
if (IsDapperAttribute(attrib))
Expand Down Expand Up @@ -778,6 +779,12 @@ enum ParameterMode
case Types.CommandPropertyAttribute:
cmdPropsCount++;
break;
case Types.BatchSizeAttribute:
if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int batchTmp)
{
batchSize = batchTmp;
}
break;
}
}
}
Expand Down Expand Up @@ -806,8 +813,8 @@ enum ParameterMode
}


return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null
? null : new(rowCountHint, rowCountHintMember?.Member.Name, cmdProps);
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps);
}

internal static ImmutableArray<ElementMember>? SharedGetParametersToInclude(MemberMap? map, ref OperationFlags flags, string? sql, Action<Diagnostic>? reportDiagnostic, out SqlParseOutputFlags parseFlags)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
bool isAsync = flags.HasAny(OperationFlags.Async);
sb.Append("Execute").Append(isAsync ? "Async" : "").Append("(");
sb.Append("(").Append(castType).Append(")param!");
if (additionalCommandState?.BatchSize is { } batchSize)
{
sb.Append(", batchSize: ").Append(batchSize);
}
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
Expand Down
16 changes: 12 additions & 4 deletions src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public bool Equals(in CommandProperty other)
internal sealed class AdditionalCommandState : IEquatable<AdditionalCommandState>
{
public readonly int RowCountHint;
public readonly int? BatchSize;
public readonly string? RowCountHintMemberName;
public readonly ImmutableArray<CommandProperty> CommandProperties;

Expand Down Expand Up @@ -72,7 +73,8 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited,
countMember = null;
}

return new(count, countMember, Concat(inherited.CommandProperties, overrides.CommandProperties));
return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize,
Concat(inherited.CommandProperties, overrides.CommandProperties));
}

static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x, ImmutableArray<CommandProperty> y)
Expand All @@ -85,10 +87,13 @@ static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x,
return builder.ToImmutable();
}

internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName, ImmutableArray<CommandProperty> commandProperties)
internal AdditionalCommandState(
int rowCountHint, string? rowCountHintMemberName, int? batchSize,
ImmutableArray<CommandProperty> commandProperties)
{
RowCountHint = rowCountHint;
RowCountHintMemberName = rowCountHintMemberName;
BatchSize = batchSize;
CommandProperties = commandProperties;
}

Expand All @@ -98,7 +103,9 @@ internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName
bool IEquatable<AdditionalCommandState>.Equals(AdditionalCommandState other) => Equals(in other);

public bool Equals(in AdditionalCommandState other)
=> RowCountHint == other.RowCountHint && RowCountHintMemberName == other.RowCountHintMemberName
=> RowCountHint == other.RowCountHint
&& BatchSize == other.BatchSize
&& RowCountHintMemberName == other.RowCountHintMemberName
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties));

private static bool Equals(in ImmutableArray<CommandProperty> x, in ImmutableArray<CommandProperty> y)
Expand Down Expand Up @@ -136,6 +143,7 @@ static int GetHashCode(in ImmutableArray<CommandProperty> x)
}

public override int GetHashCode()
=> (RowCountHint + (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
=> (RowCountHint + BatchSize.GetValueOrDefault()
+ (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties));
}
3 changes: 2 additions & 1 deletion src/Dapper.AOT.Analyzers/Internal/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ public const string
IDynamicParameters = nameof(IDynamicParameters),
SqlMapper = nameof(SqlMapper),
SqlAttribute = nameof(SqlAttribute),
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute);
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute),
BatchSizeAttribute = nameof(BatchSizeAttribute);
}
19 changes: 19 additions & 0 deletions src/Dapper.AOT/BatchSizeAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.ComponentModel;
using System.Diagnostics;

namespace Dapper;

/// <summary>
/// Indicates the batch size to use when executing commands with a sequence of argument rows.
/// </summary>
[Conditional("DEBUG")] // not needed post-build, so: evaporate
[ImmutableObject(true)]
[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = false)]
public sealed class BatchSizeAttribute : Attribute
{
/// <summary>
/// Indicates the batch size to use when executing commands with a sequence of argument row; a value of zero disables batch usage; a negative value uses a single batch for all rows.
/// </summary>
public BatchSizeAttribute(int batchSize) => _ = batchSize;
}
23 changes: 23 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/BatchSize.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Dapper;
using System.Data.Common;

[module: DapperAot]

public static class Foo
{
[BatchSize(10)] // should be passed explicitly
static void SomeCode(DbConnection connection, string sql, string bar)
{
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };

connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
}

// no batch size, should be passed implicitly
static void SomeOtherCode(DbConnection connection, string sql, string bar)
{
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };

connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
}
}
144 changes: 144 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/BatchSize.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 13, 20)]
internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::<anonymous type: int id, string bar>[]
// parameter map: bar id
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((object?[])param!, batchSize: 10);

}

[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 21, 20)]
internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Execute, HasParameters, Text, KnownParameters
// takes parameter: global::<anonymous type: int id, string bar>[]
// parameter map: bar id
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((object?[])param!);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class CommandFactory0 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.id);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "bar";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.bar);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.id);
ps[1].Value = AsValue(typed.bar);

}
public override bool CanPrepare => true;

}

private sealed class CommandFactory1 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
{
internal static readonly CommandFactory1 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "id";
p.DbType = global::System.Data.DbType.Int32;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.id);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "bar";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(typed.bar);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
{
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
var ps = cmd.Parameters;
ps[0].Value = AsValue(typed.id);
ps[1].Value = AsValue(typed.bar);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading

0 comments on commit 66e1e2b

Please sign in to comment.