Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TableDataSavingCallback #602

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/TrixiParticles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ export WeaklyCompressibleSPHSystem, EntropicallyDampedSPHSystem, TotalLagrangian
BoundarySPHSystem, DEMSystem, BoundaryDEMSystem, OpenBoundarySPHSystem, InFlow,
OutFlow
export InfoCallback, SolutionSavingCallback, DensityReinitializationCallback,
PostprocessCallback, StepsizeCallback, UpdateCallback, SteadyStateReachedCallback
PostprocessCallback, StepsizeCallback, UpdateCallback, SteadyStateReachedCallback, TableDataSavingCallback

export ContinuityDensity, SummationDensity
export PenaltyForceGanzenmueller, TransportVelocityAdami
export SchoenbergCubicSplineKernel, SchoenbergQuarticSplineKernel,
Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
include("info.jl")
include("solution_saving.jl")
include("density_reinit.jl")
include("data_saving.jl")
include("post_process.jl")
include("stepsize.jl")
include("update.jl")
Expand Down
218 changes: 218 additions & 0 deletions src/callbacks/data_saving.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
struct TableDataSavingCallback{I, ELTYPE, F}
interval :: I
write_file_interval :: Int
start_at :: ELTYPE
save_interval :: Int
data :: Dict{String, Vector{Vector{ELTYPE}}}
axis_ticks :: Dict{<:Function, Vector{Float64}}
output_directory :: String
functions :: F
empty_vector :: Vector{ELTYPE}
end

function TableDataSavingCallback(; interval::Integer=0, dt=0.0, save_interval::Integer=-1,
output_directory="out", start_at=0.0,
axis_ticks=Dict{Function, Vector{Float64}}(),
write_file_interval::Integer=1, funcs...)
if isempty(funcs)
throw(ArgumentError("`funcs` cannot be empty"))
end

if dt > 0 && interval > 0
throw(ArgumentError("setting both `interval` and `dt` is not supported"))
end

if dt > 0
interval = Float64(dt)
end

table_data_cb = TableDataSavingCallback(interval, write_file_interval, start_at,
save_interval,
Dict{String, Vector{Vector{Float64}}}(),
axis_ticks, output_directory, funcs, Float64[])
if dt > 0
# Add a `tstop` every `dt`, and save the final solution
return PeriodicCallback(table_data_cb, dt,
save_positions=(false, false), final_affect=true)
else
# The first one is the `condition`, the second the `affect!`
return DiscreteCallback(table_data_cb, table_data_cb,
save_positions=(false, false))
end
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:TableDataSavingCallback})
@nospecialize cb # reduce precompilation time

cb_ = cb.affect!

print(io, "TableDataSavingCallback(interval=", cb_.interval, ", ",
"start_at=", cb_.start_at, ")")
end

function Base.show(io::IO,
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:TableDataSavingCallback}})
@nospecialize cb # reduce precompilation time

cb_ = cb.affect!.affect!

print(io, "TableDataSavingCallback(interval=", cb_.interval, ", ",
"start_at=", cb_.start_at, ")")
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:TableDataSavingCallback})
@nospecialize cb # reduce precompilation time

function write_file_interval(interval)
if interval > 1
return "every $(interval) * interval"
elseif interval == 1
return "always"
elseif interval == 0
return "no"
end
end

function save_interval_size(interval)
if interval > 1
return "$(interval) * interval"
elseif interval == -1
return "entire simulation"
elseif interval == 0
return "never"
end
end
if get(io, :compact, false)
show(io, cb)
else
cb_ = cb.affect!

setup = [
"interval" => string(cb_.interval),
"write file" => write_file_interval(cb_.write_file_interval),
"output directory" => abspath(cb_.output_directory),
"start at" => "t = " * string(cb_.start_at),
"save interval size" => save_interval_size(cb_.save_interval)]
summary_box(io, "TableDataSavingCallback", setup)
end
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:TableDataSavingCallback}})
@nospecialize cb # reduce precompilation time

function write_file_interval(interval)
if interval > 1
return "every $(interval) * interval"
elseif interval == 1
return "always"
elseif interval == 0
return "no"
end
end

function save_interval_size(interval)
if interval > 1
return "$(interval) * dt"
elseif interval == -1
return "entire simulation"
elseif interval == 0
return "never"
end
end

if get(io, :compact, false)
show(io, cb)
else
cb_ = cb.affect!.affect!

setup = [
"dt" => string(cb_.interval),
"write file" => write_file_interval(cb_.write_file_interval),
"output directory" => abspath(cb_.output_directory),
"start at" => "t = " * string(cb_.start_at),
"save interval size" => save_interval_size(cb_.save_interval)]
summary_box(io, "TableDataSavingCallback", setup)
end
end

# `condition` with interval
function (cb::TableDataSavingCallback)(u, t, integrator)
(; interval, start_at) = cb
return t >= start_at && condition_integrator_interval(integrator, interval)
end

# `affect!`
function (cb::TableDataSavingCallback)(integrator)
(; empty_vector, functions, data, save_interval, start_at) = cb

vu_ode = integrator.u
v_ode, u_ode = vu_ode.x
semi = integrator.p
t = integrator.t

t >= start_at || return nothing

# Update systems to compute quantities like density and pressure
update_systems_and_nhs(v_ode, u_ode, semi, t; update_from_callback=true)

if !isempty(data) && save_interval == length(data[first(keys(data))])
for value in values(data)
popfirst!(value)
end
end

filenames = system_names(semi.systems)

foreach_system(semi) do system
v = wrap_v(v_ode, system, semi)
u = wrap_u(u_ode, system, semi)
for (key, f) in functions
result = f(v, u, t, system)
if result !== nothing
data_key = string(key) * "_" * filenames[system_indices(system, semi)]

push!(get!(data, data_key, empty_vector), result)
end
end
end

if isfinished(integrator) ||
(cb.write_file_interval > 0 && backup_condition(cb, integrator))
write_table_data(cb, integrator)
end

# Tell OrdinaryDiffEq that `u` has not been modified
u_modified!(integrator, false)
end

function write_table_data(cb::TableDataSavingCallback, integrator)
(; data, axis_ticks, functions) = cb
semi = integrator.p

filenames = system_names(semi.systems)

foreach_system(semi) do system
for (function_key, f) in functions
data_key = string(function_key) * "_" * filenames[system_indices(system, semi)]

haskey(data, data_key) || continue

value = stack(values(data[data_key]))

df = DataFrame(value, [Symbol(data_key * "_$i") for i in 1:size(value, 2)])

df[!, Symbol(data_key * "_avg")] = [sum(value, dims=2)...] ./ size(value, 2)

if !isempty(axis_ticks) && haskey(axis_ticks, f)
df[!, Symbol("x")] = axis_ticks[f]
end

# Write the DataFrame to a CSV file
CSV.write(joinpath(cb.output_directory, data_key * ".csv"), df)
end
end
end
6 changes: 4 additions & 2 deletions src/callbacks/post_process.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,14 @@ function (pp::PostprocessCallback)(integrator)
u_modified!(integrator, false)
end

@inline function backup_condition(cb::PostprocessCallback{Int}, integrator)
@inline function backup_condition(cb::Union{PostprocessCallback{Int},
TableDataSavingCallback{Int}}, integrator)
return integrator.stats.naccept > 0 &&
round(integrator.stats.naccept / cb.interval) % cb.write_file_interval == 0
end

@inline function backup_condition(cb::PostprocessCallback, integrator)
@inline function backup_condition(cb::Union{PostprocessCallback, TableDataSavingCallback},
integrator)
return integrator.stats.naccept > 0 &&
round(Int, integrator.t / cb.interval) % cb.write_file_interval == 0
end
Expand Down
Loading