#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![doc = include_str!(concat!(env!("OUT_DIR"), "/readme_for_rustdoc.md"))]
use log::{debug, info, trace, warn};
use rusqlite::{Connection, Transaction};
#[cfg(feature = "from-directory")]
use include_dir::Dir;
#[cfg(feature = "from-directory")]
mod loader;
#[cfg(feature = "from-directory")]
use loader::from_directory;
#[cfg(feature = "from-directory")]
mod builder;
#[cfg(feature = "from-directory")]
pub use builder::MigrationsBuilder;
#[cfg(feature = "alpha-async-tokio-rusqlite")]
mod asynch;
mod errors;
#[cfg(test)]
mod tests;
#[cfg(feature = "alpha-async-tokio-rusqlite")]
pub use asynch::AsyncMigrations;
pub use errors::{
Error, ForeignKeyCheckError, HookError, HookResult, MigrationDefinitionError, Result,
SchemaVersionError,
};
use std::{
cmp::{self, Ordering},
fmt::{self, Debug},
iter::FromIterator,
num::NonZeroUsize,
ptr::addr_of,
};
pub trait MigrationHook: Fn(&Transaction) -> HookResult + Send + Sync {
fn clone_box(&self) -> Box<dyn MigrationHook>;
}
impl<T> MigrationHook for T
where
T: 'static + Clone + Send + Sync + Fn(&Transaction) -> HookResult,
{
fn clone_box(&self) -> Box<dyn MigrationHook> {
Box::new(self.clone())
}
}
impl Debug for Box<dyn MigrationHook> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MigrationHook(<closure>)")
}
}
impl Clone for Box<dyn MigrationHook> {
fn clone(&self) -> Self {
(**self).clone_box()
}
}
#[derive(Debug, Clone)]
#[must_use]
pub struct M<'u> {
up: &'u str,
up_hook: Option<Box<dyn MigrationHook>>,
down: Option<&'u str>,
down_hook: Option<Box<dyn MigrationHook>>,
foreign_key_check: bool,
comment: Option<&'u str>,
}
impl<'u> PartialEq for M<'u> {
fn eq(&self, other: &Self) -> bool {
let equal_up_hooks = match (self.up_hook.as_ref(), other.up_hook.as_ref()) {
(None, None) => true,
(Some(a), Some(b)) => addr_of!(*a) as usize == addr_of!(*b) as usize,
_ => false,
};
let equal_down_hooks = match (self.down_hook.as_ref(), other.down_hook.as_ref()) {
(None, None) => true,
(Some(a), Some(b)) => addr_of!(*a) as usize == addr_of!(*b) as usize,
_ => false,
};
self.up == other.up
&& self.down == other.down
&& equal_up_hooks
&& equal_down_hooks
&& self.foreign_key_check == other.foreign_key_check
}
}
impl<'u> Eq for M<'u> {}
impl<'u> M<'u> {
pub const fn up(sql: &'u str) -> Self {
Self {
up: sql,
up_hook: None,
down: None,
down_hook: None,
foreign_key_check: false,
comment: None,
}
}
pub const fn comment(mut self, comment: &'u str) -> Self {
self.comment = Some(comment);
self
}
pub fn up_with_hook(sql: &'u str, hook: impl MigrationHook + 'static) -> Self {
let mut m = Self::up(sql);
m.up_hook = Some(hook.clone_box());
m
}
pub const fn down(mut self, sql: &'u str) -> Self {
self.down = Some(sql);
self
}
pub fn down_with_hook(mut self, sql: &'u str, hook: impl MigrationHook + 'static) -> Self {
self.down = Some(sql);
self.down_hook = Some(hook.clone_box());
self
}
pub const fn foreign_key_check(mut self) -> Self {
self.foreign_key_check = true;
self
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum SchemaVersion {
NoneSet,
Inside(NonZeroUsize),
Outside(NonZeroUsize),
}
impl From<&SchemaVersion> for usize {
fn from(schema_version: &SchemaVersion) -> usize {
match schema_version {
SchemaVersion::NoneSet => 0,
SchemaVersion::Inside(v) | SchemaVersion::Outside(v) => From::from(*v),
}
}
}
impl From<SchemaVersion> for usize {
fn from(schema_version: SchemaVersion) -> Self {
From::from(&schema_version)
}
}
impl fmt::Display for SchemaVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchemaVersion::NoneSet => write!(f, "0 (no version set)"),
SchemaVersion::Inside(v) => write!(f, "{v} (inside)"),
SchemaVersion::Outside(v) => write!(f, "{v} (outside)"),
}
}
}
impl cmp::PartialOrd for SchemaVersion {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let self_usize: usize = self.into();
let other_usize: usize = other.into();
self_usize.partial_cmp(&other_usize)
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Migrations<'m> {
ms: Vec<M<'m>>,
}
impl<'m> Migrations<'m> {
#[must_use]
pub fn new(ms: Vec<M<'m>>) -> Self {
Self { ms }
}
#[cfg(feature = "from-directory")]
pub fn from_directory(dir: &'static Dir<'static>) -> Result<Self> {
let migrations = from_directory(dir)?
.into_iter()
.collect::<Option<Vec<_>>>()
.ok_or(Error::FileLoad("Could not load migrations".to_string()))?;
Ok(Self { ms: migrations })
}
#[deprecated = "Use the `FromIterator` trait implementation instead. For instance, you can call Migrations::from_iter."]
pub fn new_iter<I: IntoIterator<Item = M<'m>>>(ms: I) -> Self {
Self::new(Vec::from_iter(ms))
}
fn db_version_to_schema(&self, db_version: usize) -> SchemaVersion {
match db_version {
0 => SchemaVersion::NoneSet,
v if v > 0 && v <= self.ms.len() => SchemaVersion::Inside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
v => SchemaVersion::Outside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
}
}
pub fn current_version(&self, conn: &Connection) -> Result<SchemaVersion> {
Ok(user_version(conn).map(|v| self.db_version_to_schema(v))?)
}
fn goto_up(
&self,
conn: &mut Connection,
current_version: usize,
target_version: usize,
) -> Result<()> {
debug_assert!(current_version <= target_version);
debug_assert!(target_version <= self.ms.len());
trace!("start migration transaction");
let tx = conn.transaction()?;
for v in current_version..target_version {
let m = &self.ms[v];
debug!("Running: {}", m.up);
tx.execute_batch(m.up)
.map_err(|e| Error::with_sql(e, m.up))?;
if m.foreign_key_check {
validate_foreign_keys(&tx)?;
}
if let Some(hook) = &m.up_hook {
hook(&tx)?;
}
}
set_user_version(&tx, target_version)?;
tx.commit()?;
trace!("committed migration transaction");
Ok(())
}
fn goto_down(
&self,
conn: &mut Connection,
current_version: usize,
target_version: usize,
) -> Result<()> {
debug_assert!(current_version >= target_version);
debug_assert!(target_version <= self.ms.len());
if let Some((i, bad_m)) = self
.ms
.iter()
.enumerate()
.skip(target_version)
.take(current_version - target_version)
.find(|(_, m)| m.down.is_none())
{
warn!("Cannot revert: {:?}", bad_m);
return Err(Error::MigrationDefinition(
MigrationDefinitionError::DownNotDefined { migration_index: i },
));
}
trace!("start migration transaction");
let tx = conn.transaction()?;
for v in (target_version..current_version).rev() {
let m = &self.ms[v];
if let Some(down) = m.down {
debug!("Running: {}", &down);
if let Some(hook) = &m.down_hook {
hook(&tx)?;
}
tx.execute_batch(down)
.map_err(|e| Error::with_sql(e, down))?;
if m.foreign_key_check {
validate_foreign_keys(&tx)?;
}
} else {
unreachable!();
}
}
set_user_version(&tx, target_version)?;
tx.commit()?;
trace!("committed migration transaction");
Ok(())
}
fn goto(&self, conn: &mut Connection, target_db_version: usize) -> Result<()> {
let current_version = user_version(conn)?;
let res = match target_db_version.cmp(¤t_version) {
Ordering::Less => {
if current_version > self.ms.len() {
return Err(Error::MigrationDefinition(
MigrationDefinitionError::DatabaseTooFarAhead,
));
}
debug!(
"rollback to older version requested, target_db_version: {}, current_version: {}",
target_db_version, current_version
);
self.goto_down(conn, current_version, target_db_version)
}
Ordering::Equal => {
debug!("no migration to run, db already up to date");
return Ok(()); }
Ordering::Greater => {
debug!(
"some migrations to run, target: {target_db_version}, current: {current_version}"
);
self.goto_up(conn, current_version, target_db_version)
}
};
if res.is_ok() {
info!("Database migrated to version {}", target_db_version);
}
res
}
fn max_schema_version(&self) -> SchemaVersion {
match self.ms.len() {
0 => SchemaVersion::NoneSet,
v => SchemaVersion::Inside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
}
}
pub fn to_latest(&self, conn: &mut Connection) -> Result<()> {
let v_max = self.max_schema_version();
match v_max {
SchemaVersion::NoneSet => {
warn!("no migration defined");
Err(Error::MigrationDefinition(
MigrationDefinitionError::NoMigrationsDefined,
))
}
SchemaVersion::Inside(v) => {
debug!("some migrations defined (version: {v}), try to migrate");
self.goto(conn, v_max.into())
}
SchemaVersion::Outside(_) => unreachable!(),
}
}
pub fn to_version(&self, conn: &mut Connection, version: usize) -> Result<()> {
let target_version: SchemaVersion = self.db_version_to_schema(version);
let v_max = self.max_schema_version();
match v_max {
SchemaVersion::NoneSet => {
warn!("no migrations defined");
Err(Error::MigrationDefinition(
MigrationDefinitionError::NoMigrationsDefined,
))
}
SchemaVersion::Inside(v) => {
debug!("some migrations defined (version: {v}), try to migrate");
if target_version > v_max {
warn!("specified version is higher than the max supported version");
return Err(Error::SpecifiedSchemaVersion(
SchemaVersionError::TargetVersionOutOfRange {
specified: target_version,
highest: v_max,
},
));
}
self.goto(conn, target_version.into())
}
SchemaVersion::Outside(_) => unreachable!(),
}
}
pub fn validate(&self) -> Result<()> {
let mut conn = Connection::open_in_memory()?;
self.to_latest(&mut conn)
}
}
fn user_version(conn: &Connection) -> Result<usize, rusqlite::Error> {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
conn.query_row("PRAGMA user_version", [], |row| row.get(0))
.map(|v: i64| v as usize)
}
fn set_user_version(conn: &Connection, v: usize) -> Result<()> {
trace!("set user version to: {}", v);
#[allow(clippy::cast_possible_truncation)]
let v = v as u32;
conn.pragma_update(None, "user_version", v)
.map_err(|e| Error::RusqliteError {
query: format!("PRAGMA user_version = {v}; -- Approximate query"),
err: e,
})
}
fn validate_foreign_keys(conn: &Connection) -> Result<()> {
let pragma_fk_check = "PRAGMA foreign_key_check";
let mut stmt = conn
.prepare_cached(pragma_fk_check)
.map_err(|e| Error::with_sql(e, pragma_fk_check))?;
let fk_errors = stmt
.query_map([], |row| {
Ok(ForeignKeyCheckError {
table: row.get(0)?,
rowid: row.get(1)?,
parent: row.get(2)?,
fkid: row.get(3)?,
})
})
.map_err(|e| Error::with_sql(e, pragma_fk_check))?
.collect::<Result<Vec<_>, _>>()?;
if !fk_errors.is_empty() {
Err(crate::Error::ForeignKeyCheck(fk_errors))
} else {
Ok(())
}
}
impl<'u> FromIterator<M<'u>> for Migrations<'u> {
fn from_iter<T: IntoIterator<Item = M<'u>>>(iter: T) -> Self {
Self {
ms: Vec::from_iter(iter),
}
}
}