use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use crate::ext::ExtRequest;
use crate::{
ClientCapabilities, ContentBlock, Error, ExtNotification, ExtResponse, ProtocolVersion,
SessionId,
};
#[async_trait::async_trait(?Send)]
pub trait Agent {
async fn initialize(&self, args: InitializeRequest) -> Result<InitializeResponse, Error>;
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse, Error>;
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse, Error>;
async fn load_session(&self, args: LoadSessionRequest) -> Result<LoadSessionResponse, Error>;
async fn set_session_mode(
&self,
args: SetSessionModeRequest,
) -> Result<SetSessionModeResponse, Error>;
async fn prompt(&self, args: PromptRequest) -> Result<PromptResponse, Error>;
async fn cancel(&self, args: CancelNotification) -> Result<(), Error>;
#[cfg(feature = "unstable")]
async fn set_session_model(
&self,
args: SetSessionModelRequest,
) -> Result<SetSessionModelResponse, Error>;
async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse, Error>;
async fn ext_notification(&self, args: ExtNotification) -> Result<(), Error>;
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = INITIALIZE_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct InitializeRequest {
pub protocol_version: ProtocolVersion,
#[serde(default)]
pub client_capabilities: ClientCapabilities,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = INITIALIZE_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
#[serde(default)]
pub agent_capabilities: AgentCapabilities,
#[serde(default)]
pub auth_methods: Vec<AuthMethod>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = AUTHENTICATE_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct AuthenticateRequest {
pub method_id: AuthMethodId,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[schemars(extend("x-side" = "agent", "x-method" = AUTHENTICATE_METHOD_NAME))]
pub struct AuthenticateResponse {
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
#[serde(transparent)]
pub struct AuthMethodId(pub Arc<str>);
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct AuthMethod {
pub id: AuthMethodId,
pub name: String,
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_NEW_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct NewSessionRequest {
pub cwd: PathBuf,
pub mcp_servers: Vec<McpServer>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_NEW_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct NewSessionResponse {
pub session_id: SessionId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub modes: Option<SessionModeState>,
#[cfg(feature = "unstable")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub models: Option<SessionModelState>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_LOAD_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct LoadSessionRequest {
pub mcp_servers: Vec<McpServer>,
pub cwd: PathBuf,
pub session_id: SessionId,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_LOAD_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct LoadSessionResponse {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub modes: Option<SessionModeState>,
#[cfg(feature = "unstable")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub models: Option<SessionModelState>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct SessionModeState {
pub current_mode_id: SessionModeId,
pub available_modes: Vec<SessionMode>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct SessionMode {
pub id: SessionModeId,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
#[serde(transparent)]
pub struct SessionModeId(pub Arc<str>);
impl std::fmt::Display for SessionModeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_SET_MODE_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct SetSessionModeRequest {
pub session_id: SessionId,
pub mode_id: SessionModeId,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_SET_MODE_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct SetSessionModeResponse {
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum McpServer {
#[serde(rename_all = "camelCase")]
Http {
name: String,
url: String,
headers: Vec<HttpHeader>,
},
#[serde(rename_all = "camelCase")]
Sse {
name: String,
url: String,
headers: Vec<HttpHeader>,
},
#[serde(untagged, rename_all = "camelCase")]
Stdio {
name: String,
command: PathBuf,
args: Vec<String>,
env: Vec<EnvVariable>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct EnvVariable {
pub name: String,
pub value: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct HttpHeader {
pub name: String,
pub value: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_PROMPT_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct PromptRequest {
pub session_id: SessionId,
pub prompt: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_PROMPT_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct PromptResponse {
pub stop_reason: StopReason,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
MaxTurnRequests,
Refusal,
Cancelled,
}
#[cfg(feature = "unstable")]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct SessionModelState {
pub current_model_id: ModelId,
pub available_models: Vec<ModelInfo>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[cfg(feature = "unstable")]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
#[serde(transparent)]
pub struct ModelId(pub Arc<str>);
#[cfg(feature = "unstable")]
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(feature = "unstable")]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ModelInfo {
pub model_id: ModelId,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[cfg(feature = "unstable")]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_SET_MODEL_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct SetSessionModelRequest {
pub session_id: SessionId,
pub model_id: ModelId,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[cfg(feature = "unstable")]
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_SET_MODEL_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct SetSessionModelResponse {
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct AgentCapabilities {
#[serde(default)]
pub load_session: bool,
#[serde(default)]
pub prompt_capabilities: PromptCapabilities,
#[serde(default)]
pub mcp_capabilities: McpCapabilities,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct PromptCapabilities {
#[serde(default)]
pub image: bool,
#[serde(default)]
pub audio: bool,
#[serde(default)]
pub embedded_context: bool,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct McpCapabilities {
#[serde(default)]
pub http: bool,
#[serde(default)]
pub sse: bool,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMethodNames {
pub initialize: &'static str,
pub authenticate: &'static str,
pub session_new: &'static str,
pub session_load: &'static str,
pub session_set_mode: &'static str,
pub session_prompt: &'static str,
pub session_cancel: &'static str,
#[cfg(feature = "unstable")]
pub model_select: &'static str,
}
pub const AGENT_METHOD_NAMES: AgentMethodNames = AgentMethodNames {
initialize: INITIALIZE_METHOD_NAME,
authenticate: AUTHENTICATE_METHOD_NAME,
session_new: SESSION_NEW_METHOD_NAME,
session_load: SESSION_LOAD_METHOD_NAME,
session_set_mode: SESSION_SET_MODE_METHOD_NAME,
session_prompt: SESSION_PROMPT_METHOD_NAME,
session_cancel: SESSION_CANCEL_METHOD_NAME,
#[cfg(feature = "unstable")]
model_select: SESSION_SET_MODEL_METHOD_NAME,
};
pub(crate) const INITIALIZE_METHOD_NAME: &str = "initialize";
pub(crate) const AUTHENTICATE_METHOD_NAME: &str = "authenticate";
pub(crate) const SESSION_NEW_METHOD_NAME: &str = "session/new";
pub(crate) const SESSION_LOAD_METHOD_NAME: &str = "session/load";
pub(crate) const SESSION_SET_MODE_METHOD_NAME: &str = "session/set_mode";
pub(crate) const SESSION_PROMPT_METHOD_NAME: &str = "session/prompt";
pub(crate) const SESSION_CANCEL_METHOD_NAME: &str = "session/cancel";
#[cfg(feature = "unstable")]
pub(crate) const SESSION_SET_MODEL_METHOD_NAME: &str = "session/set_model";
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum ClientRequest {
InitializeRequest(InitializeRequest),
AuthenticateRequest(AuthenticateRequest),
NewSessionRequest(NewSessionRequest),
LoadSessionRequest(LoadSessionRequest),
SetSessionModeRequest(SetSessionModeRequest),
PromptRequest(PromptRequest),
#[cfg(feature = "unstable")]
ModelSelectRequest(SetSessionModelRequest),
ExtMethodRequest(ExtRequest),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum AgentResponse {
InitializeResponse(InitializeResponse),
AuthenticateResponse(#[serde(default)] AuthenticateResponse),
NewSessionResponse(NewSessionResponse),
LoadSessionResponse(#[serde(default)] LoadSessionResponse),
SetSessionModeResponse(#[serde(default)] SetSessionModeResponse),
PromptResponse(PromptResponse),
#[cfg(feature = "unstable")]
ModelSelectResponse(SetSessionModelResponse),
ExtMethodResponse(#[schemars(with = "serde_json::Value")] Arc<RawValue>),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum ClientNotification {
CancelNotification(CancelNotification),
ExtNotification(ExtNotification),
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = SESSION_CANCEL_METHOD_NAME))]
#[serde(rename_all = "camelCase")]
pub struct CancelNotification {
pub session_id: SessionId,
#[serde(skip_serializing_if = "Option::is_none", rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
#[cfg(test)]
mod test_serialization {
use super::*;
use serde_json::json;
#[test]
fn test_mcp_server_stdio_serialization() {
let server = McpServer::Stdio {
name: "test-server".to_string(),
command: PathBuf::from("/usr/bin/server"),
args: vec!["--port".to_string(), "3000".to_string()],
env: vec![EnvVariable {
name: "API_KEY".to_string(),
value: "secret123".to_string(),
meta: None,
}],
};
let json = serde_json::to_value(&server).unwrap();
assert_eq!(
json,
json!({
"name": "test-server",
"command": "/usr/bin/server",
"args": ["--port", "3000"],
"env": [
{
"name": "API_KEY",
"value": "secret123"
}
]
})
);
let deserialized: McpServer = serde_json::from_value(json).unwrap();
match deserialized {
McpServer::Stdio {
name,
command,
args,
env,
} => {
assert_eq!(name, "test-server");
assert_eq!(command, PathBuf::from("/usr/bin/server"));
assert_eq!(args, vec!["--port", "3000"]);
assert_eq!(env.len(), 1);
assert_eq!(env[0].name, "API_KEY");
assert_eq!(env[0].value, "secret123");
}
_ => panic!("Expected Stdio variant"),
}
}
#[test]
fn test_mcp_server_http_serialization() {
let server = McpServer::Http {
name: "http-server".to_string(),
url: "https://api.example.com".to_string(),
headers: vec![
HttpHeader {
name: "Authorization".to_string(),
value: "Bearer token123".to_string(),
meta: None,
},
HttpHeader {
name: "Content-Type".to_string(),
value: "application/json".to_string(),
meta: None,
},
],
};
let json = serde_json::to_value(&server).unwrap();
assert_eq!(
json,
json!({
"type": "http",
"name": "http-server",
"url": "https://api.example.com",
"headers": [
{
"name": "Authorization",
"value": "Bearer token123"
},
{
"name": "Content-Type",
"value": "application/json"
}
]
})
);
let deserialized: McpServer = serde_json::from_value(json).unwrap();
match deserialized {
McpServer::Http { name, url, headers } => {
assert_eq!(name, "http-server");
assert_eq!(url, "https://api.example.com");
assert_eq!(headers.len(), 2);
assert_eq!(headers[0].name, "Authorization");
assert_eq!(headers[0].value, "Bearer token123");
assert_eq!(headers[1].name, "Content-Type");
assert_eq!(headers[1].value, "application/json");
}
_ => panic!("Expected Http variant"),
}
}
#[test]
fn test_mcp_server_sse_serialization() {
let server = McpServer::Sse {
name: "sse-server".to_string(),
url: "https://sse.example.com/events".to_string(),
headers: vec![HttpHeader {
name: "X-API-Key".to_string(),
value: "apikey456".to_string(),
meta: None,
}],
};
let json = serde_json::to_value(&server).unwrap();
assert_eq!(
json,
json!({
"type": "sse",
"name": "sse-server",
"url": "https://sse.example.com/events",
"headers": [
{
"name": "X-API-Key",
"value": "apikey456"
}
]
})
);
let deserialized: McpServer = serde_json::from_value(json).unwrap();
match deserialized {
McpServer::Sse { name, url, headers } => {
assert_eq!(name, "sse-server");
assert_eq!(url, "https://sse.example.com/events");
assert_eq!(headers.len(), 1);
assert_eq!(headers[0].name, "X-API-Key");
assert_eq!(headers[0].value, "apikey456");
}
_ => panic!("Expected Sse variant"),
}
}
}