Skip to content

Commit

Permalink
feat: implement derive macro for stages (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
scarmuega authored May 3, 2023
1 parent 7b2d357 commit 4bb48b5
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 142 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[workspace]

members = ["gasket"]
members = ["gasket", "gasket-derive"]
9 changes: 9 additions & 0 deletions gasket-derive/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/target


# Added by cargo
#
# already existing elements were commented out

#/target
Cargo.lock
22 changes: 22 additions & 0 deletions gasket-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "gasket-derive"
version = "0.3.0"
edition = "2021"
description = "Staged Event-Driven Architecture (SEDA) framework"
repository = "https://github.com/construkts/gasket-rs"
homepage = "https://github.com/construkts/gasket-rs"
documentation = "https://docs.rs/gasket"
license = "Apache-2.0"
readme = "README.md"
authors = ["Santiago Carmuega <[email protected]>"]

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0.56"
quote = "^1"
syn = "^2"

[dev-dependencies]
gasket = { version = "0.3.0", path = "../gasket" }
171 changes: 171 additions & 0 deletions gasket-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::ParseStream;
use syn::{parse_macro_input, Attribute, DataStruct, DeriveInput, Expr, Lit, MetaNameValue, Token};

use syn::Field;

fn expect_struct(input: &DeriveInput) -> syn::Result<&DataStruct> {
match &input.data {
syn::Data::Struct(x) => Ok(x),
_ => {
let err = syn::Error::new_spanned(&input.ident, "you need to derive from a struct");
Err(err)
}
}
}

fn has_attribute(attrs: &[Attribute], attr_name: &str) -> bool {
attrs
.iter()
.any(|attr| attr.meta.path().is_ident(attr_name))
}

fn fields_with_attribute<'a>(data: &'a DataStruct, attr_name: &'static str) -> Vec<&'a Field> {
data.fields
.iter()
.filter(|field| has_attribute(&field.attrs, attr_name))
.collect()
}

struct StageArgs {
name: Option<String>,
unit: syn::Type,
worker: syn::Type,
}

impl syn::parse::Parse for StageArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut name = None;
let mut unit = None;
let mut worker = None;

while !input.is_empty() {
let name_value: MetaNameValue = input.parse()?;
let MetaNameValue { path, value, .. } = name_value;

if path.is_ident("name") {
if let Expr::Lit(expr) = value {
if let Lit::Str(x) = expr.lit {
name = Some(x.value());
}
}
} else if path.is_ident("unit") {
if let Expr::Lit(expr) = value {
if let Lit::Str(x) = expr.lit {
unit = Some(x.parse()?);
}
}
} else if path.is_ident("worker") {
if let Expr::Lit(expr) = value {
if let Lit::Str(x) = expr.lit {
worker = Some(x.parse()?);
}
}
} else {
return Err(input.error("Unexpected argument"));
}

// Ignore commas between arguments
let _ = input.parse::<Option<Token![,]>>();
}

Ok(Self {
name,
unit: unit.ok_or_else(|| input.error("Missing unit type"))?,
worker: worker.ok_or_else(|| input.error("Missing worker type"))?,
})
}
}

fn matches_type(type_path: &syn::TypePath, target_type: &str) -> bool {
type_path
.path
.segments
.last()
.map_or(false, |segment| segment.ident == target_type)
}

fn expand_metric_registration(struct_: &DataStruct) -> syn::Result<Vec<TokenStream>> {
let metrics_code: syn::Result<Vec<_>> = fields_with_attribute(struct_, "metric")
.iter()
.map(|field| {
let field_name = field.ident.as_ref().unwrap();

match &field.ty {
syn::Type::Path(x) if matches_type(x, "Counter") => {
let q = quote! {
registry.track_counter(stringify!(#field_name), &self.#field_name);
};

Ok(q)
}
syn::Type::Path(x) if matches_type(x, "Gauge") => {
let q = quote! {
registry.track_gauge(stringify!(#field_name), &self.#field_name);
};

Ok(q)
}
_ => Err(syn::Error::new_spanned(
field,
"unknown return type for metric",
)),
}
})
.collect();

metrics_code
}

fn expand_stage_impl(input: DeriveInput) -> TokenStream {
let struct_ = match expect_struct(&input) {
Ok(x) => x,
Err(err) => return TokenStream::from(err.to_compile_error()),
};

let metrics_code = match expand_metric_registration(struct_) {
Ok(x) => x,
Err(err) => return err.to_compile_error(),
};

let stage_args: StageArgs = input
.attrs
.iter()
.find(|a| a.meta.path().is_ident("stage"))
.unwrap()
.parse_args()
.unwrap();

let name = input.ident;

let hri = stage_args.name.unwrap_or_else(|| stringify!(name).into());
let unit_type = stage_args.unit;
let worker_type = stage_args.worker;

quote! {
impl gasket::framework::Stage for #name {
type Unit = #unit_type;
type Worker = #worker_type;

fn name(&self) -> &str {
#hri
}

fn metrics(&self) -> gasket::metrics::Registry {
let mut registry = gasket::metrics::Registry::default();

#(#metrics_code)*

registry
}
}
}
}

#[proc_macro_derive(Stage, attributes(stage, metric))]
pub fn stage_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let expanded = expand_stage_impl(input);
proc_macro::TokenStream::from(expanded)
}
4 changes: 4 additions & 0 deletions gasket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ serde = { version = "1.0.160", features = ["derive"] }
thiserror = "1.0.30"
tokio = { version = "1", features = ["rt", "time", "sync", "macros"] }
tracing = "0.1.37"
gasket-derive = { version = "*", path = "../gasket-derive", optional = true }

[dev-dependencies]
tracing-subscriber = "0.3.16"

[features]
derive = ["gasket-derive"]
78 changes: 40 additions & 38 deletions gasket/examples/dumb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@ struct TickerSpec {
}

impl Stage for TickerSpec {
type Unit = TickerUnit;
type Worker = Ticker;

fn name(&self) -> &str {
"ticker"
}

fn policy(&self) -> gasket::runtime::Policy {
Policy {
tick_timeout: Some(Duration::from_secs(3)),
bootstrap_retry: retries::Policy::no_retry(),
work_retry: retries::Policy::no_retry(),
teardown_retry: retries::Policy::no_retry(),
}
}

fn register_metrics(&self, registry: &mut gasket::metrics::Registry) {
fn metrics(&self) -> gasket::metrics::Registry {
let mut registry = gasket::metrics::Registry::default();
registry.track_counter("value_1", &self.value_1);
registry
}
}

Expand All @@ -43,20 +39,17 @@ struct TickerUnit {
}

#[async_trait::async_trait(?Send)]
impl Worker for Ticker {
type Unit = TickerUnit;
type Stage = TickerSpec;

async fn bootstrap(_: &Self::Stage) -> Result<Self, WorkerError> {
impl Worker<TickerSpec> for Ticker {
async fn bootstrap(_: &TickerSpec) -> Result<Self, WorkerError> {
Ok(Self {
next_delay: Default::default(),
})
}

async fn schedule(
&mut self,
_: &mut Self::Stage,
) -> Result<WorkSchedule<Self::Unit>, WorkerError> {
_: &mut TickerSpec,
) -> Result<WorkSchedule<TickerUnit>, WorkerError> {
let unit = TickerUnit {
instant: Instant::now(),
delay: self.next_delay,
Expand All @@ -67,8 +60,8 @@ impl Worker for Ticker {

async fn execute(
&mut self,
unit: &Self::Unit,
stage: &mut Self::Stage,
unit: &TickerUnit,
stage: &mut TickerSpec,
) -> Result<(), WorkerError> {
tokio::time::sleep(Duration::from_secs(unit.delay)).await;
stage.output.send(unit.instant.into()).await.or_panic()?;
Expand All @@ -85,42 +78,35 @@ struct TerminalSpec {
}

impl Stage for TerminalSpec {
type Unit = Instant;
type Worker = Terminal;

fn name(&self) -> &str {
"terminal"
}

fn policy(&self) -> gasket::runtime::Policy {
Policy {
tick_timeout: None,
bootstrap_retry: retries::Policy::no_retry(),
work_retry: retries::Policy::no_retry(),
teardown_retry: retries::Policy::no_retry(),
}
fn metrics(&self) -> gasket::metrics::Registry {
gasket::metrics::Registry::default()
}

fn register_metrics(&self, _: &mut gasket::metrics::Registry) {}
}

struct Terminal;

#[async_trait::async_trait(?Send)]
impl Worker for Terminal {
type Unit = Instant;
type Stage = TerminalSpec;

async fn bootstrap(_: &Self::Stage) -> Result<Self, WorkerError> {
impl Worker<TerminalSpec> for Terminal {
async fn bootstrap(_: &TerminalSpec) -> Result<Self, WorkerError> {
Ok(Self)
}

async fn schedule(
&mut self,
stage: &mut Self::Stage,
) -> Result<WorkSchedule<Self::Unit>, WorkerError> {
stage: &mut TerminalSpec,
) -> Result<WorkSchedule<Instant>, WorkerError> {
let msg = stage.input.recv().await.or_panic()?;
Ok(WorkSchedule::Unit(msg.payload))
}

async fn execute(&mut self, unit: &Self::Unit, _: &mut Self::Stage) -> Result<(), WorkerError> {
async fn execute(&mut self, unit: &Instant, _: &mut TerminalSpec) -> Result<(), WorkerError> {
println!("{:?}", unit.elapsed());

Ok(())
Expand All @@ -146,9 +132,25 @@ fn main() {

connect_ports(&mut ticker.output, &mut terminal.input, 10);

let tether1 = spawn_stage::<Ticker>(ticker);
let tether1 = spawn_stage(
ticker,
Policy {
tick_timeout: Some(Duration::from_secs(3)),
bootstrap_retry: retries::Policy::no_retry(),
work_retry: retries::Policy::no_retry(),
teardown_retry: retries::Policy::no_retry(),
},
);

let tether2 = spawn_stage::<Terminal>(terminal);
let tether2 = spawn_stage(
terminal,
Policy {
tick_timeout: None,
bootstrap_retry: retries::Policy::no_retry(),
work_retry: retries::Policy::no_retry(),
teardown_retry: retries::Policy::no_retry(),
},
);

let tethers = vec![tether1, tether2];

Expand Down
Loading

0 comments on commit 4bb48b5

Please sign in to comment.