-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added retrieval of equations from CausalGraph * Exclude test_regression.ipynb from merge with main * added DAG checking, OLS residuals/intercept, tables.jl, SCM dag . toposort for data generation * put ols function outside of estimate function * removed DataFrames.jl as dependency, exported SCM functions in CI.jl and inserted a runnable test for them in test/equations.jl * Delete Manifest.toml * Reversed Project.toml * readded svg figure * included equations.jl in runtests.jl * Removed TikzGraphs from test/equations.jl
- Loading branch information
1 parent
aac0ac8
commit 408ab81
Showing
4 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
using LinearAlgebra, Graphs, Tables, Random, Statistics | ||
|
||
# Define the SCM struct | ||
""" | ||
struct SCM | ||
variables::Vector{<:AbstractString} | ||
coefficients::Vector{<:Vector{<:AbstractFloat}} | ||
residuals::Vector{<:Vector{<:AbstractFloat}} | ||
dag::DiGraph | ||
A struct representing a Structural Causal Model (SCM). | ||
# Fields | ||
- `variables::Vector{<:AbstractString}`: A list of variable names. | ||
- `coefficients::Vector{<:Vector{<:AbstractFloat}}`: A list of coefficient vectors for each variable. | ||
- `residuals::Vector{<:Vector{<:AbstractFloat}}`: A list of residuals for each variable. | ||
- `dag::DiGraph`: The directed graph representing the structure of the SCM. | ||
""" | ||
struct SCM | ||
variables::Vector{String} | ||
coefficients::Vector{Vector{Float64}} | ||
residuals::Vector{Vector{Float64}} | ||
dag::DiGraph | ||
end | ||
|
||
function ols_compute(X, y) | ||
X = hcat(ones(size(X, 1)), X) | ||
coef = X \ y | ||
yhat = X * coef | ||
resids = y - yhat | ||
return coef, resids | ||
end | ||
|
||
# Function to estimate equations and return an SCM struct | ||
""" | ||
estimate_equations(t, est_g::DiGraph)::SCM | ||
Estimate linear equations from the given table `t` based on the structure of the directed graph `est_g`. | ||
# Arguments | ||
- `t`: A table containing the data for estimation (supports any Tables.jl-compatible format). | ||
- `est_g::DiGraph`: A directed graph representing the structure of the SCM. | ||
# Returns | ||
- `SCM`: A struct containing the estimated variables, their corresponding coefficients, residuals, and the DAG. | ||
""" | ||
function estimate_equations(t, est_g::DiGraph)::SCM | ||
Tables.istable(t) || throw(ArgumentError("Argument supports just Tables.jl types")) | ||
|
||
columns = Tables.columns(t) | ||
schema = Tables.schema(t) | ||
variables = propertynames(schema.names) | ||
|
||
# Check if it is a DAG | ||
if is_cyclic(est_g) | ||
throw(ArgumentError("The provided graph is cyclic -> est_g::DiGraph should be a DAG.")) | ||
end | ||
|
||
adj_list = collect(edges(est_g)) | ||
|
||
var_names = String[] | ||
coefficients = Vector{Vector{Float64}}() | ||
residuals = Vector{Vector{Float64}}() | ||
nodes = variables | ||
|
||
for node in nodes | ||
node_index = findfirst(==(node), nodes) | ||
preds = [nodes[e.src] for e in adj_list if e.dst == node_index] | ||
|
||
if !isempty(preds) | ||
X = hcat([columns[pred] for pred in preds]...) | ||
y = columns[node] | ||
|
||
coef, resid = ols_compute(X, y) | ||
|
||
if isa(coef, Vector) | ||
push!(var_names, string(node)) | ||
push!(coefficients, coef) | ||
push!(residuals, resid) | ||
else | ||
println("Warning: Coefficients not stored for node $node. Expected vector, got $coef") | ||
end | ||
else | ||
y = columns[node] | ||
intercept = mean(y) | ||
resid = y .- intercept | ||
push!(var_names, string(node)) | ||
push!(coefficients, [intercept]) | ||
push!(residuals, resid) | ||
end | ||
end | ||
|
||
return SCM(var_names, coefficients, residuals, est_g) | ||
end | ||
|
||
# Function to generate data from the SCM | ||
""" | ||
generate_data(scm::SCM, N::Int)::NamedTuple | ||
Generate data from the given SCM. | ||
# Arguments | ||
- `scm::SCM`: The structural causal model. | ||
- `N::Int`: The number of data points to generate. | ||
# Returns | ||
- `NamedTuple`: A NamedTuple containing the generated data. | ||
""" | ||
function generate_data(scm::SCM, N::Int)::NamedTuple | ||
columns = Dict{Symbol, Vector{Float64}}() | ||
|
||
sorted_indices = topological_sort_by_dfs(scm.dag) | ||
sorted_variables = [scm.variables[i] for i in sorted_indices] | ||
variable_index_map = Dict(variable => index for (index, variable) in enumerate(scm.variables)) | ||
|
||
for node in sorted_variables | ||
idx = variable_index_map[node] | ||
coef = scm.coefficients[idx] | ||
residual_std = std(scm.residuals[idx]) | ||
|
||
if length(coef) == 1 | ||
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N) | ||
else | ||
preds = [Symbol(scm.variables[i]) for i in inneighbors(scm.dag, idx)] | ||
if isempty(preds) | ||
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N) | ||
else | ||
X = hcat(ones(N), [columns[pred] for pred in preds]...) | ||
columns[Symbol(node)] = X * coef .+ residual_std * randn(N) | ||
end | ||
end | ||
end | ||
|
||
return NamedTuple(columns) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using CausalInference | ||
using Random | ||
Random.seed!(1) | ||
|
||
# Generate some sample data to use with the GES algorithm | ||
|
||
N = 2000 # number of data points | ||
|
||
# define simple linear model with added noise | ||
x = randn(N) | ||
v = x + randn(N)*0.25 | ||
w = x + randn(N)*0.25 | ||
z = v + w + randn(N)*0.25 | ||
s = z + randn(N)*0.25 | ||
|
||
df = (x=x, v=v, w=w, z=z, s=s) | ||
|
||
est_g, score = ges(df; penalty=1.0, parallel=true) | ||
|
||
|
||
est_dag= pdag2dag!(est_g) | ||
|
||
scm= estimate_equations(df,est_dag) | ||
|
||
display(scm) | ||
|
||
#println(CI.SCM) | ||
|
||
df_generated= generate_data(scm, 2000) | ||
|
||
println("df: ") | ||
|
||
display(df) | ||
|
||
println("df_generated: ") | ||
|
||
|
||
|
||
display(df_generated) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ include("witness.jl") | |
include("fci.jl") | ||
include("klentropy.jl") | ||
include("backdoor.jl") | ||
include("equations.jl") |
408ab81
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mschauer can we get this registered so we can use it with CounterfactualExplanations.jl? 😃 cc @JorgeLuizFranco