Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration enums #312

Merged
merged 17 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
- run: rustup self update
- run: cd refinery_core && cargo test --all-features -- --test-threads 1
- run: cd refinery && cargo build --all-features
- run: cd refinery_macros && cargo test
- run: cd refinery_macros && cargo test --features=enums
- run: cd refinery_cli && cargo test

test-sqlite:
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ members = [
"refinery",
"refinery_cli",
"refinery_core",
"refinery_macros"
"refinery_macros",
"examples"
]
19 changes: 19 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "refinery-examples"
version = "0.8.12"
authors = ["Katharina Fey <[email protected]>", "João Oliveira <[email protected]>"]
description = "Minimal Refinery usage example"
license = "MIT OR Apache-2.0"
documentation = "https://docs.rs/refinery/"
repository = "https://github.com/rust-db/refinery"
edition = "2021"

[features]
enums = ["refinery/enums"]

[dependencies]
refinery = { path = "../refinery", features = ["rusqlite"] }
rusqlite = "0.29"
barrel = { version = "0.7", features = ["sqlite3"] }
log = "0.4"
env_logger = "0.11"
18 changes: 0 additions & 18 deletions examples/main.rs

This file was deleted.

42 changes: 42 additions & 0 deletions examples/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use barrel::backend::Sqlite as Sql;
use log::info;
use refinery::Migration;
use rusqlite::Connection;

refinery::embed_migrations!("migrations");

fn main() {
env_logger::init();

let mut conn = Connection::open_in_memory().unwrap();

let use_iteration = std::env::args().any(|a| a.to_lowercase().eq("--iterate"));

if use_iteration {
// create an iterator over migrations as they run
for migration in migrations::runner().run_iter(&mut conn) {
process_migration(migration.expect("Migration failed!"));
}
} else {
// or run all migrations in one go
migrations::runner().run(&mut conn).unwrap();
}
}

fn process_migration(migration: Migration) {
#[cfg(not(feature = "enums"))]
{
// run something after each migration
info!("Post-processing a migration: {}", migration)
}

#[cfg(feature = "enums")]
{
// or with the `enums` feature enabled, match against migrations to run specific post-migration steps
use migrations::EmbeddedMigration;
match migration.into() {
EmbeddedMigration::Initial(m) => info!("V{}: Initialized the database!", m.version()),
m => info!("Got a migration: {:?}", m),
}
}
}
1 change: 1 addition & 0 deletions refinery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ tiberius = ["refinery-core/tiberius"]
tiberius-config = ["refinery-core/tiberius", "refinery-core/tiberius-config"]
serde = ["refinery-core/serde"]
toml = ["refinery-core/toml"]
enums = ["refinery-macros/enums"]

[dependencies]
refinery-core = { version = "0.8.12", path = "../refinery_core" }
Expand Down
4 changes: 3 additions & 1 deletion refinery_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ pub use crate::error::Error;
pub use crate::runner::{Migration, Report, Runner, Target};
pub use crate::traits::r#async::AsyncMigrate;
pub use crate::traits::sync::Migrate;
pub use crate::util::{find_migration_files, load_sql_migrations, MigrationType};
pub use crate::util::{
find_migration_files, load_sql_migrations, parse_migration_name, MigrationType,
};

#[cfg(feature = "rusqlite")]
pub use rusqlite;
Expand Down
25 changes: 2 additions & 23 deletions refinery_core/src/runner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use regex::Regex;
use siphasher::sip::SipHasher13;
use time::OffsetDateTime;

Expand All @@ -7,19 +6,12 @@ use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;

use crate::error::Kind;
use crate::traits::{sync::migrate as sync_migrate, DEFAULT_MIGRATION_TABLE_NAME};
use crate::util::parse_migration_name;
use crate::{AsyncMigrate, Error, Migrate};
use std::fmt::Formatter;

// regex used to match file names
pub fn file_match_re() -> &'static Regex {
static RE: OnceLock<regex::Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(r"^([U|V])(\d+(?:\.\d+)?)__(\w+)").unwrap())
}

/// An enum set that represents the type of the Migration
#[derive(Clone, PartialEq)]
pub enum Type {
Expand Down Expand Up @@ -84,20 +76,7 @@ impl Migration {
/// Create an unapplied migration, name and version are parsed from the input_name,
/// which must be named in the format (U|V){1}__{2}.rs where {1} represents the migration version and {2} the name.
pub fn unapplied(input_name: &str, sql: &str) -> Result<Migration, Error> {
let captures = file_match_re()
.captures(input_name)
.filter(|caps| caps.len() == 4)
.ok_or_else(|| Error::new(Kind::InvalidName, None))?;
let version: i32 = captures[2]
.parse()
.map_err(|_| Error::new(Kind::InvalidVersion, None))?;

let name: String = (&captures[3]).into();
let prefix = match &captures[1] {
"V" => Type::Versioned,
"U" => Type::Unversioned,
_ => unreachable!(),
};
let (prefix, version, name) = parse_migration_name(input_name)?;

// Previously, `std::collections::hash_map::DefaultHasher` was used
// to calculate the checksum and the implementation at that time
Expand Down
54 changes: 47 additions & 7 deletions refinery_core/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
use crate::error::{Error, Kind};
use crate::runner::Type;
use crate::Migration;
use regex::Regex;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use walkdir::{DirEntry, WalkDir};

const STEM_RE: &'static str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";

/// Matches the stem of a migration file.
fn file_stem_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
}

/// Matches the stem + extension of a SQL migration file.
fn file_re_sql() -> &'static Regex {
jxs marked this conversation as resolved.
Show resolved Hide resolved
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
}

/// Matches the stem + extension of any migration file.
fn file_re_all() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
}

/// enum containing the migration types used to search for migrations
/// either just .sql files or both .sql and .rs
pub enum MigrationType {
Expand All @@ -13,16 +35,34 @@ pub enum MigrationType {
}

impl MigrationType {
fn file_match_re(&self) -> Regex {
let ext = match self {
MigrationType::All => "(rs|sql)",
MigrationType::Sql => "sql",
};
let re_str = format!(r"^(U|V)(\d+(?:\.\d+)?)__(\w+)\.{}$", ext);
Regex::new(re_str.as_str()).unwrap()
fn file_match_re(&self) -> &'static Regex {
match self {
MigrationType::All => file_re_all(),
MigrationType::Sql => file_re_sql(),
}
}
}

/// Parse a migration filename stem into a prefix, version, and name.
pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
let captures = file_stem_re()
.captures(name)
.filter(|caps| caps.len() == 4)
.ok_or_else(|| Error::new(Kind::InvalidName, None))?;
let version: i32 = captures[2]
.parse()
.map_err(|_| Error::new(Kind::InvalidVersion, None))?;

let name: String = (&captures[3]).into();
let prefix = match &captures[1] {
"V" => Type::Versioned,
"U" => Type::Unversioned,
_ => unreachable!(),
};

Ok((prefix, version, name))
}

/// find migrations on file system recursively across directories given a location and [MigrationType]
pub fn find_migration_files(
location: impl AsRef<Path>,
Expand Down
4 changes: 4 additions & 0 deletions refinery_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ documentation = "https://docs.rs/refinery/"
repository = "https://github.com/rust-db/refinery"
edition = "2018"

[features]
enums = []

[lib]
proc-macro = true

Expand All @@ -17,6 +20,7 @@ quote = "1"
syn = "2"
proc-macro2 = "1"
regex = "1"
heck = "0.4"

[dev-dependencies]
tempfile = "3"
62 changes: 62 additions & 0 deletions refinery_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Contains Refinery macros that are used to import and embed migration files.
#![recursion_limit = "128"]

use heck::ToUpperCamelCase;
use proc_macro::TokenStream;
use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
use quote::quote;
Expand Down Expand Up @@ -31,6 +32,42 @@ fn migration_fn_quoted<T: ToTokens>(_migrations: Vec<T>) -> TokenStream2 {
result
}

fn migration_enum_quoted(migration_names: &[impl AsRef<str>]) -> TokenStream2 {
if cfg!(feature = "enums") {
let mut variants = Vec::new();
let mut discriminants = Vec::new();

for m in migration_names {
let m = m.as_ref();
let (_, version, name) = refinery_core::parse_migration_name(m)
.unwrap_or_else(|e| panic!("Couldn't parse migration filename '{}': {:?}", m, e));
let variant = Ident::new(name.to_upper_camel_case().as_str(), Span2::call_site());
variants.push(quote! { #variant(Migration) = #version });
discriminants.push(quote! { #version => Self::#variant(migration) });
}
discriminants.push(quote! { v => panic!("Invalid migration version '{}'", v) });

let result = quote! {
#[repr(i32)]
#[derive(Debug)]
pub enum EmbeddedMigration {
#(#variants),*
}

impl From<Migration> for EmbeddedMigration {
fn from(migration: Migration) -> Self {
match migration.version() as i32 {
#(#discriminants),*
}
}
}
};
result
} else {
quote!()
}
}

/// Interpret Rust or SQL migrations and inserts a function called runner that when called returns a [`Runner`] instance with the collected migration modules.
///
/// When called without arguments `embed_migrations` searches for migration files on a directory called `migrations` at the root level of your crate.
Expand All @@ -56,6 +93,7 @@ pub fn embed_migrations(input: TokenStream) -> TokenStream {

let mut migrations_mods = Vec::new();
let mut _migrations = Vec::new();
let mut migration_filenames = Vec::new();

for migration in migration_files {
// safe to call unwrap as find_migration_filenames returns canonical paths
Expand All @@ -65,6 +103,7 @@ pub fn embed_migrations(input: TokenStream) -> TokenStream {
.unwrap();
let path = migration.display().to_string();
let extension = migration.extension().unwrap();
migration_filenames.push(filename.clone());

if extension == "sql" {
_migrations.push(quote! {(#filename, include_str!(#path).to_string())});
Expand All @@ -85,10 +124,12 @@ pub fn embed_migrations(input: TokenStream) -> TokenStream {
}

let fnq = migration_fn_quoted(_migrations);
let enums = migration_enum_quoted(migration_filenames.as_slice());
(quote! {
pub mod migrations {
#(#migrations_mods)*
#fnq
#enums
}
})
.into()
Expand All @@ -98,6 +139,27 @@ pub fn embed_migrations(input: TokenStream) -> TokenStream {
mod tests {
use super::{migration_fn_quoted, quote};

#[test]
#[cfg(feature = "enums")]
fn test_enum_fn() {
let expected = concat! {
"# [repr (i32)] # [derive (Debug)] ",
"pub enum EmbeddedMigration { ",
"Foo (Migration) = 1i32 , ",
"BarBaz (Migration) = 3i32 ",
"} ",
"impl From < Migration > for EmbeddedMigration { ",
"fn from (migration : Migration) -> Self { ",
"match migration . version () as i32 { ",
"1i32 => Self :: Foo (migration) , ",
"3i32 => Self :: BarBaz (migration) , ",
"v => panic ! (\"Invalid migration version '{}'\" , v) ",
"} } }"
};
let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
assert_eq!(expected, enums);
}

#[test]
fn test_quote_fn() {
let migs = vec![quote!("V1__first", "valid_sql_file")];
Expand Down
Loading