Skip to content

Commit a917e21

Browse files
committed
rust: replace static fixture with test helpers
Summary: In #4385, @nfelt suggested using test helpers to easily populate a commit rather than using a single fixture across all test cases. This patch implements such test helpers and uses them in the server tests. The net diff to the test code (not the helpers) is −32 lines. Test Plan: Existing tests pass. The new test helpers mostly don’t have explicit tests, but that’s okay since they’re exercised by the main test code. wchargin-branch: rust-test-data-builders wchargin-source: 31164b4d2b4a1fb9bbabc584249e1d3715c17a3e
1 parent 6c7b0b2 commit a917e21

File tree

2 files changed

+199
-72
lines changed

2 files changed

+199
-72
lines changed

tensorboard/data/server/commit.rs

+160-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub struct DataLoss;
105105
#[derive(Debug, Copy, Clone, PartialEq)]
106106
pub struct ScalarValue(pub f32);
107107

108-
#[cfg(test)]
108+
#[cfg(any(test, doctest))]
109109
mod tests {
110110
use super::*;
111111

@@ -136,3 +136,162 @@ mod tests {
136136
);
137137
}
138138
}
139+
140+
/// Utilities for constructing commits with test data.
141+
pub mod test_data {
142+
use super::*;
143+
use crate::data_compat;
144+
use crate::reservoir::StageReservoir;
145+
146+
#[derive(Default)]
147+
pub struct CommitBuilder(Commit);
148+
149+
impl CommitBuilder {
150+
/// Creates a new builder for an empty commit.
151+
pub fn new() -> Self {
152+
Self::default()
153+
}
154+
155+
/// Ensures that data for a run exists, and update it.
156+
///
157+
/// This takes a callback because we can't just return a mutable reference, since there are
158+
/// nested `RwLockWriteGuard`s.
159+
fn with_run_data(&self, run: Run, update: impl FnOnce(&mut RunData)) {
160+
let mut runs = self.0.runs.write().expect("runs.write");
161+
let mut run_data = runs
162+
.entry(run)
163+
.or_default()
164+
.write()
165+
.expect("runs[run].write");
166+
update(&mut run_data);
167+
}
168+
169+
/// Adds a scalar time series, creating the run if it doesn't exist, and setting its start
170+
/// time if unset.
171+
///
172+
/// # Examples
173+
///
174+
/// ```
175+
/// use rustboard_core::commit::{test_data::CommitBuilder, Commit};
176+
/// use rustboard_core::types::Step;
177+
///
178+
/// let my_commit: Commit = CommitBuilder::new()
179+
/// .scalars("train", "xent", |mut b| {
180+
/// b.eval(|Step(i)| 1.0 / (i + 1) as f32).build()
181+
/// })
182+
/// .build();
183+
/// ```
184+
pub fn scalars(
185+
self,
186+
run: &str,
187+
tag: &str,
188+
build: impl FnOnce(ScalarTimeSeriesBuilder) -> TimeSeries<ScalarValue>,
189+
) -> Self {
190+
self.with_run_data(Run(run.to_string()), |run_data| {
191+
let time_series = build(ScalarTimeSeriesBuilder::default());
192+
match (run_data.start_time, time_series.valid_values().nth(0)) {
193+
(None, Some((_step, wt, _value))) => run_data.start_time = Some(wt),
194+
_ => (),
195+
};
196+
run_data.scalars.insert(Tag(tag.to_string()), time_series);
197+
});
198+
self
199+
}
200+
201+
/// Ensures that a run is present and sets its start time.
202+
///
203+
/// If you don't care about the start time and the run is going to have data, anyway, you
204+
/// can just add the data directly, which will create the run as a side effect.
205+
///
206+
/// # Panics
207+
///
208+
/// If `start_time` represents an invalid wall time (i.e., `start_time` is `Some(wt)` but
209+
/// `WallTime::new(wt)` is `None`).
210+
pub fn run(self, run: &str, start_time: Option<f64>) -> Self {
211+
self.with_run_data(Run(run.to_string()), |run_data| {
212+
run_data.start_time = start_time.map(|f| WallTime::new(f).unwrap());
213+
});
214+
self
215+
}
216+
217+
/// Consumes this `CommitBuilder` and returns the underlying commit.
218+
pub fn build(self) -> Commit {
219+
self.0
220+
}
221+
}
222+
223+
pub struct ScalarTimeSeriesBuilder {
224+
/// Initial step. Increments by `1` for each point.
225+
step_start: Step,
226+
/// Initial wall time. Increments by `1.0` for each point.
227+
wall_time_start: WallTime,
228+
/// Number of points in this time series.
229+
len: u64,
230+
/// Custom summary metadata. Leave `None` to use default.
231+
metadata: Option<Box<pb::SummaryMetadata>>,
232+
/// Scalar evaluation function, called for each point in the series.
233+
///
234+
/// By default, this maps step `i` to `0.5 ** (i + 1)`.
235+
eval: Box<dyn Fn(Step) -> f32>,
236+
}
237+
238+
impl Default for ScalarTimeSeriesBuilder {
239+
fn default() -> Self {
240+
ScalarTimeSeriesBuilder {
241+
step_start: Step(0),
242+
wall_time_start: WallTime::new(1235.0).unwrap(),
243+
len: 3,
244+
metadata: None,
245+
eval: Box::new(|Step(i)| 0.5f32.powi((i + 1) as i32)),
246+
}
247+
}
248+
}
249+
250+
impl ScalarTimeSeriesBuilder {
251+
pub fn step_start(&mut self, raw_step: i64) -> &mut Self {
252+
self.step_start = Step(raw_step);
253+
self
254+
}
255+
pub fn wall_time_start(&mut self, raw_wall_time: f64) -> &mut Self {
256+
self.wall_time_start = WallTime::new(raw_wall_time).unwrap();
257+
self
258+
}
259+
pub fn len(&mut self, len: u64) -> &mut Self {
260+
self.len = len;
261+
self
262+
}
263+
pub fn metadata(&mut self, metadata: Option<Box<pb::SummaryMetadata>>) -> &mut Self {
264+
self.metadata = metadata;
265+
self
266+
}
267+
pub fn eval(&mut self, eval: impl Fn(Step) -> f32 + 'static) -> &mut Self {
268+
self.eval = Box::new(eval);
269+
self
270+
}
271+
272+
/// Constructs a scalar time series from the state of this builder.
273+
///
274+
/// # Panics
275+
///
276+
/// If the wall time of a point would overflow to be infinite.
277+
pub fn build(&self) -> TimeSeries<ScalarValue> {
278+
let metadata = self.metadata.clone().unwrap_or_else(|| {
279+
let sample_point = Box::new(pb::summary::value::Value::SimpleValue(0.0));
280+
data_compat::SummaryValue(sample_point).initial_metadata(None)
281+
});
282+
let mut time_series = TimeSeries::new(metadata);
283+
284+
let mut rsv = StageReservoir::new(self.len as usize);
285+
for i in 0..self.len {
286+
let step = Step(self.step_start.0 + i as i64);
287+
let wall_time =
288+
WallTime::new(f64::from(self.wall_time_start) + (i as f64)).unwrap();
289+
let value = (self.eval)(step);
290+
rsv.offer(step, (wall_time, Ok(ScalarValue(value))));
291+
}
292+
rsv.commit(&mut time_series.basin);
293+
294+
time_series
295+
}
296+
}
297+
}

tensorboard/data/server/server.rs

+39-71
Original file line numberDiff line numberDiff line change
@@ -324,80 +324,22 @@ mod tests {
324324
use super::*;
325325
use tonic::Code;
326326

327-
use crate::commit::{ScalarValue, TimeSeries};
328-
use crate::data_compat;
329-
use crate::proto::tensorboard as pb;
330-
use crate::reservoir::StageReservoir;
331-
use crate::types::{Run, Step, Tag, WallTime};
332-
333-
/// Creates a commit with some test data.
334-
fn sample_commit() -> Commit {
335-
let commit = Commit::new();
336-
337-
let mut runs = commit.runs.write().unwrap();
338-
339-
fn scalar_series(points: Vec<(Step, WallTime, f32)>) -> TimeSeries<ScalarValue> {
340-
use pb::summary::value::Value::SimpleValue;
341-
let mut ts = commit::TimeSeries::new(
342-
data_compat::SummaryValue(Box::new(SimpleValue(0.0))).initial_metadata(None),
343-
);
344-
let mut rsv = StageReservoir::new(points.len());
345-
for (step, wall_time, value) in points {
346-
rsv.offer(step, (wall_time, Ok(commit::ScalarValue(value))));
347-
}
348-
rsv.commit(&mut ts.basin);
349-
ts
350-
}
351-
352-
let mut train = runs
353-
.entry(Run("train".to_string()))
354-
.or_default()
355-
.write()
356-
.unwrap();
357-
train.start_time = Some(WallTime::new(1234.0).unwrap());
358-
train.scalars.insert(
359-
Tag("xent".to_string()),
360-
scalar_series(vec![
361-
(Step(0), WallTime::new(1235.0).unwrap(), 0.5),
362-
(Step(1), WallTime::new(1236.0).unwrap(), 0.25),
363-
(Step(2), WallTime::new(1237.0).unwrap(), 0.125),
364-
]),
365-
);
366-
drop(train);
367-
368-
let mut test = runs
369-
.entry(Run("test".to_string()))
370-
.or_default()
371-
.write()
372-
.unwrap();
373-
test.start_time = Some(WallTime::new(6234.0).unwrap());
374-
test.scalars.insert(
375-
Tag("accuracy".to_string()),
376-
scalar_series(vec![
377-
(Step(0), WallTime::new(6235.0).unwrap(), 0.125),
378-
(Step(1), WallTime::new(6236.0).unwrap(), 0.25),
379-
(Step(2), WallTime::new(6237.0).unwrap(), 0.5),
380-
]),
381-
);
382-
drop(test);
383-
384-
// An run with no start time or data: should not show up in results.
385-
runs.entry(Run("empty".to_string())).or_default();
386-
387-
drop(runs);
388-
commit
389-
}
327+
use crate::commit::test_data::CommitBuilder;
328+
use crate::types::{Run, Tag};
390329

391-
fn sample_handler() -> DataProviderHandler {
330+
fn sample_handler(commit: Commit) -> DataProviderHandler {
392331
DataProviderHandler {
393332
// Leak the commit object, since the Tonic server must have only 'static references.
394-
commit: Box::leak(Box::new(sample_commit())),
333+
commit: Box::leak(Box::new(commit)),
395334
}
396335
}
397336

398337
#[tokio::test]
399338
async fn test_list_plugins() {
400-
let handler = sample_handler();
339+
let commit = CommitBuilder::new()
340+
.scalars("train", "xent", |b| b.build())
341+
.build();
342+
let handler = sample_handler(commit);
401343
let req = Request::new(data::ListPluginsRequest {
402344
experiment_id: "123".to_string(),
403345
});
@@ -410,7 +352,14 @@ mod tests {
410352

411353
#[tokio::test]
412354
async fn test_list_runs() {
413-
let handler = sample_handler();
355+
let commit = CommitBuilder::new()
356+
.run("train", Some(1234.0))
357+
.run("test", Some(6234.0))
358+
.run("run_with_no_data", None)
359+
.scalars("train", "xent", |mut b| b.wall_time_start(1235.0).build())
360+
.scalars("test", "acc", |mut b| b.wall_time_start(6235.0).build())
361+
.build();
362+
let handler = sample_handler(commit);
414363
let req = Request::new(data::ListRunsRequest {
415364
experiment_id: "123".to_string(),
416365
});
@@ -457,7 +406,18 @@ mod tests {
457406

458407
#[tokio::test]
459408
async fn test_list_scalars() {
460-
let handler = sample_handler();
409+
let commit = CommitBuilder::new()
410+
.run("train", Some(1234.0))
411+
.run("test", Some(6234.0))
412+
.run("run_with_no_data", None)
413+
.scalars("train", "xent", |mut b| {
414+
b.wall_time_start(1235.0).len(3).build()
415+
})
416+
.scalars("test", "accuracy", |mut b| {
417+
b.wall_time_start(6235.0).len(3).build()
418+
})
419+
.build();
420+
let handler = sample_handler(commit);
461421
let req = Request::new(data::ListScalarsRequest {
462422
experiment_id: "123".to_string(),
463423
plugin_filter: Some(data::PluginFilter {
@@ -501,7 +461,11 @@ mod tests {
501461

502462
#[tokio::test]
503463
async fn test_read_scalars() {
504-
let handler = sample_handler();
464+
let commit = CommitBuilder::new()
465+
.scalars("train", "xent", |b| b.build())
466+
.scalars("test", "xent", |b| b.build())
467+
.build();
468+
let handler = sample_handler(commit);
505469
let req = Request::new(data::ReadScalarsRequest {
506470
experiment_id: "123".to_string(),
507471
plugin_filter: Some(data::PluginFilter {
@@ -532,7 +496,7 @@ mod tests {
532496

533497
#[tokio::test]
534498
async fn test_read_scalars_needs_downsample() {
535-
let handler = sample_handler();
499+
let handler = sample_handler(Commit::default());
536500
let req = Request::new(data::ReadScalarsRequest {
537501
experiment_id: "123".to_string(),
538502
plugin_filter: Some(data::PluginFilter {
@@ -550,7 +514,11 @@ mod tests {
550514

551515
#[tokio::test]
552516
async fn test_read_scalars_downsample_zero_okay() {
553-
let handler = sample_handler();
517+
let commit = CommitBuilder::new()
518+
.scalars("train", "xent", |b| b.build())
519+
.scalars("test", "xent", |b| b.build())
520+
.build();
521+
let handler = sample_handler(commit);
554522
let req = Request::new(data::ReadScalarsRequest {
555523
experiment_id: "123".to_string(),
556524
plugin_filter: Some(data::PluginFilter {

0 commit comments

Comments
 (0)