[go: up one dir, main page]

rig/
tool.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    completion::{self, ToolDefinition},
19    embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24    /// Error returned by the tool
25    #[error("ToolCallError: {0}")]
26    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28    #[error("JsonError: {0}")]
29    JsonError(#[from] serde_json::Error),
30}
31
32/// Trait that represents a simple LLM tool
33///
34/// # Example
35/// ```
36/// use rig::{
37///     completion::ToolDefinition,
38///     tool::{ToolSet, Tool},
39/// };
40///
41/// #[derive(serde::Deserialize)]
42/// struct AddArgs {
43///     x: i32,
44///     y: i32,
45/// }
46///
47/// #[derive(Debug, thiserror::Error)]
48/// #[error("Math error")]
49/// struct MathError;
50///
51/// #[derive(serde::Deserialize, serde::Serialize)]
52/// struct Adder;
53///
54/// impl Tool for Adder {
55///     const NAME: &'static str = "add";
56///
57///     type Error = MathError;
58///     type Args = AddArgs;
59///     type Output = i32;
60///
61///     async fn definition(&self, _prompt: String) -> ToolDefinition {
62///         ToolDefinition {
63///             name: "add".to_string(),
64///             description: "Add x and y together".to_string(),
65///             parameters: serde_json::json!({
66///                 "type": "object",
67///                 "properties": {
68///                     "x": {
69///                         "type": "number",
70///                         "description": "The first number to add"
71///                     },
72///                     "y": {
73///                         "type": "number",
74///                         "description": "The second number to add"
75///                     }
76///                 }
77///             })
78///         }
79///     }
80///
81///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
82///         let result = args.x + args.y;
83///         Ok(result)
84///     }
85/// }
86/// ```
87pub trait Tool: Sized + Send + Sync {
88    /// The name of the tool. This name should be unique.
89    const NAME: &'static str;
90
91    /// The error type of the tool.
92    type Error: std::error::Error + Send + Sync + 'static;
93    /// The arguments type of the tool.
94    type Args: for<'a> Deserialize<'a> + Send + Sync;
95    /// The output type of the tool.
96    type Output: Serialize;
97
98    /// A method returning the name of the tool.
99    fn name(&self) -> String {
100        Self::NAME.to_string()
101    }
102
103    /// A method returning the tool definition. The user prompt can be used to
104    /// tailor the definition to the specific use case.
105    fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107    /// The tool execution method.
108    /// Both the arguments and return value are a String since these values are meant to
109    /// be the output and input of LLM models (respectively)
110    fn call(
111        &self,
112        args: Self::Args,
113    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + Sync;
114}
115
116/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
117pub trait ToolEmbedding: Tool {
118    type InitError: std::error::Error + Send + Sync + 'static;
119
120    /// Type of the tool' context. This context will be saved and loaded from the
121    /// vector store when ragging the tool.
122    /// This context can be used to store the tool's static configuration and local
123    /// context.
124    type Context: for<'a> Deserialize<'a> + Serialize;
125
126    /// Type of the tool's state. This state will be passed to the tool when initializing it.
127    /// This state can be used to pass runtime arguments to the tool such as clients,
128    /// API keys and other configuration.
129    type State: Send;
130
131    /// A method returning the documents that will be used as embeddings for the tool.
132    /// This allows for a tool to be retrieved from multiple embedding "directions".
133    /// If the tool will not be RAGged, this method should return an empty vector.
134    fn embedding_docs(&self) -> Vec<String>;
135
136    /// A method returning the context of the tool.
137    fn context(&self) -> Self::Context;
138
139    /// A method to initialize the tool from the context, and a state.
140    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143/// Wrapper trait to allow for dynamic dispatch of simple tools
144pub trait ToolDyn: Send + Sync {
145    fn name(&self) -> String;
146
147    fn definition(
148        &self,
149        prompt: String,
150    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152    fn call(
153        &self,
154        args: String,
155    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159    fn name(&self) -> String {
160        self.name()
161    }
162
163    fn definition(
164        &self,
165        prompt: String,
166    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167        Box::pin(<Self as Tool>::definition(self, prompt))
168    }
169
170    fn call(
171        &self,
172        args: String,
173    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
174        Box::pin(async move {
175            match serde_json::from_str(&args) {
176                Ok(args) => <Self as Tool>::call(self, args)
177                    .await
178                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179                    .and_then(|output| {
180                        serde_json::to_string(&output).map_err(ToolError::JsonError)
181                    }),
182                Err(e) => Err(ToolError::JsonError(e)),
183            }
184        })
185    }
186}
187
188#[cfg(feature = "mcp")]
189pub struct McpTool<T: mcp_core::transport::Transport> {
190    definition: mcp_core::types::Tool,
191    client: mcp_core::client::Client<T>,
192}
193
194#[cfg(feature = "mcp")]
195impl<T> McpTool<T>
196where
197    T: mcp_core::transport::Transport,
198{
199    pub fn from_mcp_server(
200        definition: mcp_core::types::Tool,
201        client: mcp_core::client::Client<T>,
202    ) -> Self {
203        Self { definition, client }
204    }
205}
206
207#[cfg(feature = "mcp")]
208impl From<&mcp_core::types::Tool> for ToolDefinition {
209    fn from(val: &mcp_core::types::Tool) -> Self {
210        Self {
211            name: val.name.to_owned(),
212            description: val.description.to_owned().unwrap_or_default(),
213            parameters: val.input_schema.to_owned(),
214        }
215    }
216}
217
218#[cfg(feature = "mcp")]
219impl From<mcp_core::types::Tool> for ToolDefinition {
220    fn from(val: mcp_core::types::Tool) -> Self {
221        Self {
222            name: val.name,
223            description: val.description.unwrap_or_default(),
224            parameters: val.input_schema,
225        }
226    }
227}
228
229#[cfg(feature = "mcp")]
230#[derive(Debug, thiserror::Error)]
231#[error("MCP tool error: {0}")]
232pub struct McpToolError(String);
233
234#[cfg(feature = "mcp")]
235impl From<McpToolError> for ToolError {
236    fn from(e: McpToolError) -> Self {
237        ToolError::ToolCallError(Box::new(e))
238    }
239}
240
241#[cfg(feature = "mcp")]
242impl<T> ToolDyn for McpTool<T>
243where
244    T: mcp_core::transport::Transport,
245{
246    fn name(&self) -> String {
247        self.definition.name.clone()
248    }
249
250    fn definition(
251        &self,
252        _prompt: String,
253    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
254        Box::pin(async move {
255            ToolDefinition {
256                name: self.definition.name.clone(),
257                description: match &self.definition.description {
258                    Some(desc) => desc.clone(),
259                    None => String::new(),
260                },
261                parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
262            }
263        })
264    }
265
266    fn call(
267        &self,
268        args: String,
269    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
270        let name = self.definition.name.clone();
271        let args_clone = args.clone();
272        let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
273        Box::pin(async move {
274            let result = self
275                .client
276                .call_tool(&name, Some(args))
277                .await
278                .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280            if result.is_error.unwrap_or(false) {
281                if let Some(error) = result.content.first() {
282                    match error {
283                        mcp_core::types::ToolResponseContent::Text(text_content) => {
284                            return Err(McpToolError(text_content.text.clone()).into());
285                        }
286                        _ => return Err(McpToolError("Unsuppported error type".to_string()).into()),
287                    }
288                } else {
289                    return Err(McpToolError("No error message returned".to_string()).into());
290                }
291            }
292
293            Ok(result
294                .content
295                .into_iter()
296                .map(|c| match c {
297                    mcp_core::types::ToolResponseContent::Text(text_content) => text_content.text,
298                    mcp_core::types::ToolResponseContent::Image(image_content) => {
299                        format!(
300                            "data:{};base64,{}",
301                            image_content.mime_type, image_content.data
302                        )
303                    }
304                    mcp_core::types::ToolResponseContent::Audio(audio_content) => {
305                        format!(
306                            "data:{};base64,{}",
307                            audio_content.mime_type, audio_content.data
308                        )
309                    }
310
311                    mcp_core::types::ToolResponseContent::Resource(embedded_resource) => {
312                        format!(
313                            "{}{}",
314                            embedded_resource
315                                .resource
316                                .mime_type
317                                .map(|m| format!("data:{m};"))
318                                .unwrap_or_default(),
319                            embedded_resource.resource.uri
320                        )
321                    }
322                })
323                .collect::<Vec<_>>()
324                .join(""))
325        })
326    }
327}
328
329/// Wrapper trait to allow for dynamic dispatch of raggable tools
330pub trait ToolEmbeddingDyn: ToolDyn {
331    fn context(&self) -> serde_json::Result<serde_json::Value>;
332
333    fn embedding_docs(&self) -> Vec<String>;
334}
335
336impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
337    fn context(&self) -> serde_json::Result<serde_json::Value> {
338        serde_json::to_value(self.context())
339    }
340
341    fn embedding_docs(&self) -> Vec<String> {
342        self.embedding_docs()
343    }
344}
345
346pub(crate) enum ToolType {
347    Simple(Box<dyn ToolDyn>),
348    Embedding(Box<dyn ToolEmbeddingDyn>),
349}
350
351impl ToolType {
352    pub fn name(&self) -> String {
353        match self {
354            ToolType::Simple(tool) => tool.name(),
355            ToolType::Embedding(tool) => tool.name(),
356        }
357    }
358
359    pub async fn definition(&self, prompt: String) -> ToolDefinition {
360        match self {
361            ToolType::Simple(tool) => tool.definition(prompt).await,
362            ToolType::Embedding(tool) => tool.definition(prompt).await,
363        }
364    }
365
366    pub async fn call(&self, args: String) -> Result<String, ToolError> {
367        match self {
368            ToolType::Simple(tool) => tool.call(args).await,
369            ToolType::Embedding(tool) => tool.call(args).await,
370        }
371    }
372}
373
374#[derive(Debug, thiserror::Error)]
375pub enum ToolSetError {
376    /// Error returned by the tool
377    #[error("ToolCallError: {0}")]
378    ToolCallError(#[from] ToolError),
379
380    #[error("ToolNotFoundError: {0}")]
381    ToolNotFoundError(String),
382
383    // TODO: Revisit this
384    #[error("JsonError: {0}")]
385    JsonError(#[from] serde_json::Error),
386}
387
388/// A struct that holds a set of tools
389#[derive(Default)]
390pub struct ToolSet {
391    pub(crate) tools: HashMap<String, ToolType>,
392}
393
394impl ToolSet {
395    /// Create a new ToolSet from a list of tools
396    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
397        let mut toolset = Self::default();
398        tools.into_iter().for_each(|tool| {
399            toolset.add_tool(tool);
400        });
401        toolset
402    }
403
404    /// Create a toolset builder
405    pub fn builder() -> ToolSetBuilder {
406        ToolSetBuilder::default()
407    }
408
409    /// Check if the toolset contains a tool with the given name
410    pub fn contains(&self, toolname: &str) -> bool {
411        self.tools.contains_key(toolname)
412    }
413
414    /// Add a tool to the toolset
415    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
416        self.tools
417            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
418    }
419
420    /// Merge another toolset into this one
421    pub fn add_tools(&mut self, toolset: ToolSet) {
422        self.tools.extend(toolset.tools);
423    }
424
425    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
426        self.tools.get(toolname)
427    }
428
429    /// Call a tool with the given name and arguments
430    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
431        if let Some(tool) = self.tools.get(toolname) {
432            tracing::info!(target: "rig",
433                "Calling tool {toolname} with args:\n{}",
434                serde_json::to_string_pretty(&args).unwrap()
435            );
436            Ok(tool.call(args).await?)
437        } else {
438            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
439        }
440    }
441
442    /// Get the documents of all the tools in the toolset
443    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
444        let mut docs = Vec::new();
445        for tool in self.tools.values() {
446            match tool {
447                ToolType::Simple(tool) => {
448                    docs.push(completion::Document {
449                        id: tool.name(),
450                        text: format!(
451                            "\
452                            Tool: {}\n\
453                            Definition: \n\
454                            {}\
455                        ",
456                            tool.name(),
457                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
458                        ),
459                        additional_props: HashMap::new(),
460                    });
461                }
462                ToolType::Embedding(tool) => {
463                    docs.push(completion::Document {
464                        id: tool.name(),
465                        text: format!(
466                            "\
467                            Tool: {}\n\
468                            Definition: \n\
469                            {}\
470                        ",
471                            tool.name(),
472                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
473                        ),
474                        additional_props: HashMap::new(),
475                    });
476                }
477            }
478        }
479        Ok(docs)
480    }
481
482    /// Convert tools in self to objects of type ToolSchema.
483    /// This is necessary because when adding tools to the EmbeddingBuilder because all
484    /// documents added to the builder must all be of the same type.
485    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
486        self.tools
487            .values()
488            .filter_map(|tool_type| {
489                if let ToolType::Embedding(tool) = tool_type {
490                    Some(ToolSchema::try_from(&**tool))
491                } else {
492                    None
493                }
494            })
495            .collect::<Result<Vec<_>, _>>()
496    }
497}
498
499#[derive(Default)]
500pub struct ToolSetBuilder {
501    tools: Vec<ToolType>,
502}
503
504impl ToolSetBuilder {
505    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
506        self.tools.push(ToolType::Simple(Box::new(tool)));
507        self
508    }
509
510    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
511        self.tools.push(ToolType::Embedding(Box::new(tool)));
512        self
513    }
514
515    pub fn build(self) -> ToolSet {
516        ToolSet {
517            tools: self
518                .tools
519                .into_iter()
520                .map(|tool| (tool.name(), tool))
521                .collect(),
522        }
523    }
524}