Skip to content

Commit 04f8485

Browse files
authored
Add extension methods for ConcurrentDictionary + AsyncAtomicFactory (#513)
1 parent 8bfee3d commit 04f8485

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

BitFaster.Caching.UnitTests/Atomic/ConcurrentDictionaryExtensionTests.cs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
using System.Collections.Concurrent;
33
using System.Collections.Generic;
4+
using System.Threading.Tasks;
45
using BitFaster.Caching.Atomic;
56
using FluentAssertions;
67
using Xunit;
@@ -9,7 +10,8 @@ namespace BitFaster.Caching.UnitTests.Atomic
910
{
1011
public class ConcurrentDictionaryExtensionTests
1112
{
12-
private ConcurrentDictionary<int, AtomicFactory<int, int>> dictionary = new ConcurrentDictionary<int, AtomicFactory<int, int>>();
13+
private ConcurrentDictionary<int, AtomicFactory<int, int>> dictionary = new();
14+
private ConcurrentDictionary<int, AsyncAtomicFactory<int, int>> dictionaryAsync = new();
1315

1416
[Fact]
1517
public void WhenItemIsAddedItCanBeRetrieved()
@@ -20,6 +22,15 @@ public void WhenItemIsAddedItCanBeRetrieved()
2022
value.Should().Be(1);
2123
}
2224

25+
[Fact]
26+
public async Task WhenItemIsAddedAsyncItCanBeRetrieved()
27+
{
28+
await dictionaryAsync.GetOrAddAsync(1, k => Task.FromResult(k));
29+
30+
dictionaryAsync.TryGetValue(1, out int value).Should().BeTrue();
31+
value.Should().Be(1);
32+
}
33+
2334
[Fact]
2435
public void WhenItemIsAddedWithArgItCanBeRetrieved()
2536
{
@@ -29,12 +40,27 @@ public void WhenItemIsAddedWithArgItCanBeRetrieved()
2940
value.Should().Be(3);
3041
}
3142

43+
[Fact]
44+
public async Task WhenItemIsAddedWithArgAsyncItCanBeRetrieved()
45+
{
46+
await dictionaryAsync.GetOrAddAsync(1, (k, a) => Task.FromResult(k + a), 2);
47+
48+
dictionaryAsync.TryGetValue(1, out int value).Should().BeTrue();
49+
value.Should().Be(3);
50+
}
51+
3252
[Fact]
3353
public void WhenKeyDoesNotExistTryGetReturnsFalse()
3454
{
3555
dictionary.TryGetValue(1, out int _).Should().BeFalse();
3656
}
3757

58+
[Fact]
59+
public void WhenKeyDoesNotExistAsyncTryGetReturnsFalse()
60+
{
61+
dictionaryAsync.TryGetValue(1, out int _).Should().BeFalse();
62+
}
63+
3864
[Fact]
3965
public void WhenItemIsAddedItCanBeRemovedByKey()
4066
{
@@ -44,6 +70,15 @@ public void WhenItemIsAddedItCanBeRemovedByKey()
4470
value.Should().Be(1);
4571
}
4672

73+
[Fact]
74+
public async Task WhenItemIsAddedAsyncItCanBeRemovedByKey()
75+
{
76+
await dictionaryAsync.GetOrAddAsync(1, k => Task.FromResult(k));
77+
78+
dictionaryAsync.TryRemove(1, out int value).Should().BeTrue();
79+
value.Should().Be(1);
80+
}
81+
4782
[Fact]
4883
public void WhenItemIsAddedItCanBeRemovedByKvp()
4984
{
@@ -53,10 +88,25 @@ public void WhenItemIsAddedItCanBeRemovedByKvp()
5388
dictionary.TryGetValue(1, out _).Should().BeFalse();
5489
}
5590

91+
[Fact]
92+
public async Task WhenItemIsAddedAsyncItCanBeRemovedByKvp()
93+
{
94+
await dictionaryAsync.GetOrAddAsync(1, k => Task.FromResult(k));
95+
96+
dictionaryAsync.TryRemove(new KeyValuePair<int, int>(1, 1)).Should().BeTrue();
97+
dictionaryAsync.TryGetValue(1, out _).Should().BeFalse();
98+
}
99+
56100
[Fact]
57101
public void WhenKeyDoesNotExistTryRemoveReturnsFalse()
58102
{
59103
dictionary.TryRemove(1, out int _).Should().BeFalse();
60104
}
105+
106+
[Fact]
107+
public void WhenKeyDoesNotExistAsyncTryRemoveReturnsFalse()
108+
{
109+
dictionaryAsync.TryRemove(1, out int _).Should().BeFalse();
110+
}
61111
}
62112
}

BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Concurrent;
33
using System.Collections.Generic;
4+
using System.Threading.Tasks;
45

56
namespace BitFaster.Caching.Atomic
67
{
@@ -36,6 +37,33 @@ public static V GetOrAdd<K, V, TArg>(this ConcurrentDictionary<K, AtomicFactory<
3637
return atomicFactory.GetValue(key, valueFactory, factoryArgument);
3738
}
3839

40+
/// <summary>
41+
/// Adds a key/value pair to the ConcurrentDictionary if the key does not already exist. Returns the new value, or the existing value if the key already exists.
42+
/// </summary>
43+
/// <param name="dictionary">The ConcurrentDictionary to use.</param>
44+
/// <param name="key">The key of the element to add.</param>
45+
/// <param name="valueFactory">The function used to generate a value for the key.</param>
46+
/// <returns>The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary.</returns>
47+
public static ValueTask<V> GetOrAddAsync<K, V>(this ConcurrentDictionary<K, AsyncAtomicFactory<K, V>> dictionary, K key, Func<K, Task<V>> valueFactory)
48+
{
49+
var asyncAtomicFactory = dictionary.GetOrAdd(key, _ => new AsyncAtomicFactory<K, V>());
50+
return asyncAtomicFactory.GetValueAsync(key, valueFactory);
51+
}
52+
53+
/// <summary>
54+
/// Adds a key/value pair to the ConcurrentDictionary by using the specified function and an argument if the key does not already exist, or returns the existing value if the key exists.
55+
/// </summary>
56+
/// <param name="dictionary">The ConcurrentDictionary to use.</param>
57+
/// <param name="key">The key of the element to add.</param>
58+
/// <param name="valueFactory">The function used to generate a value for the key.</param>
59+
/// <param name="factoryArgument">An argument value to pass into valueFactory.</param>
60+
/// <returns>The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary.</returns>
61+
public static ValueTask<V> GetOrAddAsync<K, V, TArg>(this ConcurrentDictionary<K, AsyncAtomicFactory<K, V>> dictionary, K key, Func<K, TArg, Task<V>> valueFactory, TArg factoryArgument)
62+
{
63+
var asyncAtomicFactory = dictionary.GetOrAdd(key, _ => new AsyncAtomicFactory<K, V>());
64+
return asyncAtomicFactory.GetValueAsync(key, valueFactory, factoryArgument);
65+
}
66+
3967
/// <summary>
4068
/// Attempts to get the value associated with the specified key from the ConcurrentDictionary.
4169
/// </summary>
@@ -58,6 +86,27 @@ public static bool TryGetValue<K, V>(this ConcurrentDictionary<K, AtomicFactory<
5886
return false;
5987
}
6088

89+
/// <summary>
90+
/// Attempts to get the value associated with the specified key from the ConcurrentDictionary.
91+
/// </summary>
92+
/// <param name="dictionary">The ConcurrentDictionary to use.</param>
93+
/// <param name="key">The key of the value to get.</param>
94+
/// <param name="value">When this method returns, contains the object from the ConcurrentDictionary that has the specified key, or the default value of the type if the operation failed.</param>
95+
public static bool TryGetValue<K, V>(this ConcurrentDictionary<K, AsyncAtomicFactory<K, V>> dictionary, K key, out V value)
96+
{
97+
AsyncAtomicFactory<K, V> output;
98+
var ret = dictionary.TryGetValue(key, out output);
99+
100+
if (ret && output.IsValueCreated)
101+
{
102+
value = output.ValueIfCreated;
103+
return true;
104+
}
105+
106+
value = default;
107+
return false;
108+
}
109+
61110
/// <summary>
62111
/// Removes a key and value from the dictionary.
63112
/// </summary>
@@ -75,6 +124,23 @@ public static bool TryRemove<K, V>(this ConcurrentDictionary<K, AtomicFactory<K,
75124
#endif
76125
}
77126

127+
/// <summary>
128+
/// Removes a key and value from the dictionary.
129+
/// </summary>
130+
/// <param name="dictionary">The ConcurrentDictionary to use.</param>
131+
/// <param name="item">The KeyValuePair representing the key and value to remove.</param>
132+
/// <returns>true if the object was removed successfully; otherwise, false.</returns>
133+
public static bool TryRemove<K, V>(this ConcurrentDictionary<K, AsyncAtomicFactory<K, V>> dictionary, KeyValuePair<K, V> item)
134+
{
135+
var kvp = new KeyValuePair<K, AsyncAtomicFactory<K, V>>(item.Key, new AsyncAtomicFactory<K, V>(item.Value));
136+
#if NET6_0_OR_GREATER
137+
return dictionary.TryRemove(kvp);
138+
#else
139+
// https://devblogs.microsoft.com/pfxteam/little-known-gems-atomic-conditional-removals-from-concurrentdictionary/
140+
return ((ICollection<KeyValuePair<K, AsyncAtomicFactory<K, V>>>)dictionary).Remove(kvp);
141+
#endif
142+
}
143+
78144
/// <summary>
79145
/// Attempts to remove and return the value that has the specified key from the ConcurrentDictionary.
80146
/// </summary>
@@ -93,5 +159,24 @@ public static bool TryRemove<K, V>(this ConcurrentDictionary<K, AtomicFactory<K,
93159
value = default;
94160
return false;
95161
}
162+
163+
/// <summary>
164+
/// Attempts to remove and return the value that has the specified key from the ConcurrentDictionary.
165+
/// </summary>
166+
/// <param name="dictionary">The ConcurrentDictionary to use.</param>
167+
/// <param name="key">The key of the element to remove and return.</param>
168+
/// <param name="value">When this method returns, contains the object removed from the ConcurrentDictionary, or the default value of the TValue type if key does not exist.</param>
169+
/// <returns>true if the object was removed successfully; otherwise, false.</returns>
170+
public static bool TryRemove<K, V>(this ConcurrentDictionary<K, AsyncAtomicFactory<K, V>> dictionary, K key, out V value)
171+
{
172+
if (dictionary.TryRemove(key, out var atomic))
173+
{
174+
value = atomic.ValueIfCreated;
175+
return true;
176+
}
177+
178+
value = default;
179+
return false;
180+
}
96181
}
97182
}

0 commit comments

Comments
 (0)