Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2"
Mooncake = "0.4"
PrettyTables = "3"
ReverseDiff = "1.15.3"
StableRNGs = "1"
StableRNGs = "1"
2 changes: 2 additions & 0 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ chosen_combinations = [
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
retvals = model(rng)
vns = [VarName{k}() for k in keys(retvals)]
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
elseif varinfo_choice == :typed_vector
DynamicPPL.typed_vector_varinfo(rng, model)
elseif varinfo_choice == :untyped_vector
DynamicPPL.untyped_vector_varinfo(rng, model)
else
error("Unknown varinfo choice: $varinfo_choice")
end
Expand Down
207 changes: 141 additions & 66 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const CHECK_CONSISTENCY_DEFAULT = true

"""
VarNamedVector

Expand Down Expand Up @@ -40,6 +42,11 @@ contents of the internal storage quickly with `getindex_internal(vnv, :)`. The o
of `VarNamedVector` are mostly used to keep track of which part of the internal storage
belongs to which `VarName`.

All constructors accept a keyword argument `check_consistency::Bool=true` that controls
whether to run checks like the number of values matching the number of variables. Some of
these checks can be expensive, so if you are confident in the input, you may want to turn
`check_consistency` off for performance.

# Fields

$(FIELDS)
Expand Down Expand Up @@ -184,68 +191,71 @@ struct VarNamedVector{
vals::TVal,
transforms::TTrans,
is_unconstrained=fill!(BitVector(undef, length(varnames)), 0),
num_inactive=OrderedDict{Int,Int}(),
num_inactive=OrderedDict{Int,Int}();
check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT,
) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector}
if length(varnames) != length(ranges) ||
length(varnames) != length(transforms) ||
length(varnames) != length(is_unconstrained) ||
length(varnames) != length(varname_to_index)
msg = (
"Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: " *
"$(length(varnames)), ranges: " *
"$(length(ranges)), " *
"transforms: $(length(transforms)), " *
"is_unconstrained: $(length(is_unconstrained)), " *
"varname_to_index: $(length(varname_to_index))."
)
throw(ArgumentError(msg))
end
if check_consistency
if length(varnames) != length(ranges) ||
length(varnames) != length(transforms) ||
length(varnames) != length(is_unconstrained) ||
length(varnames) != length(varname_to_index)
msg = (
"Inputs to VarNamedVector have inconsistent lengths. " *
"Got lengths varnames: $(length(varnames)), " *
"ranges: $(length(ranges)), " *
"transforms: $(length(transforms)), " *
"is_unconstrained: $(length(is_unconstrained)), " *
"varname_to_index: $(length(varname_to_index))."
)
throw(ArgumentError(msg))
end

num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive))
if num_vals != length(vals)
msg = (
"The total number of elements in `vals` ($(length(vals))) does not match " *
"the sum of the lengths of the ranges and the number of inactive entries " *
"($(num_vals))."
)
throw(ArgumentError(msg))
end
num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive))
if num_vals != length(vals)
msg = (
"The total number of elements in `vals` ($(length(vals))) does not " *
"match the sum of the lengths of the ranges and the number of " *
"inactive entries ($(num_vals))."
)
throw(ArgumentError(msg))
end

if Set(values(varname_to_index)) != Set(axes(varnames, 1))
msg = (
"The set of values of `varname_to_index` is not the set of valid indices " *
"for `varnames`."
)
throw(ArgumentError(msg))
end
if Set(values(varname_to_index)) != Set(axes(varnames, 1))
msg = (
"The set of values of `varname_to_index` is not the set of valid " *
"indices for `varnames`."
)
throw(ArgumentError(msg))
end

if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index)))
msg = (
"The keys of `num_inactive` are not a subset of the values of " *
"`varname_to_index`."
)
throw(ArgumentError(msg))
end
if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index)))
msg = (
"The keys of `num_inactive` are not a subset of the values of " *
"`varname_to_index`."
)
throw(ArgumentError(msg))
end

# Check that the varnames don't overlap. The time cost is quadratic in number of
# variables. If this ever becomes an issue, we should be able to go down to at least
# N log N by sorting based on subsumes-order.
for vn1 in keys(varname_to_index)
for vn2 in keys(varname_to_index)
vn1 === vn2 && continue
if subsumes(vn1, vn2)
msg = (
"Variables in a VarNamedVector should not subsume each other, " *
"but $vn1 subsumes $vn2, i.e. $vn2 describes a subrange of $vn1."
)
throw(ArgumentError(msg))
# Check that the varnames don't overlap. The time cost is quadratic in number of
# variables. If this ever becomes an issue, we should be able to go down to at
# least N log N by sorting based on subsumes-order.
for vn1 in keys(varname_to_index)
for vn2 in keys(varname_to_index)
vn1 === vn2 && continue
if subsumes(vn1, vn2)
msg = (
"Variables in a VarNamedVector should not subsume each " *
"other, but $vn1 subsumes $vn2."
)
throw(ArgumentError(msg))
end
end
end
end

# We could also have a test to check that the ranges don't overlap, but that sounds
# unlikely to occur, and implementing it in linear time would require a tiny bit of
# thought.
# We could also have a test to check that the ranges don't overlap, but that
# sounds unlikely to occur, and implementing it in linear time would require a
# tiny bit of thought.
end

return new{K,V,TVN,TVal,TTrans}(
varname_to_index,
Expand All @@ -260,7 +270,9 @@ struct VarNamedVector{
end

function VarNamedVector{K,V}() where {K,V}
return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[])
return VarNamedVector(
OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]; check_consistency=false
)
end

# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the
Expand All @@ -269,15 +281,22 @@ end
# making that change here opens some other cans of worms related to how VarInfo uses
# BangBang, that I don't want to deal with right now.
VarNamedVector() = VarNamedVector{VarName,Real}()
VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...))
VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x))
function VarNamedVector(varnames, vals)
return VarNamedVector(collect_maybe(varnames), collect_maybe(vals))
function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT)
return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency)
end
function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISTENCY_DEFAULT)
return VarNamedVector(keys(x), values(x); check_consistency=check_consistency)
end
function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISTENCY_DEFAULT)
return VarNamedVector(
collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency
)
end
function VarNamedVector(
varnames::AbstractVector,
orig_vals::AbstractVector,
transforms=fill(identity, length(varnames)),
transforms=fill(identity, length(varnames));
check_consistency=CHECK_CONSISTENCY_DEFAULT,
)
# Convert `vals` into a vector of vectors.
vals_vecs = map(tovec, orig_vals)
Expand All @@ -301,7 +320,19 @@ function VarNamedVector(
offset = r[end]
end

return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms)
# Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a
# lightweight check of the arguments of this function, and rely on the correctness
# of what this function does? However, the expensive check is whether any variable
# subsumes another, and that's the same regardless of where it's done, so the
# optimisation would be quite pointless.
return VarNamedVector(
varname_to_index,
varnames,
ranges,
vals,
transforms;
check_consistency=check_consistency,
)
end

function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
Expand Down Expand Up @@ -832,7 +863,8 @@ function loosen_types!!(
vnv.vals,
Vector{transform_type}(vnv.transforms),
vnv.is_unconstrained,
vnv.num_inactive,
vnv.num_inactive;
check_consistency=false,
)
end
end
Expand Down Expand Up @@ -887,7 +919,8 @@ function tighten_types(vnv::VarNamedVector)
map(identity, vnv.vals),
map(identity, vnv.transforms),
copy(vnv.is_unconstrained),
copy(vnv.num_inactive),
copy(vnv.num_inactive);
check_consistency=false,
)
end

Expand Down Expand Up @@ -1041,6 +1074,14 @@ julia> unflatten(vnv, vnv[:]) == vnv
true
"""
function unflatten(vnv::VarNamedVector, vals::AbstractVector)
if length(vals) != vector_length(vnv)
throw(
ArgumentError(
"Length of `vals` ($(length(vals))) does not match the length of " *
"`vnv` ($(vector_length(vnv))).",
),
)
end
new_ranges = deepcopy(vnv.ranges)
recontiguify_ranges!(new_ranges)
return VarNamedVector(
Expand All @@ -1049,7 +1090,8 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector)
new_ranges,
vals,
vnv.transforms,
vnv.is_unconstrained,
vnv.is_unconstrained;
check_consistency=false,
)
end

Expand All @@ -1063,6 +1105,32 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
vns_right = right_vnv.varnames
vns_both = union(vns_left, vns_right)

# Check that varnames do not subsume each other.
for vn_left in vns_left
for vn_right in vns_right
vn_left == vn_right && continue
# TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For
# instance, if right has a value for `x` and left has a value for `x[1]`, then
# right will take precedence anyway, and we could merge. However, that requires
# some extra logic that hasn't been done yet.
if subsumes(vn_left, vn_right)
throw(
ArgumentError(
"Cannot merge VarNamedVectors: variable name $vn_left " *
"subsumes $vn_right.",
),
)
elseif subsumes(vn_right, vn_left)
throw(
ArgumentError(
"Cannot merge VarNamedVectors: variable name $vn_right " *
"subsumes $vn_left.",
),
)
end
end
end

# Determine `eltype` of `vals`.
T_left = eltype(left_vnv.vals)
T_right = eltype(right_vnv.vals)
Expand Down Expand Up @@ -1117,7 +1185,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
end

return VarNamedVector(
varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained
varname_to_index,
vns_both,
ranges,
vals,
transforms,
is_unconstrained;
check_consistency=false,
)
end

Expand Down Expand Up @@ -1193,7 +1267,8 @@ function Base.similar(vnv::VarNamedVector)
similar(vnv.vals, 0),
similar(vnv.transforms, 0),
BitVector(),
empty(vnv.num_inactive),
empty(vnv.num_inactive);
check_consistency=false,
)
end

Expand Down