1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + Sync;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188#[cfg(feature = "mcp")]
189pub struct McpTool<T: mcp_core::transport::Transport> {
190 definition: mcp_core::types::Tool,
191 client: mcp_core::client::Client<T>,
192}
193
194#[cfg(feature = "mcp")]
195impl<T> McpTool<T>
196where
197 T: mcp_core::transport::Transport,
198{
199 pub fn from_mcp_server(
200 definition: mcp_core::types::Tool,
201 client: mcp_core::client::Client<T>,
202 ) -> Self {
203 Self { definition, client }
204 }
205}
206
207#[cfg(feature = "mcp")]
208impl From<&mcp_core::types::Tool> for ToolDefinition {
209 fn from(val: &mcp_core::types::Tool) -> Self {
210 Self {
211 name: val.name.to_owned(),
212 description: val.description.to_owned().unwrap_or_default(),
213 parameters: val.input_schema.to_owned(),
214 }
215 }
216}
217
218#[cfg(feature = "mcp")]
219impl From<mcp_core::types::Tool> for ToolDefinition {
220 fn from(val: mcp_core::types::Tool) -> Self {
221 Self {
222 name: val.name,
223 description: val.description.unwrap_or_default(),
224 parameters: val.input_schema,
225 }
226 }
227}
228
229#[cfg(feature = "mcp")]
230#[derive(Debug, thiserror::Error)]
231#[error("MCP tool error: {0}")]
232pub struct McpToolError(String);
233
234#[cfg(feature = "mcp")]
235impl From<McpToolError> for ToolError {
236 fn from(e: McpToolError) -> Self {
237 ToolError::ToolCallError(Box::new(e))
238 }
239}
240
241#[cfg(feature = "mcp")]
242impl<T> ToolDyn for McpTool<T>
243where
244 T: mcp_core::transport::Transport,
245{
246 fn name(&self) -> String {
247 self.definition.name.clone()
248 }
249
250 fn definition(
251 &self,
252 _prompt: String,
253 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
254 Box::pin(async move {
255 ToolDefinition {
256 name: self.definition.name.clone(),
257 description: match &self.definition.description {
258 Some(desc) => desc.clone(),
259 None => String::new(),
260 },
261 parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
262 }
263 })
264 }
265
266 fn call(
267 &self,
268 args: String,
269 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
270 let name = self.definition.name.clone();
271 let args_clone = args.clone();
272 let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
273 Box::pin(async move {
274 let result = self
275 .client
276 .call_tool(&name, Some(args))
277 .await
278 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280 if result.is_error.unwrap_or(false) {
281 if let Some(error) = result.content.first() {
282 match error {
283 mcp_core::types::ToolResponseContent::Text(text_content) => {
284 return Err(McpToolError(text_content.text.clone()).into());
285 }
286 _ => return Err(McpToolError("Unsuppported error type".to_string()).into()),
287 }
288 } else {
289 return Err(McpToolError("No error message returned".to_string()).into());
290 }
291 }
292
293 Ok(result
294 .content
295 .into_iter()
296 .map(|c| match c {
297 mcp_core::types::ToolResponseContent::Text(text_content) => text_content.text,
298 mcp_core::types::ToolResponseContent::Image(image_content) => {
299 format!(
300 "data:{};base64,{}",
301 image_content.mime_type, image_content.data
302 )
303 }
304 mcp_core::types::ToolResponseContent::Audio(audio_content) => {
305 format!(
306 "data:{};base64,{}",
307 audio_content.mime_type, audio_content.data
308 )
309 }
310
311 mcp_core::types::ToolResponseContent::Resource(embedded_resource) => {
312 format!(
313 "{}{}",
314 embedded_resource
315 .resource
316 .mime_type
317 .map(|m| format!("data:{m};"))
318 .unwrap_or_default(),
319 embedded_resource.resource.uri
320 )
321 }
322 })
323 .collect::<Vec<_>>()
324 .join(""))
325 })
326 }
327}
328
329pub trait ToolEmbeddingDyn: ToolDyn {
331 fn context(&self) -> serde_json::Result<serde_json::Value>;
332
333 fn embedding_docs(&self) -> Vec<String>;
334}
335
336impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
337 fn context(&self) -> serde_json::Result<serde_json::Value> {
338 serde_json::to_value(self.context())
339 }
340
341 fn embedding_docs(&self) -> Vec<String> {
342 self.embedding_docs()
343 }
344}
345
346pub(crate) enum ToolType {
347 Simple(Box<dyn ToolDyn>),
348 Embedding(Box<dyn ToolEmbeddingDyn>),
349}
350
351impl ToolType {
352 pub fn name(&self) -> String {
353 match self {
354 ToolType::Simple(tool) => tool.name(),
355 ToolType::Embedding(tool) => tool.name(),
356 }
357 }
358
359 pub async fn definition(&self, prompt: String) -> ToolDefinition {
360 match self {
361 ToolType::Simple(tool) => tool.definition(prompt).await,
362 ToolType::Embedding(tool) => tool.definition(prompt).await,
363 }
364 }
365
366 pub async fn call(&self, args: String) -> Result<String, ToolError> {
367 match self {
368 ToolType::Simple(tool) => tool.call(args).await,
369 ToolType::Embedding(tool) => tool.call(args).await,
370 }
371 }
372}
373
374#[derive(Debug, thiserror::Error)]
375pub enum ToolSetError {
376 #[error("ToolCallError: {0}")]
378 ToolCallError(#[from] ToolError),
379
380 #[error("ToolNotFoundError: {0}")]
381 ToolNotFoundError(String),
382
383 #[error("JsonError: {0}")]
385 JsonError(#[from] serde_json::Error),
386}
387
388#[derive(Default)]
390pub struct ToolSet {
391 pub(crate) tools: HashMap<String, ToolType>,
392}
393
394impl ToolSet {
395 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
397 let mut toolset = Self::default();
398 tools.into_iter().for_each(|tool| {
399 toolset.add_tool(tool);
400 });
401 toolset
402 }
403
404 pub fn builder() -> ToolSetBuilder {
406 ToolSetBuilder::default()
407 }
408
409 pub fn contains(&self, toolname: &str) -> bool {
411 self.tools.contains_key(toolname)
412 }
413
414 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
416 self.tools
417 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
418 }
419
420 pub fn add_tools(&mut self, toolset: ToolSet) {
422 self.tools.extend(toolset.tools);
423 }
424
425 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
426 self.tools.get(toolname)
427 }
428
429 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
431 if let Some(tool) = self.tools.get(toolname) {
432 tracing::info!(target: "rig",
433 "Calling tool {toolname} with args:\n{}",
434 serde_json::to_string_pretty(&args).unwrap()
435 );
436 Ok(tool.call(args).await?)
437 } else {
438 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
439 }
440 }
441
442 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
444 let mut docs = Vec::new();
445 for tool in self.tools.values() {
446 match tool {
447 ToolType::Simple(tool) => {
448 docs.push(completion::Document {
449 id: tool.name(),
450 text: format!(
451 "\
452 Tool: {}\n\
453 Definition: \n\
454 {}\
455 ",
456 tool.name(),
457 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
458 ),
459 additional_props: HashMap::new(),
460 });
461 }
462 ToolType::Embedding(tool) => {
463 docs.push(completion::Document {
464 id: tool.name(),
465 text: format!(
466 "\
467 Tool: {}\n\
468 Definition: \n\
469 {}\
470 ",
471 tool.name(),
472 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
473 ),
474 additional_props: HashMap::new(),
475 });
476 }
477 }
478 }
479 Ok(docs)
480 }
481
482 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
486 self.tools
487 .values()
488 .filter_map(|tool_type| {
489 if let ToolType::Embedding(tool) = tool_type {
490 Some(ToolSchema::try_from(&**tool))
491 } else {
492 None
493 }
494 })
495 .collect::<Result<Vec<_>, _>>()
496 }
497}
498
499#[derive(Default)]
500pub struct ToolSetBuilder {
501 tools: Vec<ToolType>,
502}
503
504impl ToolSetBuilder {
505 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
506 self.tools.push(ToolType::Simple(Box::new(tool)));
507 self
508 }
509
510 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
511 self.tools.push(ToolType::Embedding(Box::new(tool)));
512 self
513 }
514
515 pub fn build(self) -> ToolSet {
516 ToolSet {
517 tools: self
518 .tools
519 .into_iter()
520 .map(|tool| (tool.name(), tool))
521 .collect(),
522 }
523 }
524}