Skip to content

Commit

Permalink
Added option to specify order of species
Browse files Browse the repository at this point in the history
  • Loading branch information
kaandocal committed Jul 12, 2024
1 parent 35307c2 commit 9ec1ebd
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 183 deletions.
8 changes: 7 additions & 1 deletion src/fspsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ struct FSPSystem{IHT <: AbstractIndexHandler, RT}
rfs::RT
end

function FSPSystem(rs::ReactionSystem, ih=DefaultIndexHandler(); combinatoric_ratelaw::Bool=true)
function FSPSystem(rs::ReactionSystem, ih::AbstractIndexHandler=DefaultIndexHandler{length(Catalyst.species(rs))}();
combinatoric_ratelaw::Bool=true)
isempty(Catalyst.get_systems(rs)) ||
error("Supported Catalyst models can not contain subsystems. Please use `rs = Catalyst.flatten(rs::ReactionSystem)` to generate a single system with no subsystems from your Catalyst model.")
any(eq -> !(eq isa Reaction), equations(rs)) &&
Expand All @@ -19,6 +20,11 @@ function FSPSystem(rs::ReactionSystem, ih=DefaultIndexHandler(); combinatoric_ra
FSPSystem(rs, ih, rfs)
end

function FSPSystem(rs::ReactionSystem, order::AbstractVector{Symbol}; kwargs...)
FSPSystem(rs, PermutingIndexHandler(rs, order); kwargs...)
end


"""
build_ratefuncs(rs, ih; state_sym::Symbol, combinatoric_ratelaw::Bool)::Vector
Expand Down
72 changes: 59 additions & 13 deletions src/indexhandlers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,51 +51,53 @@ function LinearIndices end


"""
struct DefaultIndexHandler <: AbstractIndexHandler
struct DefaultIndexHandler{N} <: AbstractIndexHandler
offset::Int
perm::NTuple{N,Int}
end
Basic index handler that stores the state of a system with
`s` species in an `s`-dimensional array. The `offset` parameter
denotes the offset by which the array is indexed (defaults to 1
in Julia).
in Julia). The order of the species is given by the tuple `perm`.
This is the simplest index handler, but it will not be optimal
if some states cannot be reached from the initial state, e.g.
due to the presence of conservation laws. In these cases one should
use `ReducingIndexHandler`, which will automatically elide species
where possible.
use try to remove redundant species where possible.
Constructors: `DefaultIndexHandler([sys::FSPSystem, offset::Int=1])`
"""
struct DefaultIndexHandler <: AbstractIndexHandler
struct DefaultIndexHandler{N} <: AbstractIndexHandler
offset::Int
perm::NTuple{N,Int}
end

DefaultIndexHandler() = DefaultIndexHandler(1)
DefaultIndexHandler{N}() where {N} = DefaultIndexHandler{N}(1, Tuple(1:N))

@deprecate NaiveIndexHandler DefaultIndexHandler true

Base.vec(::DefaultIndexHandler, arr) = vec(arr)
Base.LinearIndices(::DefaultIndexHandler, arr) = LinearIndices(arr)

function pairedindices(ih::DefaultIndexHandler, arr::AbstractArray{T,N},
function pairedindices(ih::DefaultIndexHandler{N}, arr::AbstractArray{T,N},
shift::CartesianIndex{N}) where {T,N}
pairedindices(ih, axes(arr), shift)
end

function pairedindices(ih::DefaultIndexHandler, dims::NTuple{N,T},
function pairedindices(ih::DefaultIndexHandler{N}, dims::NTuple{N,T},
shift::CartesianIndex{N}) where {N,T<:Number}
pairedindices(ih, Base.OneTo.(dims), shift)
end

function pairedindices(::DefaultIndexHandler, dims::NTuple{N,T},
# Important: the species in `shift` are ordered according to `Catalyst.species`!
function pairedindices(ih::DefaultIndexHandler{N}, dims::NTuple{N,T},
shift::CartesianIndex{N}) where {N,T<:AbstractVector}
ranges = tuple((UnitRange(max(first(ax), first(ax)+shift[i]),
min(last(ax), last(ax)+shift[i]))
ranges = tuple((UnitRange(max(first(ax), first(ax)+shift[ih.perm[i]]),
min(last(ax), last(ax)+shift[ih.perm[i]]))
for (i, ax) in enumerate(dims))...)

ranges_shifted = tuple((rng .- shift[i] for (i, rng) in enumerate(ranges))...)
ranges_shifted = tuple((rng .- shift[ih.perm[i]] for (i, rng) in enumerate(ranges))...)

zip(CartesianIndices(ranges_shifted), CartesianIndices(ranges))
end
Expand All @@ -114,5 +116,49 @@ Defines the abundance of species ``S_i`` to be `state_sym[i] - offset`.
function getsubstitutions(ih::DefaultIndexHandler, rs::ReactionSystem; state_sym::Symbol)
nspecs = numspecies(rs)
state_sym_vec = ModelingToolkit.value.(ModelingToolkit.scalarize((@variables ($state_sym)[1:nspecs])[1]))
Dict(symbol => state_sym_vec[i] - ih.offset for (i, symbol) in enumerate(species(rs)))

species_orig = species(rs)
species_perm = [ species_orig[ih.perm[i]] for i in 1:nspecs ]

Dict(symbol => state_sym_vec[i] - ih.offset for (i, symbol) in enumerate(species_perm))
end

"""
PermutingIndexHandler(rs::ReactionSystem, order::AbstractVector)
Constructs an index handler for the reaction system in which the species appear in the order
defined by the vector `order`.
"""
function PermutingIndexHandler(rs::ReactionSystem, order::AbstractVector{Symbol})
PermutingIndexHandler(rs, map(sym -> Catalyst._symbol_to_var(rs, sym), order))
end

function PermutingIndexHandler(rs::ReactionSystem, order::AbstractVector)
spec = Catalyst.species(rs)
nspec = length(spec)

if nspec != length(order)
@error "Length of species vector ($(length(order))) does not match number of species ($nspec)"
end

perm = zeros(Int, nspec)
count = zeros(Int, nspec)

for i in 1:nspec
idx = findfirst(s -> isequal(s, order[i]), spec)
if isnothing(idx)
@error "Cannot find species $(order[i]) in reaction system"
end

if count[idx] > 0
@error "Species $(order[i]) specified twice in ordering"
end

count[idx] += 1
perm[i] = idx
end

@assert count == ones(Int, nspec)

DefaultIndexHandler(1, Tuple(perm))
end
180 changes: 107 additions & 73 deletions test/birthdeath2D.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,107 @@
using Test
using OrdinaryDiffEq
using SteadyStateDiffEq
using Distributions
using FiniteStateProjection
using SparseArrays
using LinearAlgebra
using Sundials

marg(vec; dims) = dropdims(sum(vec; dims); dims)

rs = @reaction_network begin
r1, 0 --> A
r2, A --> 0
s1, 0 --> B
s2, B --> 0
end

sys = FSPSystem(rs)

prs = exp.(2 .* rand(2))
pmap = [ :r1 => prs[1],
:r2 => prs[1] / exp(4 * rand()),
:s1 => prs[2],
:s2 => prs[2] / exp(4 * rand()) ]

ps = last.(pmap)

Nmax = 100
u0 = zeros(Nmax+1, Nmax+1)
u0[1] = 1.0

tt = [ 0.25, 1.0, 10.0 ]

prob = ODEProblem(sys, u0, 10.0, pmap)
sol = solve(prob, Vern7(), abstol=1e-6, saveat=tt)

@test marg(sol.u[1], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[1]))), 0:Nmax) atol=1e-4
@test marg(sol.u[1], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[1]))), 0:Nmax) atol=1e-4

@test marg(sol.u[2], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[2]))), 0:Nmax) atol=1e-4
@test marg(sol.u[2], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[2]))), 0:Nmax) atol=1e-4

@test marg(sol.u[3], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[3]))), 0:Nmax) atol=1e-4
@test marg(sol.u[3], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[3]))), 0:Nmax) atol=1e-4

A = SparseMatrixCSC(sys, (Nmax+1, Nmax+1), pmap, 0)
f = (du,u,p,t) -> mul!(vec(du), A, vec(u))

probA = ODEProblem(f, u0, 10.0)
solA = solve(probA, Vern7(), abstol=1e-6, saveat=tt)

@test sol.u[1] solA.u[1] atol=1e-4
@test sol.u[2] solA.u[2] atol=1e-4
@test sol.u[3] solA.u[3] atol=1e-4

## Steady-State Tests

prob = SteadyStateProblem(sys, u0, pmap)
sol = solve(prob, SSRootfind())
sol.u ./= sum(sol.u)

@test marg(sol.u, dims=2) pdf.(Poisson(ps[1] / ps[2]), 0:Nmax) atol=1e-4
@test marg(sol.u, dims=1) pdf.(Poisson(ps[3] / ps[4]), 0:Nmax) atol=1e-4

A = SparseMatrixCSC(sys, (Nmax+1, Nmax+1), pmap, SteadyState())
f = (du,u,p,t) -> mul!(vec(du), A, vec(u))

probA = SteadyStateProblem(f, u0)
solA = solve(probA, SSRootfind())
solA.u ./= sum(solA.u)

@test sol.u solA.u atol=1e-4
using Test
using OrdinaryDiffEq
using SteadyStateDiffEq
using Distributions
using FiniteStateProjection
using SparseArrays
using LinearAlgebra
using Sundials

marg(vec; dims) = dropdims(sum(vec; dims); dims)

rs = @reaction_network begin
r1, 0 --> A
r2, A --> 0
s1, 0 --> B
s2, B --> 0
end

sys = FSPSystem(rs)

prs = exp.(2 .* rand(2))

pmap = [ :r1 => prs[1],
:r2 => prs[1] / exp(3 * rand()),
:s1 => prs[2],
:s2 => prs[2] / exp(3 * rand()) ]

ps = last.(pmap)

Nmax = 40

u0 = zeros(Nmax+1, Nmax+1)
u0[1] = 1.0

tt = [ 0.25, 1.0, 10.0 ]

prob = ODEProblem(sys, u0, 10.0, pmap)
sol = solve(prob, Vern7(), abstol=1e-6, saveat=tt)

@test marg(sol.u[1], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[1]))), 0:Nmax) atol=1e-4
@test marg(sol.u[1], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[1]))), 0:Nmax) atol=1e-4

@test marg(sol.u[2], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[2]))), 0:Nmax) atol=1e-4
@test marg(sol.u[2], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[2]))), 0:Nmax) atol=1e-4

@test marg(sol.u[3], dims=2) pdf.(Poisson(ps[1] / ps[2] * (1 - exp(-ps[2] * tt[3]))), 0:Nmax) atol=1e-4
@test marg(sol.u[3], dims=1) pdf.(Poisson(ps[3] / ps[4] * (1 - exp(-ps[4] * tt[3]))), 0:Nmax) atol=1e-4

A = SparseMatrixCSC(sys, (Nmax+1, Nmax+1), pmap, 0)
f = (du,u,p,t) -> mul!(vec(du), A, vec(u))

probA = ODEProblem(f, u0, 10.0)
solA = solve(probA, Vern7(), abstol=1e-6, saveat=tt)

@test sol.u[1] solA.u[1] atol=1e-4
@test sol.u[2] solA.u[2] atol=1e-4
@test sol.u[3] solA.u[3] atol=1e-4

## Steady-State Tests

prob_ss = SteadyStateProblem(sys, u0, pmap)
sol_ss = solve(prob_ss, SSRootfind())
sol_ss.u ./= sum(sol_ss.u)

@test marg(sol_ss.u, dims=2) pdf.(Poisson(ps[1] / ps[2]), 0:Nmax) atol=1e-4
@test marg(sol_ss.u, dims=1) pdf.(Poisson(ps[3] / ps[4]), 0:Nmax) atol=1e-4

A_ss = SparseMatrixCSC(sys, (Nmax+1, Nmax+1), pmap, SteadyState())
f_ss = (du,u,p,t) -> mul!(vec(du), A_ss, vec(u))

probA_ss = SteadyStateProblem(f_ss, u0)
solA_ss = solve(probA_ss, SSRootfind())
solA_ss.u ./= sum(solA_ss.u)

@test sol_ss.u solA_ss.u atol=1e-4

## Permutation tests

sys_perm = FSPSystem(rs, [:B, :A])

u0_perm = u0'

A_perm = SparseMatrixCSC(sys_perm, (Nmax+1, Nmax+1), pmap, 0)

idx_perm = vec(reshape(1:(Nmax+1)^2, (Nmax+1, Nmax+1))')
P = sparse(1:(Nmax+1)^2, idx_perm, 1)'

@test A_perm P * A * P'

prob_perm = ODEProblem(sys_perm, u0_perm, 10.0, pmap)
sol_perm = solve(prob_perm, Vern7(), abstol=1e-6, saveat=tt)

@test sol_perm.u[1] sol.u[1]' atol=1e-4
@test sol_perm.u[2] sol.u[2]' atol=1e-4
@test sol_perm.u[3] sol.u[3]' atol=1e-4

## Steady-State Tests

A_fsp_ss_perm = SparseMatrixCSC(sys_perm, (Nmax+1, Nmax+1), pmap, SteadyState())

@test A_fsp_ss_perm P * A_ss * P'

prob_ss_perm = SteadyStateProblem(sys_perm, u0_perm, pmap)
sol_ss_perm = solve(prob_ss_perm, SSRootfind())
sol_ss_perm.u ./= sum(sol_ss_perm.u)

@test sol_ss_perm.u sol_ss.u' atol=1e-4
Loading

0 comments on commit 9ec1ebd

Please sign in to comment.