diff --git a/Makefile.toml b/Makefile.toml index 5619e0a..f8b65af 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -29,6 +29,10 @@ args = [ "rust-mcp-transport", ] + +[tasks.check] +dependencies = ["fmt", "clippy", "test", "doc-test"] + [tasks.clippy-fix] command = "cargo" args = ["clippy", "--fix", "--allow-dirty"] diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 36cb5dd..2e61e9e 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -54,7 +54,7 @@ impl ClientRuntime { async fn initialize_request(&self) -> SdkResult<()> { let request = InitializeRequest::new(self.client_details.clone()); - let result: ServerResult = self.request(request.into()).await?.try_into()?; + let result: ServerResult = self.request(request.into(), None).await?.try_into()?; if let ServerResult::InitializeResult(initialize_result) = result { // store server details @@ -147,7 +147,9 @@ impl McpClient for ClientRuntime { Err(error_value) => MessageFromClient::Error(error_value), }; // send the response back with corresponding request id - sender.send(response, Some(jsonrpc_request.id)).await?; + sender + .send(response, Some(jsonrpc_request.id), None) + .await?; } ServerMessage::Notification(jsonrpc_notification) => { self_ref diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index afb0d46..b6828a6 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -107,7 +107,7 @@ impl McpServer for ServerRuntime { // send the response back with corresponding request id sender - .send(response, Some(client_jsonrpc_request.id)) + .send(response, Some(client_jsonrpc_request.id), None) .await?; } ClientMessage::Notification(client_jsonrpc_notification) => { diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 909dea9..b3662c8 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use rust_mcp_schema::{ @@ -170,7 +170,11 @@ pub trait McpClient: Sync + Send { /// This function sends a `RequestFromClient` message to the server, waits for the response, /// and handles the result. If the response is empty or of an invalid type, an error is returned. /// Otherwise, it returns the result from the server. - async fn request(&self, request: RequestFromClient) -> SdkResult { + async fn request( + &self, + request: RequestFromClient, + timeout: Option, + ) -> SdkResult { let sender = self.sender().await.read().await; let sender = sender .as_ref() @@ -178,7 +182,7 @@ pub trait McpClient: Sync + Send { // Send the request and receive the response. let response = sender - .send(MessageFromClient::RequestFromClient(request), None) + .send(MessageFromClient::RequestFromClient(request), None, timeout) .await?; let server_message = response.ok_or_else(|| { @@ -205,6 +209,7 @@ pub trait McpClient: Sync + Send { .send( MessageFromClient::NotificationFromClient(notification), None, + None, ) .await?; Ok(()) @@ -220,9 +225,9 @@ pub trait McpClient: Sync + Send { /// # Returns /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. /// If the request or conversion fails, an error is returned. - async fn ping(&self) -> SdkResult { + async fn ping(&self, timeout: Option) -> SdkResult { let ping_request = PingRequest::new(None); - let response = self.request(ping_request.into()).await?; + let response = self.request(ping_request.into(), timeout).await?; Ok(response.try_into()?) } @@ -231,13 +236,13 @@ pub trait McpClient: Sync + Send { params: CompleteRequestParams, ) -> SdkResult { let request = CompleteRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult { let request = SetLevelRequest::new(SetLevelRequestParams { level }); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -246,7 +251,7 @@ pub trait McpClient: Sync + Send { params: GetPromptRequestParams, ) -> SdkResult { let request = GetPromptRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -255,7 +260,7 @@ pub trait McpClient: Sync + Send { params: Option, ) -> SdkResult { let request = ListPromptsRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -269,7 +274,7 @@ pub trait McpClient: Sync + Send { // that excepts an empty params to be passed (like server-everything) let request = ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default()))); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -278,7 +283,7 @@ pub trait McpClient: Sync + Send { params: Option, ) -> SdkResult { let request = ListResourceTemplatesRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -287,7 +292,7 @@ pub trait McpClient: Sync + Send { params: ReadResourceRequestParams, ) -> SdkResult { let request = ReadResourceRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -296,7 +301,7 @@ pub trait McpClient: Sync + Send { params: SubscribeRequestParams, ) -> SdkResult { let request = SubscribeRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -305,13 +310,13 @@ pub trait McpClient: Sync + Send { params: UnsubscribeRequestParams, ) -> SdkResult { let request = UnsubscribeRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult { let request = CallToolRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } @@ -320,7 +325,7 @@ pub trait McpClient: Sync + Send { params: Option, ) -> SdkResult { let request = ListToolsRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; Ok(response.try_into()?) } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index c80df0e..79ad5fc 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use async_trait::async_trait; use rust_mcp_schema::{ schema_utils::{ @@ -62,14 +64,18 @@ pub trait McpServer: Sync + Send { /// This function sends a `RequestFromServer` message to the client, waits for the response, /// and handles the result. If the response is empty or of an invalid type, an error is returned. /// Otherwise, it returns the result from the client. - async fn request(&self, request: RequestFromServer) -> SdkResult { + async fn request( + &self, + request: RequestFromServer, + timeout: Option, + ) -> SdkResult { let sender = self.sender().await; let sender = sender.read().await; let sender = sender.as_ref().unwrap(); // Send the request and receive the response. let response = sender - .send(MessageFromServer::RequestFromServer(request), None) + .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; let client_message = response.ok_or_else(|| { RpcError::internal_error() @@ -95,6 +101,7 @@ pub trait McpServer: Sync + Send { .send( MessageFromServer::NotificationFromServer(notification), None, + None, ) .await?; Ok(()) @@ -110,7 +117,7 @@ pub trait McpServer: Sync + Send { params: Option, ) -> SdkResult { let request: ListRootsRequest = ListRootsRequest::new(params); - let response = self.request(request.into()).await?; + let response = self.request(request.into(), None).await?; ListRootsResult::try_from(response).map_err(|err| err.into()) } @@ -178,9 +185,9 @@ pub trait McpServer: Sync + Send { /// # Returns /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. /// If the request or conversion fails, an error is returned. - async fn ping(&self) -> SdkResult { + async fn ping(&self, timeout: Option) -> SdkResult { let ping_request = PingRequest::new(None); - let response = self.request(ping_request.into()).await?; + let response = self.request(ping_request.into(), timeout).await?; Ok(response.try_into()?) } @@ -194,7 +201,7 @@ pub trait McpServer: Sync + Send { params: CreateMessageRequestParams, ) -> SdkResult { let ping_request = CreateMessageRequest::new(params); - let response = self.request(ping_request.into()).await?; + let response = self.request(ping_request.into(), None).await?; Ok(response.try_into()?) } diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 754e307..9ecd6e6 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -9,6 +9,7 @@ use std::{ collections::HashMap, pin::Pin, sync::{atomic::AtomicI64, Arc}, + time::Duration, }; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -34,7 +35,7 @@ impl MCPStream { writable: Mutex>>, error_io: IoStream, pending_requests: Arc>>>, - timeout_msec: u64, + request_timeout: Duration, shutdown_rx: Receiver, ) -> ( Pin + Send>>, @@ -62,7 +63,7 @@ impl MCPStream { pending_requests, writable, Arc::new(AtomicI64::new(0)), - timeout_msec, + request_timeout, ); (stream, sender, error_io) diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 8ae5cdd..84ddd61 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -28,7 +28,7 @@ pub struct MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, message_id_counter: Arc, - timeout_msec: u64, + request_timeout: Duration, } impl MessageDispatcher { @@ -38,7 +38,7 @@ impl MessageDispatcher { /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels. /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages. /// * `message_id_counter` - An atomic counter for generating unique request IDs. - /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses. + /// * `request_timeout` - The timeout duration in milliseconds for awaiting responses. /// /// # Returns /// A new `MessageDispatcher` instance configured for MCP message handling. @@ -46,13 +46,13 @@ impl MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, message_id_counter: Arc, - timeout_msec: u64, + request_timeout: Duration, ) -> Self { Self { pending_requests, writable_std, message_id_counter, - timeout_msec, + request_timeout, } } @@ -112,6 +112,7 @@ impl McpDispatch for MessageDispatcher, + request_timeout: Option, ) -> TransportResult> { let mut writable_std = self.writable_std.lock().await; @@ -148,7 +149,7 @@ impl McpDispatch for MessageDispatcher Ok(Some(response)), Err(error) => match error { TransportError::OneshotRecvError(_) => { @@ -185,6 +186,7 @@ impl McpDispatch for MessageDispatcher, + request_timeout: Option, ) -> TransportResult> { let mut writable_std = self.writable_std.lock().await; @@ -220,7 +222,7 @@ impl McpDispatch for MessageDispatcher Ok(Some(response)), Err(error) => Err(error), } diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 06710f0..3efc840 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,4 +1,4 @@ -use std::pin::Pin; +use std::{pin::Pin, time::Duration}; use async_trait::async_trait; use rust_mcp_schema::{schema_utils::McpMessage, RequestId}; @@ -29,12 +29,12 @@ pub struct TransportOptions { /// /// This value defines the maximum amount of time to wait for a response before /// considering the request as timed out. - pub timeout: u64, + pub timeout: Duration, } impl Default for TransportOptions { fn default() -> Self { Self { - timeout: DEFAULT_TIMEOUT_MSEC, + timeout: Duration::from_millis(DEFAULT_TIMEOUT_MSEC), } } } @@ -84,7 +84,12 @@ where /// Sends a raw message represented by type `S` and optionally includes a `request_id`. /// The `request_id` is used when sending a message in response to an MCP request. /// It should match the `request_id` of the original request. - async fn send(&self, message: S, request_id: Option) -> TransportResult>; + async fn send( + &self, + message: S, + request_id: Option, + request_timeout: Option, + ) -> TransportResult>; } /// A trait representing the transport layer for MCP. diff --git a/examples/simple-mcp-client-core/src/inquiry_utils.rs b/examples/simple-mcp-client-core/src/inquiry_utils.rs index 7f27ca9..d6db24b 100644 --- a/examples/simple-mcp-client-core/src/inquiry_utils.rs +++ b/examples/simple-mcp-client-core/src/inquiry_utils.rs @@ -204,7 +204,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({} out of {})...", ping_index, max_pings); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping().await; + let ping_result = self.client.ping(None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index, diff --git a/examples/simple-mcp-client/src/inquiry_utils.rs b/examples/simple-mcp-client/src/inquiry_utils.rs index 7f27ca9..d6db24b 100644 --- a/examples/simple-mcp-client/src/inquiry_utils.rs +++ b/examples/simple-mcp-client/src/inquiry_utils.rs @@ -204,7 +204,7 @@ impl InquiryUtils { for ping_index in 1..=max_pings { print!("Ping the server ({} out of {})...", ping_index, max_pings); std::io::stdout().flush().unwrap(); - let ping_result = self.client.ping().await; + let ping_result = self.client.ping(None).await; print!( "\rPing the server ({} out of {}) : {}", ping_index,