Skip to content

rust: implement ListRuns RPC #4385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
84d6175
rust: add `data_compat` module
wchargin Nov 17, 2020
1459d88
rust: use `thiserror` for error types
wchargin Nov 17, 2020
f52200f
rust: create `types` module with primitives like `Step`
wchargin Nov 18, 2020
5e16ea3
Merge remote-tracking branches 'origin/wchargin-rust-data-compat', 'o…
wchargin Nov 18, 2020
997ddfa
[update patch]
wchargin Nov 18, 2020
103ccbd
[update diffbase]
wchargin Nov 18, 2020
734d780
[update diffbase]
wchargin Nov 18, 2020
d7fddfa
rust: promote event values to `data_compat`
wchargin Nov 18, 2020
76bcca2
[update diffbase]
wchargin Nov 19, 2020
9ba9726
[update patch]
wchargin Nov 19, 2020
aef800e
[update diffbase]
wchargin Nov 19, 2020
dd3e955
rust: define shared `Commit` type
wchargin Nov 19, 2020
0e1a45b
rust: commit changes after run loading
wchargin Nov 19, 2020
532ea1a
[update patch]
wchargin Nov 19, 2020
c577558
[update diffbase]
wchargin Nov 19, 2020
8b11660
[update diffbase]
wchargin Nov 19, 2020
b8121ac
[update diffbase]
wchargin Nov 19, 2020
f5d7542
[update patch]
wchargin Nov 19, 2020
e7a6f6b
[update diffbase]
wchargin Nov 19, 2020
b41f47f
[update patch]
wchargin Nov 19, 2020
3c7e13c
[update diffbase]
wchargin Nov 19, 2020
587e1af
[update diffbase]
wchargin Nov 19, 2020
ec0cf2c
[update diffbase]
wchargin Nov 19, 2020
d9f9f19
[update diffbase]
wchargin Nov 19, 2020
d39710a
[update patch]
wchargin Nov 19, 2020
5025feb
[update diffbase]
wchargin Nov 21, 2020
6480fdf
[update patch]
wchargin Nov 21, 2020
d1aa9a4
[update diffbase]
wchargin Nov 23, 2020
4c28c91
[update patch]
wchargin Nov 23, 2020
d2b3bf4
[update diffbase]
wchargin Nov 23, 2020
d5f2ae0
[update patch]
wchargin Nov 23, 2020
140aae1
rust: discover event files in log directories
wchargin Nov 24, 2020
f7a7cc0
rust: reload data in log directories
wchargin Nov 24, 2020
6a38bb6
[update diffbase]
wchargin Nov 24, 2020
5bedfd8
[update patch]
wchargin Nov 24, 2020
293f887
[update diffbase]
wchargin Nov 24, 2020
ed39120
rust: implement `ListRuns` RPC
wchargin Nov 24, 2020
cf7d310
[update patch]
wchargin Nov 24, 2020
d96e53d
[update diffbase]
wchargin Nov 24, 2020
1c625f0
[update diffbase]
wchargin Nov 24, 2020
46ac8fd
[update diffbase]
wchargin Nov 24, 2020
34ce4d2
[update patch]
wchargin Nov 24, 2020
8e58aca
[update diffbase]
wchargin Nov 24, 2020
758ae4d
[update diffbase]
wchargin Nov 24, 2020
17eeeb7
[update diffbase]
wchargin Nov 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion tensorboard/data/server/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

use std::path::PathBuf;
use std::time::{Duration, Instant};
use tonic::transport::Server;

use rustboard_core::commit::Commit;
use rustboard_core::logdir::LogdirLoader;
use rustboard_core::proto::tensorboard::data::tensor_board_data_provider_server::TensorBoardDataProviderServer;
use rustboard_core::server::DataProviderHandler;

const RELOAD_INTERVAL: Duration = Duration::from_secs(5);

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::0]:6806".parse::<std::net::SocketAddr>()?;
let handler = DataProviderHandler;
let logdir = match std::env::args_os().nth(1) {
Some(d) => PathBuf::from(d),
None => {
eprintln!("fatal: specify logdir as first command-line argument");
std::process::exit(1);
}
};

// Leak the commit object, since the Tonic server must have only 'static references. This only
// leaks the outer commit structure (of constant size), not the pointers to the actual data.
let commit: &'static Commit = Box::leak(Box::new(Commit::new()));
std::thread::spawn(move || {
let mut loader = LogdirLoader::new(commit, logdir);
loop {
eprintln!("beginning load cycle");
let start = Instant::now();
loader.reload();
let end = Instant::now();
eprintln!("finished load cycle ({:?})", end - start);
std::thread::sleep(RELOAD_INTERVAL);
}
});

let handler = DataProviderHandler { commit };
Server::builder()
.add_service(TensorBoardDataProviderServer::new(handler))
.serve(addr)
Expand Down
3 changes: 2 additions & 1 deletion tensorboard/data/server/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ impl RunLoader {
}

fn commit_all(&mut self, run_data: &RwLock<commit::RunData>) {
let mut run = run_data.write().expect("acquiring run data lock");
let mut run = run_data.write().expect("acquiring tags lock");
run.start_time = self.start_time;
for (tag, ts) in &mut self.time_series {
ts.commit(tag, &mut *run);
}
Expand Down
147 changes: 133 additions & 14 deletions tensorboard/data/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@ limitations under the License.
==============================================================================*/

use futures_core::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{RwLock, RwLockReadGuard};
use tonic::{Request, Response, Status};

use crate::commit::{self, Commit};
use crate::proto::tensorboard::data;
use crate::types::{Run, WallTime};
use data::tensor_board_data_provider_server::TensorBoardDataProvider;

/// Data provider gRPC service implementation.
#[derive(Debug)]
pub struct DataProviderHandler;
pub struct DataProviderHandler {
pub commit: &'static Commit,
}

impl DataProviderHandler {
/// Obtains a read-lock to `self.commit.runs`, or fails with `Status::internal`.
fn read_runs(&self) -> Result<RwLockReadGuard<HashMap<Run, RwLock<commit::RunData>>>, Status> {
self.commit
.runs
.read()
.map_err(|_| Status::internal("failed to read commit.runs"))
}
}

const FAKE_START_TIME: f64 = 1605752017.0;

Expand All @@ -44,12 +60,33 @@ impl TensorBoardDataProvider for DataProviderHandler {
&self,
_request: Request<data::ListRunsRequest>,
) -> Result<Response<data::ListRunsResponse>, Status> {
let mut res: data::ListRunsResponse = Default::default();
res.runs.push(data::Run {
name: "train".to_string(),
start_time: FAKE_START_TIME,
let runs = self.read_runs()?;

// Buffer up started runs to sort by wall time. Keep `WallTime` rather than projecting down
// to f64 so that we're guaranteed that they're non-NaN and can sort them.
let mut results: Vec<(Run, WallTime)> = Vec::with_capacity(runs.len());
for (run, data) in runs.iter() {
let data = data
.read()
.map_err(|_| Status::internal(format!("failed to read run data for {:?}", run)))?;
if let Some(start_time) = data.start_time {
results.push((run.clone(), start_time));
}
}
results.sort_by_key(|&(_, start_time)| start_time);
drop(runs); // release lock a bit earlier

let res = data::ListRunsResponse {
runs: results
.into_iter()
.map(|(Run(name), start_time)| data::Run {
name,
start_time: start_time.into(),
..Default::default()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gather the point of this line is to be future-proof in case we add more fields to Run? Is there any mechanism that enforces this by default or is just a convention we'll need to uphold ourselves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah—specifically, so that proto updates that add new fields don’t
introduce compilation errors.

I was considering adding a tiny “style guide” to DEVELOPMENT.md with
things like this, so I’ll probably go ahead and do that.

There is a #[non_exhaustive] attribute that you can add to structs to
indicate that they may have more fields later, preventing callers
outside the crate from initializing them directly but still permitting
them to access public fields. This doesn’t suffice because (a) we’re in
the same crate (though we could change that) and (b) it doesn’t permit
functional record updates (..Default::default()):

error[E0639]: cannot create non-exhaustive struct using struct expression
 --> main.rs:3:13
  |
3 |       let s = sub::S {
  |  _____________^
4 | |         a: 1,
5 | |         b: 2,
6 | |         ..Default::default()
7 | |     };
  | |_____^

The RFC threads have a lot of comments, and from skimming them it’s not
obvious to me whether this was “intended, by design” or merely “artifact
of how #[non_exhaustive] desugars to a hidden private field”. (It was
definitely known at the time.)

Another potential check for this would be in Clippy: “lint if you
construct a value at a type that implements prost::Message without
using a functional record update”. There is no such check currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through the next PR, it's really rather noisy to have the ..Default::default() everywhere we create a proto, especially in tests :/ Given the fact that we control these protos ourselves, I wonder if on net maybe it's easier for places where we don't expect to add fields to just relax this and accept that we'll need to make modifications to the code if we update the proto?

The other alternative that came to mind would be to initialize structs to default-empty, then mutate to populate fields; needing mutation is a downside although I think in some of the more convoluted multi-layer struct cases this may be significantly more concise.

I can see why neither of these is ideal, it just feels distracting IMO to have every single proto struct need the extra syntax. Would be nice if there was some way to specify on the struct definition that they can be initialized a subset of fields if the rest are Default (basically like optional arguments), or if at least maybe the functional update syntax supported plain .. as sugar for ..Default::default().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just seeing this. For posterity, further discussion at #4399.

})
.collect(),
..Default::default()
});
};
Ok(Response::new(res))
}

Expand Down Expand Up @@ -139,9 +176,80 @@ impl TensorBoardDataProvider for DataProviderHandler {
mod tests {
use super::*;

use crate::commit::{ScalarValue, TimeSeries};
use crate::data_compat;
use crate::proto::tensorboard as pb;
use crate::reservoir::StageReservoir;
use crate::types::{Run, Step, Tag, WallTime};

/// Creates a commit with some test data.
fn sample_commit() -> Commit {
let commit = Commit::new();

let mut runs = commit.runs.write().unwrap();

fn scalar_series(points: Vec<(Step, WallTime, f64)>) -> TimeSeries<ScalarValue> {
use pb::summary::value::Value::SimpleValue;
let mut ts = commit::TimeSeries::new(
data_compat::SummaryValue(Box::new(SimpleValue(0.0))).initial_metadata(None),
);
let mut rsv = StageReservoir::new(points.len());
for (step, wall_time, value) in points {
rsv.offer(step, (wall_time, Ok(commit::ScalarValue(value))));
}
rsv.commit(&mut ts.basin);
ts
}

let mut train = runs
.entry(Run("train".to_string()))
.or_default()
.write()
.unwrap();
train.start_time = Some(WallTime::new(1234.0).unwrap());
train.scalars.insert(
Tag("xent".to_string()),
scalar_series(vec![
(Step(0), WallTime::new(1235.0).unwrap(), 0.5),
(Step(1), WallTime::new(1236.0).unwrap(), 0.25),
(Step(2), WallTime::new(1237.0).unwrap(), 0.125),
]),
);
drop(train);

let mut test = runs
.entry(Run("test".to_string()))
.or_default()
.write()
.unwrap();
test.start_time = Some(WallTime::new(6234.0).unwrap());
test.scalars.insert(
Tag("accuracy".to_string()),
scalar_series(vec![
(Step(0), WallTime::new(6235.0).unwrap(), 0.125),
(Step(1), WallTime::new(6236.0).unwrap(), 0.25),
(Step(2), WallTime::new(6237.0).unwrap(), 0.5),
]),
);
drop(test);

// An run with no start time or data: should not show up in results.
runs.entry(Run("empty".to_string())).or_default();

drop(runs);
commit
}

fn sample_handler() -> DataProviderHandler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it up to you (and would also be fine to do in a follow up or punt for now), but I try to avoid data fixtures unless it's really pretty complicated or expensive to set up the right data. In this case it seems like it wouldn't be so bad to 1) create a commit, 2) explicitly populate it with the data we want via test helpers, and then 3) wrap a handler around it and do the assertions. I just find the assertions to be a lot more meaningful when I can see the data they depend on in the same test method vs having to go dig up the fixture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay; noted, thanks!. I’m certainly sympathetic to this. I think I might
punt it for now since a shared fixture suffices for this set of changes.
But I’m happy to go back and reswizzle things after that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #4433.

DataProviderHandler {
// Leak the commit object, since the Tonic server must have only 'static references.
commit: Box::leak(Box::new(sample_commit())),
}
}

#[tokio::test]
async fn test_list_plugins() {
let handler = DataProviderHandler;
let handler = sample_handler();
let req = Request::new(data::ListPluginsRequest {
experiment_id: "123".to_string(),
..Default::default()
Expand All @@ -155,21 +263,32 @@ mod tests {

#[tokio::test]
async fn test_list_runs() {
let handler = DataProviderHandler;
let handler = sample_handler();
let req = Request::new(data::ListRunsRequest {
experiment_id: "123".to_string(),
..Default::default()
});
let res = handler.list_runs(req).await.unwrap().into_inner();
assert_eq!(res.runs.len(), 1);
let run = &res.runs[0];
assert_eq!(run.start_time, FAKE_START_TIME);
assert_eq!(run.name, "train");
assert_eq!(
res.runs,
vec![
data::Run {
name: "train".to_string(),
start_time: 1234.0,
..Default::default()
},
data::Run {
name: "test".to_string(),
start_time: 6234.0,
..Default::default()
},
]
);
}

#[tokio::test]
async fn test_list_scalars() {
let handler = DataProviderHandler;
let handler = sample_handler();
let req = Request::new(data::ListScalarsRequest {
experiment_id: "123".to_string(),
plugin_filter: Some(data::PluginFilter {
Expand All @@ -186,7 +305,7 @@ mod tests {

#[tokio::test]
async fn test_read_scalars() {
let handler = DataProviderHandler;
let handler = sample_handler();
let req = Request::new(data::ReadScalarsRequest {
experiment_id: "123".to_string(),
plugin_filter: Some(data::PluginFilter {
Expand Down