use std::str::FromStr;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, QueryBuilder, Sqlite, SqlitePool};
use ts_rs_forge::TS;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, TS, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[ts(rename_all = "snake_case")]
pub enum DraftType {
FollowUp,
Retry,
}
impl DraftType {
pub fn as_str(&self) -> &'static str {
match self {
DraftType::FollowUp => "follow_up",
DraftType::Retry => "retry",
}
}
}
impl FromStr for DraftType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"follow_up" => Ok(DraftType::FollowUp),
"retry" => Ok(DraftType::Retry),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TS)]
pub struct Draft {
pub id: Uuid,
pub task_attempt_id: Uuid,
pub draft_type: DraftType,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry_process_id: Option<Uuid>,
pub prompt: String,
pub queued: bool,
pub sending: bool,
pub variant: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_ids: Option<Vec<Uuid>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub version: i64,
}
#[derive(Debug, Clone, FromRow)]
struct DraftRow {
pub id: Uuid,
pub task_attempt_id: Uuid,
pub draft_type: String,
pub retry_process_id: Option<Uuid>,
pub prompt: String,
pub queued: bool,
pub sending: bool,
pub variant: Option<String>,
pub image_ids: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub version: i64,
}
impl From<DraftRow> for Draft {
fn from(r: DraftRow) -> Self {
let image_ids = r
.image_ids
.as_deref()
.and_then(|s| serde_json::from_str::<Vec<Uuid>>(s).ok());
Draft {
id: r.id,
task_attempt_id: r.task_attempt_id,
draft_type: DraftType::from_str(&r.draft_type).unwrap_or(DraftType::FollowUp),
retry_process_id: r.retry_process_id,
prompt: r.prompt,
queued: r.queued,
sending: r.sending,
variant: r.variant,
image_ids,
created_at: r.created_at,
updated_at: r.updated_at,
version: r.version,
}
}
}
#[derive(Debug, Deserialize, TS)]
pub struct UpsertDraft {
pub task_attempt_id: Uuid,
pub draft_type: DraftType,
pub retry_process_id: Option<Uuid>,
pub prompt: String,
pub queued: bool,
pub variant: Option<String>,
pub image_ids: Option<Vec<Uuid>>,
}
impl Draft {
pub async fn find_by_rowid(pool: &SqlitePool, rowid: i64) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as!(
DraftRow,
r#"SELECT
id as "id!: Uuid",
task_attempt_id as "task_attempt_id!: Uuid",
draft_type,
retry_process_id as "retry_process_id?: Uuid",
prompt,
queued as "queued!: bool",
sending as "sending!: bool",
variant,
image_ids,
created_at as "created_at!: DateTime<Utc>",
updated_at as "updated_at!: DateTime<Utc>",
version as "version!: i64"
FROM drafts
WHERE rowid = $1"#,
rowid
)
.fetch_optional(pool)
.await
.map(|opt| opt.map(Draft::from))
}
pub async fn find_by_task_attempt_and_type(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<Option<Self>, sqlx::Error> {
let draft_type_str = draft_type.as_str();
sqlx::query_as!(
DraftRow,
r#"SELECT
id as "id!: Uuid",
task_attempt_id as "task_attempt_id!: Uuid",
draft_type,
retry_process_id as "retry_process_id?: Uuid",
prompt,
queued as "queued!: bool",
sending as "sending!: bool",
variant,
image_ids,
created_at as "created_at!: DateTime<Utc>",
updated_at as "updated_at!: DateTime<Utc>",
version as "version!: i64"
FROM drafts
WHERE task_attempt_id = $1 AND draft_type = $2"#,
task_attempt_id,
draft_type_str
)
.fetch_optional(pool)
.await
.map(|opt| opt.map(Draft::from))
}
pub async fn upsert(pool: &SqlitePool, data: &UpsertDraft) -> Result<Self, sqlx::Error> {
if data.draft_type == DraftType::Retry && data.retry_process_id.is_none() {
return Err(sqlx::Error::Protocol(
"retry_process_id is required for retry drafts".into(),
));
}
let id = Uuid::new_v4();
let image_ids_json = data
.image_ids
.as_ref()
.map(|ids| serde_json::to_string(ids).unwrap_or_else(|_| "[]".to_string()));
let draft_type_str = data.draft_type.as_str();
let prompt = data.prompt.clone();
let variant = data.variant.clone();
sqlx::query_as!(
DraftRow,
r#"INSERT INTO drafts (id, task_attempt_id, draft_type, retry_process_id, prompt, queued, variant, image_ids)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT(task_attempt_id, draft_type) DO UPDATE SET
retry_process_id = excluded.retry_process_id,
prompt = excluded.prompt,
queued = excluded.queued,
variant = excluded.variant,
image_ids = excluded.image_ids,
version = drafts.version + 1
RETURNING
id as "id!: Uuid",
task_attempt_id as "task_attempt_id!: Uuid",
draft_type,
retry_process_id as "retry_process_id?: Uuid",
prompt,
queued as "queued!: bool",
sending as "sending!: bool",
variant,
image_ids,
created_at as "created_at!: DateTime<Utc>",
updated_at as "updated_at!: DateTime<Utc>",
version as "version!: i64""#,
id,
data.task_attempt_id,
draft_type_str,
data.retry_process_id,
prompt,
data.queued,
variant,
image_ids_json
)
.fetch_one(pool)
.await
.map(Draft::from)
}
pub async fn clear_after_send(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<(), sqlx::Error> {
let draft_type_str = draft_type.as_str();
match draft_type {
DraftType::FollowUp => {
sqlx::query(
r#"UPDATE drafts
SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1
WHERE task_attempt_id = ? AND draft_type = ?"#,
)
.bind(task_attempt_id)
.bind(draft_type_str)
.execute(pool)
.await?;
}
DraftType::Retry => {
Self::delete_by_task_attempt_and_type(pool, task_attempt_id, draft_type).await?;
}
}
Ok(())
}
pub async fn delete_by_task_attempt_and_type(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<(), sqlx::Error> {
sqlx::query(r#"DELETE FROM drafts WHERE task_attempt_id = ? AND draft_type = ?"#)
.bind(task_attempt_id)
.bind(draft_type.as_str())
.execute(pool)
.await?;
Ok(())
}
pub async fn try_mark_sending(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
) -> Result<bool, sqlx::Error> {
let draft_type_str = draft_type.as_str();
let result = sqlx::query(
r#"UPDATE drafts
SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1
WHERE task_attempt_id = ?
AND draft_type = ?
AND queued = 1
AND sending = 0
AND TRIM(prompt) != ''"#,
)
.bind(task_attempt_id)
.bind(draft_type_str)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn update_partial(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
prompt: Option<String>,
variant: Option<Option<String>>,
image_ids: Option<Vec<Uuid>>,
retry_process_id: Option<Uuid>,
) -> Result<(), sqlx::Error> {
if retry_process_id.is_none()
&& prompt.is_none()
&& variant.is_none()
&& image_ids.is_none()
{
return Ok(());
}
let mut query = QueryBuilder::<Sqlite>::new("UPDATE drafts SET ");
let mut separated = query.separated(", ");
if let Some(rpid) = retry_process_id {
separated.push("retry_process_id = ");
separated.push_bind_unseparated(rpid);
}
if let Some(p) = prompt {
separated.push("prompt = ");
separated.push_bind_unseparated(p);
}
if let Some(v_opt) = variant {
separated.push("variant = ");
match v_opt {
Some(v) => separated.push_bind_unseparated(v),
None => separated.push_bind_unseparated(Option::<String>::None),
};
}
if let Some(ids) = image_ids {
let image_ids_json = serde_json::to_string(&ids).unwrap_or_else(|_| "[]".to_string());
separated.push("image_ids = ");
separated.push_bind_unseparated(image_ids_json);
}
separated.push("updated_at = CURRENT_TIMESTAMP");
separated.push("version = version + 1");
query.push(" WHERE task_attempt_id = ");
query.push_bind(task_attempt_id);
query.push(" AND draft_type = ");
query.push_bind(draft_type.as_str());
query.build().execute(pool).await?;
Ok(())
}
pub async fn set_queued(
pool: &SqlitePool,
task_attempt_id: Uuid,
draft_type: DraftType,
queued: bool,
expected_queued: Option<bool>,
expected_version: Option<i64>,
) -> Result<u64, sqlx::Error> {
let result = sqlx::query(
r#"UPDATE drafts
SET queued = CASE
WHEN ?1 THEN (TRIM(prompt) <> '')
ELSE 0
END,
updated_at = CURRENT_TIMESTAMP,
version = version + 1
WHERE task_attempt_id = ?2
AND draft_type = ?3
AND (?4 IS NULL OR queued = ?4)
AND (?5 IS NULL OR version = ?5)"#,
)
.bind(queued as i64)
.bind(task_attempt_id)
.bind(draft_type.as_str())
.bind(expected_queued.map(|value| value as i64))
.bind(expected_version)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
}