Skip to content

Commit 5be4523

Browse files
committed
perf: Memory performance improvements
Use jemalloc for memory allocation, and updates IPM with a new limit for matrix size switching.
1 parent 5023871 commit 5be4523

File tree

4 files changed

+72
-64
lines changed

4 files changed

+72
-64
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pharmsol = "=0.22.1"
3333
rand = "0.9.0"
3434
anyhow = "1.0.100"
3535
rayon = "1.10.0"
36+
tikv-jemallocator = "0.6.1"
3637

3738
[features]
3839
default = []

src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,9 @@ pub mod prelude {
7373
pub use pharmsol::fetch_params;
7474
pub use pharmsol::lag;
7575
}
76+
77+
use tikv_jemallocator::Jemalloc;
78+
79+
// Use jemalloc as the global allocator
80+
#[global_allocator]
81+
static GLOBAL: Jemalloc = Jemalloc;

src/routines/estimation/ipm.rs

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use anyhow::bail;
44
use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
55
use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
66
use faer::{Col, Mat, Row};
7-
use rayon::prelude::*;
7+
88
/// Applies Burke's Interior Point Method (IPM) to solve a convex optimization problem.
99
///
1010
/// The objective function to maximize is:
@@ -93,14 +93,14 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
9393

9494
let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());
9595

96-
let n_threads = faer::get_global_parallelism().degree();
97-
9896
let rows = psi.nrows();
9997

100-
let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();
101-
10298
let mut h: Mat<f64> = Mat::zeros(rows, rows);
10399

100+
// Cache-size threshold: prefer sequential for small matrices to avoid thread overhead
101+
// For larger matrices, use faer's built-in parallelism which has better cache behavior
102+
const PARALLEL_THRESHOLD: usize = 512;
103+
104104
while mu > eps || norm_r > eps || gap > eps {
105105
let smu = sig * mu;
106106
// inner = lam ./ y, elementwise.
@@ -109,46 +109,32 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
109109
let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));
110110

111111
// Scale each column of psi by the corresponding element of 'inner'
112-
113-
if psi.ncols() > n_threads * 128 {
114-
psi_inner
115-
.par_col_partition_mut(n_threads)
116-
.zip(psi.par_col_partition(n_threads))
117-
.zip(inner.par_partition(n_threads))
118-
.zip(output.par_iter_mut())
119-
.for_each(|(((mut psi_inner, psi), inner), output)| {
120-
psi_inner
121-
.as_mut()
122-
.col_iter_mut()
123-
.zip(psi.col_iter())
124-
.zip(inner.iter())
125-
.for_each(|((col, psi_col), inner_val)| {
126-
col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
127-
*x = psi_val * inner_val;
128-
});
129-
});
130-
faer::linalg::matmul::triangular::matmul(
131-
output.as_mut(),
132-
faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
133-
faer::Accum::Replace,
134-
&psi_inner,
135-
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
136-
psi.transpose(),
137-
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
138-
1.0,
139-
faer::Par::Seq,
140-
);
112+
// Use sequential column scaling - cache-friendly access pattern
113+
psi_inner
114+
.as_mut()
115+
.col_iter_mut()
116+
.zip(psi.col_iter())
117+
.zip(inner.iter())
118+
.for_each(|((col, psi_col), inner_val)| {
119+
col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
120+
*x = psi_val * inner_val;
141121
});
122+
});
142123

143-
let mut first_iter = true;
144-
for output in &output {
145-
if first_iter {
146-
h.copy_from(output);
147-
first_iter = false;
148-
} else {
149-
h += output;
150-
}
151-
}
124+
// Use faer's built-in parallelism for matmul - it has better cache tiling
125+
// than our manual partitioning which caused false sharing
126+
if psi.ncols() > PARALLEL_THRESHOLD {
127+
faer::linalg::matmul::triangular::matmul(
128+
h.as_mut(),
129+
faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
130+
faer::Accum::Replace,
131+
&psi_inner,
132+
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
133+
psi.transpose(),
134+
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
135+
1.0,
136+
faer::Par::rayon(0), // Let faer handle parallelism with proper cache tiling
137+
);
152138
} else {
153139
psi_inner
154140
.as_mut()

src/routines/expansion/adaptative_grid.rs

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::structs::theta::Theta;
22
use anyhow::Result;
3-
use faer::Row;
43

54
/// Implements the adaptive grid algorithm for support point expansion.
65
///
@@ -25,36 +24,52 @@ pub fn adaptative_grid(
2524
ranges: &[(f64, f64)],
2625
min_dist: f64,
2726
) -> Result<()> {
28-
let mut candidates = Vec::new();
27+
let n_params = ranges.len();
28+
let n_spp = theta.nspp();
2929

30-
// Collect all points first to avoid borrowing conflicts
30+
// Pre-compute deltas for each dimension (cache-friendly: sequential access)
31+
let deltas: Vec<f64> = ranges.iter().map(|(lo, hi)| eps * (hi - lo)).collect();
32+
33+
// Pre-allocate flat buffer for candidates to minimize allocations
34+
// Max candidates = n_spp * n_params * 2 directions
35+
let max_candidates = n_spp * n_params * 2;
36+
let mut candidates: Vec<f64> = Vec::with_capacity(max_candidates * n_params);
37+
let mut n_candidates = 0usize;
38+
39+
// Generate candidates using flat buffer
3140
for spp in theta.matrix().row_iter() {
32-
for (j, val) in spp.iter().enumerate() {
33-
let l = eps * (ranges[j].1 - ranges[j].0); //abs?
41+
for (j, &val) in spp.iter().enumerate() {
42+
let l = deltas[j];
43+
44+
// Check +delta direction
3445
if val + l < ranges[j].1 {
35-
let mut plus = Row::zeros(spp.ncols());
36-
plus[j] = l;
37-
plus += spp;
38-
candidates.push(plus.iter().copied().collect::<Vec<f64>>());
46+
// Append candidate point to flat buffer
47+
for (k, &v) in spp.iter().enumerate() {
48+
candidates.push(if k == j { v + l } else { v });
49+
}
50+
n_candidates += 1;
3951
}
52+
53+
// Check -delta direction
4054
if val - l > ranges[j].0 {
41-
let mut minus = Row::zeros(spp.ncols());
42-
minus[j] = -l;
43-
minus += spp;
44-
candidates.push(minus.iter().copied().collect::<Vec<f64>>());
55+
for (k, &v) in spp.iter().enumerate() {
56+
candidates.push(if k == j { v - l } else { v });
57+
}
58+
n_candidates += 1;
4559
}
4660
}
4761
}
4862

49-
// Option 1: Check all points against the original theta, then add them
50-
let keep = candidates
51-
.iter()
52-
.filter(|point| theta.check_point(point, min_dist))
53-
.cloned()
54-
.collect::<Vec<_>>();
63+
// Filter and add valid candidates
64+
// Use slice views into the flat buffer to avoid allocations
65+
for i in 0..n_candidates {
66+
let start = i * n_params;
67+
let end = start + n_params;
68+
let point = &candidates[start..end];
5569

56-
for point in keep {
57-
theta.add_point(point.as_slice())?;
70+
if theta.check_point(point, min_dist) {
71+
theta.add_point(point)?;
72+
}
5873
}
5974

6075
Ok(())

0 commit comments

Comments
 (0)