Skip to content

Commit 4495a03

Browse files
authored
rust: replace static fixture with test helpers (#4433)
Summary: In #4385, @nfelt suggested using test helpers to easily populate a commit rather than using a single fixture across all test cases. This patch implements such test helpers and uses them in the server tests. The net diff to the test code (not the helpers) is −32 lines. Test Plan: Existing tests pass. The new test helpers mostly don’t have explicit tests, but that’s okay since they’re exercised by the main test code. wchargin-branch: rust-test-data-builders
1 parent abaf81a commit 4495a03

File tree

2 files changed

+209
-72
lines changed

2 files changed

+209
-72
lines changed

tensorboard/data/server/commit.rs

+163
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,166 @@ mod tests {
136136
);
137137
}
138138
}
139+
140+
/// Utilities for constructing commits with test data.
141+
//
142+
// Not `#[cfg(test)]` because we have a doctest:
143+
// <https://github.com/rust-lang/rust/issues/45599>
144+
pub mod test_data {
145+
use super::*;
146+
use crate::data_compat;
147+
use crate::reservoir::StageReservoir;
148+
149+
#[derive(Default)]
150+
pub struct CommitBuilder(Commit);
151+
152+
impl CommitBuilder {
153+
/// Creates a new builder for an empty commit.
154+
pub fn new() -> Self {
155+
Self::default()
156+
}
157+
158+
/// Ensures that data for a run exists, and update it.
159+
///
160+
/// This takes a callback because we can't just return a mutable reference, since there are
161+
/// nested `RwLockWriteGuard`s.
162+
fn with_run_data(&self, run: Run, update: impl FnOnce(&mut RunData)) {
163+
let mut runs = self.0.runs.write().expect("runs.write");
164+
let mut run_data = runs
165+
.entry(run)
166+
.or_default()
167+
.write()
168+
.expect("runs[run].write");
169+
update(&mut run_data);
170+
}
171+
172+
/// Adds a scalar time series, creating the run if it doesn't exist, and setting its start
173+
/// time if unset.
174+
///
175+
/// # Examples
176+
///
177+
/// ```
178+
/// use rustboard_core::commit::{test_data::CommitBuilder, Commit};
179+
/// use rustboard_core::types::Step;
180+
///
181+
/// let my_commit: Commit = CommitBuilder::new()
182+
/// .scalars("train", "xent", |mut b| {
183+
/// b.eval(|Step(i)| 1.0 / (i + 1) as f32).build()
184+
/// })
185+
/// .build();
186+
/// ```
187+
pub fn scalars(
188+
self,
189+
run: &str,
190+
tag: &str,
191+
build: impl FnOnce(ScalarTimeSeriesBuilder) -> TimeSeries<ScalarValue>,
192+
) -> Self {
193+
self.with_run_data(Run(run.to_string()), |run_data| {
194+
let time_series = build(ScalarTimeSeriesBuilder::default());
195+
if let (None, Some((_step, wall_time, _value))) =
196+
(run_data.start_time, time_series.valid_values().next())
197+
{
198+
run_data.start_time = Some(wall_time);
199+
}
200+
run_data.scalars.insert(Tag(tag.to_string()), time_series);
201+
});
202+
self
203+
}
204+
205+
/// Ensures that a run is present and sets its start time.
206+
///
207+
/// If you don't care about the start time and the run is going to have data, anyway, you
208+
/// can just add the data directly, which will create the run as a side effect.
209+
///
210+
/// # Panics
211+
///
212+
/// If `start_time` represents an invalid wall time (i.e., `start_time` is `Some(wt)` but
213+
/// `WallTime::new(wt)` is `None`).
214+
pub fn run(self, run: &str, start_time: Option<f64>) -> Self {
215+
self.with_run_data(Run(run.to_string()), |run_data| {
216+
run_data.start_time = start_time.map(|f| WallTime::new(f).unwrap());
217+
});
218+
self
219+
}
220+
221+
/// Consumes this `CommitBuilder` and returns the underlying commit.
222+
pub fn build(self) -> Commit {
223+
self.0
224+
}
225+
}
226+
227+
pub struct ScalarTimeSeriesBuilder {
228+
/// Initial step. Increments by `1` for each point.
229+
step_start: Step,
230+
/// Initial wall time. Increments by `1.0` for each point.
231+
wall_time_start: WallTime,
232+
/// Number of points in this time series.
233+
len: u64,
234+
/// Custom summary metadata. Leave `None` to use default.
235+
metadata: Option<Box<pb::SummaryMetadata>>,
236+
/// Scalar evaluation function, called for each point in the series.
237+
///
238+
/// By default, this maps every step to `0.0`.
239+
eval: Box<dyn Fn(Step) -> f32>,
240+
}
241+
242+
impl Default for ScalarTimeSeriesBuilder {
243+
fn default() -> Self {
244+
ScalarTimeSeriesBuilder {
245+
step_start: Step(0),
246+
wall_time_start: WallTime::new(0.0).unwrap(),
247+
len: 1,
248+
metadata: None,
249+
eval: Box::new(|_| 0.0),
250+
}
251+
}
252+
}
253+
254+
impl ScalarTimeSeriesBuilder {
255+
pub fn step_start(&mut self, raw_step: i64) -> &mut Self {
256+
self.step_start = Step(raw_step);
257+
self
258+
}
259+
pub fn wall_time_start(&mut self, raw_wall_time: f64) -> &mut Self {
260+
self.wall_time_start = WallTime::new(raw_wall_time).unwrap();
261+
self
262+
}
263+
pub fn len(&mut self, len: u64) -> &mut Self {
264+
self.len = len;
265+
self
266+
}
267+
pub fn metadata(&mut self, metadata: Option<Box<pb::SummaryMetadata>>) -> &mut Self {
268+
self.metadata = metadata;
269+
self
270+
}
271+
pub fn eval(&mut self, eval: impl Fn(Step) -> f32 + 'static) -> &mut Self {
272+
self.eval = Box::new(eval);
273+
self
274+
}
275+
276+
/// Constructs a scalar time series from the state of this builder.
277+
///
278+
/// # Panics
279+
///
280+
/// If the wall time of a point would overflow to be infinite.
281+
pub fn build(&self) -> TimeSeries<ScalarValue> {
282+
let metadata = self.metadata.clone().unwrap_or_else(|| {
283+
let sample_point = Box::new(pb::summary::value::Value::SimpleValue(0.0));
284+
data_compat::SummaryValue(sample_point).initial_metadata(None)
285+
});
286+
let mut time_series = TimeSeries::new(metadata);
287+
288+
let mut rsv = StageReservoir::new(self.len as usize);
289+
for i in 0..self.len {
290+
let step = Step(self.step_start.0 + i as i64);
291+
let wall_time =
292+
WallTime::new(f64::from(self.wall_time_start) + (i as f64)).unwrap();
293+
let value = (self.eval)(step);
294+
rsv.offer(step, (wall_time, Ok(ScalarValue(value))));
295+
}
296+
rsv.commit(&mut time_series.basin);
297+
298+
time_series
299+
}
300+
}
301+
}

tensorboard/data/server/server.rs

+46-72
Original file line numberDiff line numberDiff line change
@@ -324,80 +324,22 @@ mod tests {
324324
use super::*;
325325
use tonic::Code;
326326

327-
use crate::commit::{ScalarValue, TimeSeries};
328-
use crate::data_compat;
329-
use crate::proto::tensorboard as pb;
330-
use crate::reservoir::StageReservoir;
331-
use crate::types::{Run, Step, Tag, WallTime};
332-
333-
/// Creates a commit with some test data.
334-
fn sample_commit() -> Commit {
335-
let commit = Commit::new();
336-
337-
let mut runs = commit.runs.write().unwrap();
338-
339-
fn scalar_series(points: Vec<(Step, WallTime, f32)>) -> TimeSeries<ScalarValue> {
340-
use pb::summary::value::Value::SimpleValue;
341-
let mut ts = commit::TimeSeries::new(
342-
data_compat::SummaryValue(Box::new(SimpleValue(0.0))).initial_metadata(None),
343-
);
344-
let mut rsv = StageReservoir::new(points.len());
345-
for (step, wall_time, value) in points {
346-
rsv.offer(step, (wall_time, Ok(commit::ScalarValue(value))));
347-
}
348-
rsv.commit(&mut ts.basin);
349-
ts
350-
}
351-
352-
let mut train = runs
353-
.entry(Run("train".to_string()))
354-
.or_default()
355-
.write()
356-
.unwrap();
357-
train.start_time = Some(WallTime::new(1234.0).unwrap());
358-
train.scalars.insert(
359-
Tag("xent".to_string()),
360-
scalar_series(vec![
361-
(Step(0), WallTime::new(1235.0).unwrap(), 0.5),
362-
(Step(1), WallTime::new(1236.0).unwrap(), 0.25),
363-
(Step(2), WallTime::new(1237.0).unwrap(), 0.125),
364-
]),
365-
);
366-
drop(train);
367-
368-
let mut test = runs
369-
.entry(Run("test".to_string()))
370-
.or_default()
371-
.write()
372-
.unwrap();
373-
test.start_time = Some(WallTime::new(6234.0).unwrap());
374-
test.scalars.insert(
375-
Tag("accuracy".to_string()),
376-
scalar_series(vec![
377-
(Step(0), WallTime::new(6235.0).unwrap(), 0.125),
378-
(Step(1), WallTime::new(6236.0).unwrap(), 0.25),
379-
(Step(2), WallTime::new(6237.0).unwrap(), 0.5),
380-
]),
381-
);
382-
drop(test);
383-
384-
// An run with no start time or data: should not show up in results.
385-
runs.entry(Run("empty".to_string())).or_default();
386-
387-
drop(runs);
388-
commit
389-
}
327+
use crate::commit::test_data::CommitBuilder;
328+
use crate::types::{Run, Step, Tag};
390329

391-
fn sample_handler() -> DataProviderHandler {
330+
fn sample_handler(commit: Commit) -> DataProviderHandler {
392331
DataProviderHandler {
393332
// Leak the commit object, since the Tonic server must have only 'static references.
394-
commit: Box::leak(Box::new(sample_commit())),
333+
commit: Box::leak(Box::new(commit)),
395334
}
396335
}
397336

398337
#[tokio::test]
399338
async fn test_list_plugins() {
400-
let handler = sample_handler();
339+
let commit = CommitBuilder::new()
340+
.scalars("train", "xent", |b| b.build())
341+
.build();
342+
let handler = sample_handler(commit);
401343
let req = Request::new(data::ListPluginsRequest {
402344
experiment_id: "123".to_string(),
403345
});
@@ -410,7 +352,14 @@ mod tests {
410352

411353
#[tokio::test]
412354
async fn test_list_runs() {
413-
let handler = sample_handler();
355+
let commit = CommitBuilder::new()
356+
.run("train", Some(1234.0))
357+
.run("test", Some(6234.0))
358+
.run("run_with_no_data", None)
359+
.scalars("train", "xent", |mut b| b.wall_time_start(1235.0).build())
360+
.scalars("test", "acc", |mut b| b.wall_time_start(6235.0).build())
361+
.build();
362+
let handler = sample_handler(commit);
414363
let req = Request::new(data::ListRunsRequest {
415364
experiment_id: "123".to_string(),
416365
});
@@ -457,7 +406,18 @@ mod tests {
457406

458407
#[tokio::test]
459408
async fn test_list_scalars() {
460-
let handler = sample_handler();
409+
let commit = CommitBuilder::new()
410+
.run("train", Some(1234.0))
411+
.run("test", Some(6234.0))
412+
.run("run_with_no_data", None)
413+
.scalars("train", "xent", |mut b| {
414+
b.wall_time_start(1235.0).step_start(0).len(3).build()
415+
})
416+
.scalars("test", "accuracy", |mut b| {
417+
b.wall_time_start(6235.0).step_start(0).len(3).build()
418+
})
419+
.build();
420+
let handler = sample_handler(commit);
461421
let req = Request::new(data::ListScalarsRequest {
462422
experiment_id: "123".to_string(),
463423
plugin_filter: Some(data::PluginFilter {
@@ -501,7 +461,17 @@ mod tests {
501461

502462
#[tokio::test]
503463
async fn test_read_scalars() {
504-
let handler = sample_handler();
464+
let commit = CommitBuilder::new()
465+
.scalars("train", "xent", |mut b| {
466+
b.len(3)
467+
.wall_time_start(1235.0)
468+
.step_start(0)
469+
.eval(|Step(i)| 0.5f32.powi(i as i32))
470+
.build()
471+
})
472+
.scalars("test", "xent", |b| b.build())
473+
.build();
474+
let handler = sample_handler(commit);
505475
let req = Request::new(data::ReadScalarsRequest {
506476
experiment_id: "123".to_string(),
507477
plugin_filter: Some(data::PluginFilter {
@@ -527,12 +497,12 @@ mod tests {
527497
let xent_data = &train_run[&Tag("xent".to_string())].data.as_ref().unwrap();
528498
assert_eq!(xent_data.step, vec![0, 1, 2]);
529499
assert_eq!(xent_data.wall_time, vec![1235.0, 1236.0, 1237.0]);
530-
assert_eq!(xent_data.value, vec![0.5, 0.25, 0.125]);
500+
assert_eq!(xent_data.value, vec![1.0, 0.5, 0.25]);
531501
}
532502

533503
#[tokio::test]
534504
async fn test_read_scalars_needs_downsample() {
535-
let handler = sample_handler();
505+
let handler = sample_handler(Commit::default());
536506
let req = Request::new(data::ReadScalarsRequest {
537507
experiment_id: "123".to_string(),
538508
plugin_filter: Some(data::PluginFilter {
@@ -550,7 +520,11 @@ mod tests {
550520

551521
#[tokio::test]
552522
async fn test_read_scalars_downsample_zero_okay() {
553-
let handler = sample_handler();
523+
let commit = CommitBuilder::new()
524+
.scalars("train", "xent", |b| b.build())
525+
.scalars("test", "xent", |b| b.build())
526+
.build();
527+
let handler = sample_handler(commit);
554528
let req = Request::new(data::ReadScalarsRequest {
555529
experiment_id: "123".to_string(),
556530
plugin_filter: Some(data::PluginFilter {

0 commit comments

Comments
 (0)