diff --git a/tensorboard/data/server/commit.rs b/tensorboard/data/server/commit.rs index dd12b82d5c..fb8574fed9 100644 --- a/tensorboard/data/server/commit.rs +++ b/tensorboard/data/server/commit.rs @@ -136,3 +136,166 @@ mod tests { ); } } + +/// Utilities for constructing commits with test data. +// +// Not `#[cfg(test)]` because we have a doctest: +// +pub mod test_data { + use super::*; + use crate::data_compat; + use crate::reservoir::StageReservoir; + + #[derive(Default)] + pub struct CommitBuilder(Commit); + + impl CommitBuilder { + /// Creates a new builder for an empty commit. + pub fn new() -> Self { + Self::default() + } + + /// Ensures that data for a run exists, and update it. + /// + /// This takes a callback because we can't just return a mutable reference, since there are + /// nested `RwLockWriteGuard`s. + fn with_run_data(&self, run: Run, update: impl FnOnce(&mut RunData)) { + let mut runs = self.0.runs.write().expect("runs.write"); + let mut run_data = runs + .entry(run) + .or_default() + .write() + .expect("runs[run].write"); + update(&mut run_data); + } + + /// Adds a scalar time series, creating the run if it doesn't exist, and setting its start + /// time if unset. + /// + /// # Examples + /// + /// ``` + /// use rustboard_core::commit::{test_data::CommitBuilder, Commit}; + /// use rustboard_core::types::Step; + /// + /// let my_commit: Commit = CommitBuilder::new() + /// .scalars("train", "xent", |mut b| { + /// b.eval(|Step(i)| 1.0 / (i + 1) as f32).build() + /// }) + /// .build(); + /// ``` + pub fn scalars( + self, + run: &str, + tag: &str, + build: impl FnOnce(ScalarTimeSeriesBuilder) -> TimeSeries, + ) -> Self { + self.with_run_data(Run(run.to_string()), |run_data| { + let time_series = build(ScalarTimeSeriesBuilder::default()); + if let (None, Some((_step, wall_time, _value))) = + (run_data.start_time, time_series.valid_values().next()) + { + run_data.start_time = Some(wall_time); + } + run_data.scalars.insert(Tag(tag.to_string()), time_series); + }); + self + } + + /// Ensures that a run is present and sets its start time. + /// + /// If you don't care about the start time and the run is going to have data, anyway, you + /// can just add the data directly, which will create the run as a side effect. + /// + /// # Panics + /// + /// If `start_time` represents an invalid wall time (i.e., `start_time` is `Some(wt)` but + /// `WallTime::new(wt)` is `None`). + pub fn run(self, run: &str, start_time: Option) -> Self { + self.with_run_data(Run(run.to_string()), |run_data| { + run_data.start_time = start_time.map(|f| WallTime::new(f).unwrap()); + }); + self + } + + /// Consumes this `CommitBuilder` and returns the underlying commit. + pub fn build(self) -> Commit { + self.0 + } + } + + pub struct ScalarTimeSeriesBuilder { + /// Initial step. Increments by `1` for each point. + step_start: Step, + /// Initial wall time. Increments by `1.0` for each point. + wall_time_start: WallTime, + /// Number of points in this time series. + len: u64, + /// Custom summary metadata. Leave `None` to use default. + metadata: Option>, + /// Scalar evaluation function, called for each point in the series. + /// + /// By default, this maps every step to `0.0`. + eval: Box f32>, + } + + impl Default for ScalarTimeSeriesBuilder { + fn default() -> Self { + ScalarTimeSeriesBuilder { + step_start: Step(0), + wall_time_start: WallTime::new(0.0).unwrap(), + len: 1, + metadata: None, + eval: Box::new(|_| 0.0), + } + } + } + + impl ScalarTimeSeriesBuilder { + pub fn step_start(&mut self, raw_step: i64) -> &mut Self { + self.step_start = Step(raw_step); + self + } + pub fn wall_time_start(&mut self, raw_wall_time: f64) -> &mut Self { + self.wall_time_start = WallTime::new(raw_wall_time).unwrap(); + self + } + pub fn len(&mut self, len: u64) -> &mut Self { + self.len = len; + self + } + pub fn metadata(&mut self, metadata: Option>) -> &mut Self { + self.metadata = metadata; + self + } + pub fn eval(&mut self, eval: impl Fn(Step) -> f32 + 'static) -> &mut Self { + self.eval = Box::new(eval); + self + } + + /// Constructs a scalar time series from the state of this builder. + /// + /// # Panics + /// + /// If the wall time of a point would overflow to be infinite. + pub fn build(&self) -> TimeSeries { + let metadata = self.metadata.clone().unwrap_or_else(|| { + let sample_point = Box::new(pb::summary::value::Value::SimpleValue(0.0)); + data_compat::SummaryValue(sample_point).initial_metadata(None) + }); + let mut time_series = TimeSeries::new(metadata); + + let mut rsv = StageReservoir::new(self.len as usize); + for i in 0..self.len { + let step = Step(self.step_start.0 + i as i64); + let wall_time = + WallTime::new(f64::from(self.wall_time_start) + (i as f64)).unwrap(); + let value = (self.eval)(step); + rsv.offer(step, (wall_time, Ok(ScalarValue(value)))); + } + rsv.commit(&mut time_series.basin); + + time_series + } + } +} diff --git a/tensorboard/data/server/server.rs b/tensorboard/data/server/server.rs index 2ad5e0bd9e..c62d8e18a3 100644 --- a/tensorboard/data/server/server.rs +++ b/tensorboard/data/server/server.rs @@ -324,80 +324,22 @@ mod tests { use super::*; use tonic::Code; - use crate::commit::{ScalarValue, TimeSeries}; - use crate::data_compat; - use crate::proto::tensorboard as pb; - use crate::reservoir::StageReservoir; - use crate::types::{Run, Step, Tag, WallTime}; - - /// Creates a commit with some test data. - fn sample_commit() -> Commit { - let commit = Commit::new(); - - let mut runs = commit.runs.write().unwrap(); - - fn scalar_series(points: Vec<(Step, WallTime, f32)>) -> TimeSeries { - use pb::summary::value::Value::SimpleValue; - let mut ts = commit::TimeSeries::new( - data_compat::SummaryValue(Box::new(SimpleValue(0.0))).initial_metadata(None), - ); - let mut rsv = StageReservoir::new(points.len()); - for (step, wall_time, value) in points { - rsv.offer(step, (wall_time, Ok(commit::ScalarValue(value)))); - } - rsv.commit(&mut ts.basin); - ts - } - - let mut train = runs - .entry(Run("train".to_string())) - .or_default() - .write() - .unwrap(); - train.start_time = Some(WallTime::new(1234.0).unwrap()); - train.scalars.insert( - Tag("xent".to_string()), - scalar_series(vec![ - (Step(0), WallTime::new(1235.0).unwrap(), 0.5), - (Step(1), WallTime::new(1236.0).unwrap(), 0.25), - (Step(2), WallTime::new(1237.0).unwrap(), 0.125), - ]), - ); - drop(train); - - let mut test = runs - .entry(Run("test".to_string())) - .or_default() - .write() - .unwrap(); - test.start_time = Some(WallTime::new(6234.0).unwrap()); - test.scalars.insert( - Tag("accuracy".to_string()), - scalar_series(vec![ - (Step(0), WallTime::new(6235.0).unwrap(), 0.125), - (Step(1), WallTime::new(6236.0).unwrap(), 0.25), - (Step(2), WallTime::new(6237.0).unwrap(), 0.5), - ]), - ); - drop(test); - - // An run with no start time or data: should not show up in results. - runs.entry(Run("empty".to_string())).or_default(); - - drop(runs); - commit - } + use crate::commit::test_data::CommitBuilder; + use crate::types::{Run, Step, Tag}; - fn sample_handler() -> DataProviderHandler { + fn sample_handler(commit: Commit) -> DataProviderHandler { DataProviderHandler { // Leak the commit object, since the Tonic server must have only 'static references. - commit: Box::leak(Box::new(sample_commit())), + commit: Box::leak(Box::new(commit)), } } #[tokio::test] async fn test_list_plugins() { - let handler = sample_handler(); + let commit = CommitBuilder::new() + .scalars("train", "xent", |b| b.build()) + .build(); + let handler = sample_handler(commit); let req = Request::new(data::ListPluginsRequest { experiment_id: "123".to_string(), }); @@ -410,7 +352,14 @@ mod tests { #[tokio::test] async fn test_list_runs() { - let handler = sample_handler(); + let commit = CommitBuilder::new() + .run("train", Some(1234.0)) + .run("test", Some(6234.0)) + .run("run_with_no_data", None) + .scalars("train", "xent", |mut b| b.wall_time_start(1235.0).build()) + .scalars("test", "acc", |mut b| b.wall_time_start(6235.0).build()) + .build(); + let handler = sample_handler(commit); let req = Request::new(data::ListRunsRequest { experiment_id: "123".to_string(), }); @@ -457,7 +406,18 @@ mod tests { #[tokio::test] async fn test_list_scalars() { - let handler = sample_handler(); + let commit = CommitBuilder::new() + .run("train", Some(1234.0)) + .run("test", Some(6234.0)) + .run("run_with_no_data", None) + .scalars("train", "xent", |mut b| { + b.wall_time_start(1235.0).step_start(0).len(3).build() + }) + .scalars("test", "accuracy", |mut b| { + b.wall_time_start(6235.0).step_start(0).len(3).build() + }) + .build(); + let handler = sample_handler(commit); let req = Request::new(data::ListScalarsRequest { experiment_id: "123".to_string(), plugin_filter: Some(data::PluginFilter { @@ -501,7 +461,17 @@ mod tests { #[tokio::test] async fn test_read_scalars() { - let handler = sample_handler(); + let commit = CommitBuilder::new() + .scalars("train", "xent", |mut b| { + b.len(3) + .wall_time_start(1235.0) + .step_start(0) + .eval(|Step(i)| 0.5f32.powi(i as i32)) + .build() + }) + .scalars("test", "xent", |b| b.build()) + .build(); + let handler = sample_handler(commit); let req = Request::new(data::ReadScalarsRequest { experiment_id: "123".to_string(), plugin_filter: Some(data::PluginFilter { @@ -527,12 +497,12 @@ mod tests { let xent_data = &train_run[&Tag("xent".to_string())].data.as_ref().unwrap(); assert_eq!(xent_data.step, vec![0, 1, 2]); assert_eq!(xent_data.wall_time, vec![1235.0, 1236.0, 1237.0]); - assert_eq!(xent_data.value, vec![0.5, 0.25, 0.125]); + assert_eq!(xent_data.value, vec![1.0, 0.5, 0.25]); } #[tokio::test] async fn test_read_scalars_needs_downsample() { - let handler = sample_handler(); + let handler = sample_handler(Commit::default()); let req = Request::new(data::ReadScalarsRequest { experiment_id: "123".to_string(), plugin_filter: Some(data::PluginFilter { @@ -550,7 +520,11 @@ mod tests { #[tokio::test] async fn test_read_scalars_downsample_zero_okay() { - let handler = sample_handler(); + let commit = CommitBuilder::new() + .scalars("train", "xent", |b| b.build()) + .scalars("test", "xent", |b| b.build()) + .build(); + let handler = sample_handler(commit); let req = Request::new(data::ReadScalarsRequest { experiment_id: "123".to_string(), plugin_filter: Some(data::PluginFilter {