Skip to content

Commit 08638dd

Browse files
authored
Merge pull request #9087 from pbackus/sumtype-template-overhead
sumtype: reduce template overhead of match
2 parents edf6fb9 + f3d92d9 commit 08638dd

File tree

1 file changed

+106
-76
lines changed

1 file changed

+106
-76
lines changed

std/sumtype.d

Lines changed: 106 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,88 +1860,65 @@ private template Iota(size_t n)
18601860
assert(Iota!3 == AliasSeq!(0, 1, 2));
18611861
}
18621862

1863-
/* The number that the dim-th argument's tag is multiplied by when
1864-
* converting TagTuples to and from case indices ("caseIds").
1865-
*
1866-
* Named by analogy to the stride that the dim-th index into a
1867-
* multidimensional static array is multiplied by to calculate the
1868-
* offset of a specific element.
1869-
*/
1870-
private size_t stride(size_t dim, lengths...)()
1871-
{
1872-
import core.checkedint : mulu;
1873-
1874-
size_t result = 1;
1875-
bool overflow = false;
1876-
1877-
static foreach (i; 0 .. dim)
1878-
{
1879-
result = mulu(result, lengths[i], overflow);
1880-
}
1881-
1882-
/* The largest number matchImpl uses, numCases, is calculated with
1883-
* stride!(SumTypes.length), so as long as this overflow check
1884-
* passes, we don't need to check for overflow anywhere else.
1885-
*/
1886-
assert(!overflow, "Integer overflow");
1887-
return result;
1888-
}
1889-
18901863
private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
18911864
{
18921865
auto ref matchImpl(SumTypes...)(auto ref SumTypes args)
18931866
if (allSatisfy!(isSumType, SumTypes) && args.length > 0)
18941867
{
1895-
alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
1896-
alias TagTuple = .TagTuple!(SumTypes);
1897-
1898-
/*
1899-
* A list of arguments to be passed to a handler needed for the case
1900-
* labeled with `caseId`.
1901-
*/
1902-
template handlerArgs(size_t caseId)
1868+
// Single dispatch (fast path)
1869+
static if (args.length == 1)
19031870
{
1904-
enum tags = TagTuple.fromCaseId(caseId);
1905-
enum argsFrom(size_t i : tags.length) = "";
1906-
enum argsFrom(size_t i) = "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
1907-
".Types[" ~ toCtString!(tags[i]) ~ "])(), " ~ argsFrom!(i + 1);
1908-
enum handlerArgs = argsFrom!0;
1909-
}
1871+
/* When there's only one argument, the caseId is just that
1872+
* argument's tag, so there's no need for TagTuple.
1873+
*/
1874+
enum handlerArgs(size_t caseId) =
1875+
"args[0].get!(SumTypes[0].Types[" ~ toCtString!caseId ~ "])()";
19101876

1911-
/* An AliasSeq of the types of the member values in the argument list
1912-
* returned by `handlerArgs!caseId`.
1913-
*
1914-
* Note that these are the actual (that is, qualified) types of the
1915-
* member values, which may not be the same as the types listed in
1916-
* the arguments' `.Types` properties.
1917-
*/
1918-
template valueTypes(size_t caseId)
1877+
alias valueTypes(size_t caseId) =
1878+
typeof(args[0].get!(SumTypes[0].Types[caseId])());
1879+
1880+
enum numCases = SumTypes[0].Types.length;
1881+
}
1882+
// Multiple dispatch (slow path)
1883+
else
19191884
{
1920-
enum tags = TagTuple.fromCaseId(caseId);
1885+
alias typeCounts = Map!(typeCount, SumTypes);
1886+
alias stride(size_t i) = .stride!(i, typeCounts);
1887+
alias TagTuple = .TagTuple!typeCounts;
1888+
1889+
alias handlerArgs(size_t caseId) = .handlerArgs!(caseId, typeCounts);
19211890

1922-
template getType(size_t i)
1891+
/* An AliasSeq of the types of the member values in the argument list
1892+
* returned by `handlerArgs!caseId`.
1893+
*
1894+
* Note that these are the actual (that is, qualified) types of the
1895+
* member values, which may not be the same as the types listed in
1896+
* the arguments' `.Types` properties.
1897+
*/
1898+
template valueTypes(size_t caseId)
19231899
{
1924-
enum tid = tags[i];
1925-
alias T = SumTypes[i].Types[tid];
1926-
alias getType = typeof(args[i].get!T());
1900+
enum tags = TagTuple.fromCaseId(caseId);
1901+
1902+
template getType(size_t i)
1903+
{
1904+
enum tid = tags[i];
1905+
alias T = SumTypes[i].Types[tid];
1906+
alias getType = typeof(args[i].get!T());
1907+
}
1908+
1909+
alias valueTypes = Map!(getType, Iota!(tags.length));
19271910
}
19281911

1929-
alias valueTypes = Map!(getType, Iota!(tags.length));
1912+
/* The total number of cases is
1913+
*
1914+
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
1915+
*
1916+
* Conveniently, this is equal to stride!(SumTypes.length), so we can
1917+
* use that function to compute it.
1918+
*/
1919+
enum numCases = stride!(SumTypes.length);
19301920
}
19311921

1932-
/* The total number of cases is
1933-
*
1934-
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
1935-
*
1936-
* Or, equivalently,
1937-
*
1938-
* ubyte[SumTypes[0].Types.length]...[SumTypes[$-1].Types.length].sizeof
1939-
*
1940-
* Conveniently, this is equal to stride!(SumTypes.length), so we can
1941-
* use that function to compute it.
1942-
*/
1943-
enum numCases = stride!(SumTypes.length);
1944-
19451922
/* Guaranteed to never be a valid handler index, since
19461923
* handlers.length <= size_t.max.
19471924
*/
@@ -1998,7 +1975,12 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
19981975
mixin("alias ", handlerName!hid, " = handler;");
19991976
}
20001977

2001-
immutable argsId = TagTuple(args).toCaseId;
1978+
// Single dispatch (fast path)
1979+
static if (args.length == 1)
1980+
immutable argsId = args[0].tag;
1981+
// Multiple dispatch (slow path)
1982+
else
1983+
immutable argsId = TagTuple(args).toCaseId;
20021984

20031985
final switch (argsId)
20041986
{
@@ -2029,10 +2011,11 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
20292011
}
20302012
}
20312013

2014+
// Predicate for staticMap
20322015
private enum typeCount(SumType) = SumType.Types.length;
20332016

2034-
/* A TagTuple represents a single possible set of tags that `args`
2035-
* could have at runtime.
2017+
/* A TagTuple represents a single possible set of tags that the arguments to
2018+
* `matchImpl` could have at runtime.
20362019
*
20372020
* Because D does not allow a struct to be the controlling expression
20382021
* of a switch statement, we cannot dispatch on the TagTuple directly.
@@ -2054,22 +2037,23 @@ private enum typeCount(SumType) = SumType.Types.length;
20542037
* When there is only one argument, the caseId is equal to that
20552038
* argument's tag.
20562039
*/
2057-
private struct TagTuple(SumTypes...)
2040+
private struct TagTuple(typeCounts...)
20582041
{
2059-
size_t[SumTypes.length] tags;
2042+
size_t[typeCounts.length] tags;
20602043
alias tags this;
20612044

2062-
alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
2045+
alias stride(size_t i) = .stride!(i, typeCounts);
20632046

20642047
invariant
20652048
{
20662049
static foreach (i; 0 .. tags.length)
20672050
{
2068-
assert(tags[i] < SumTypes[i].Types.length, "Invalid tag");
2051+
assert(tags[i] < typeCounts[i], "Invalid tag");
20692052
}
20702053
}
20712054

2072-
this(ref const(SumTypes) args)
2055+
this(SumTypes...)(ref const SumTypes args)
2056+
if (allSatisfy!(isSumType, SumTypes) && args.length == typeCounts.length)
20732057
{
20742058
static foreach (i; 0 .. tags.length)
20752059
{
@@ -2104,6 +2088,52 @@ private struct TagTuple(SumTypes...)
21042088
}
21052089
}
21062090

2091+
/* The number that the dim-th argument's tag is multiplied by when
2092+
* converting TagTuples to and from case indices ("caseIds").
2093+
*
2094+
* Named by analogy to the stride that the dim-th index into a
2095+
* multidimensional static array is multiplied by to calculate the
2096+
* offset of a specific element.
2097+
*/
2098+
private size_t stride(size_t dim, lengths...)()
2099+
{
2100+
import core.checkedint : mulu;
2101+
2102+
size_t result = 1;
2103+
bool overflow = false;
2104+
2105+
static foreach (i; 0 .. dim)
2106+
{
2107+
result = mulu(result, lengths[i], overflow);
2108+
}
2109+
2110+
/* The largest number matchImpl uses, numCases, is calculated with
2111+
* stride!(SumTypes.length), so as long as this overflow check
2112+
* passes, we don't need to check for overflow anywhere else.
2113+
*/
2114+
assert(!overflow, "Integer overflow");
2115+
return result;
2116+
}
2117+
2118+
/* A list of arguments to be passed to a handler needed for the case
2119+
* labeled with `caseId`.
2120+
*/
2121+
private template handlerArgs(size_t caseId, typeCounts...)
2122+
{
2123+
enum tags = TagTuple!typeCounts.fromCaseId(caseId);
2124+
2125+
alias handlerArgs = AliasSeq!();
2126+
2127+
static foreach (i; 0 .. tags.length)
2128+
{
2129+
handlerArgs = AliasSeq!(
2130+
handlerArgs,
2131+
"args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
2132+
".Types[" ~ toCtString!(tags[i]) ~ "])(), "
2133+
);
2134+
}
2135+
}
2136+
21072137
// Matching
21082138
@safe unittest
21092139
{

0 commit comments

Comments
 (0)