Skip to content

chore: improve message dispatcher duration #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ args = [
"rust-mcp-transport",
]


[tasks.check]
dependencies = ["fmt", "clippy", "test", "doc-test"]

[tasks.clippy-fix]
command = "cargo"
args = ["clippy", "--fix", "--allow-dirty"]
6 changes: 4 additions & 2 deletions crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl ClientRuntime {

async fn initialize_request(&self) -> SdkResult<()> {
let request = InitializeRequest::new(self.client_details.clone());
let result: ServerResult = self.request(request.into()).await?.try_into()?;
let result: ServerResult = self.request(request.into(), None).await?.try_into()?;

if let ServerResult::InitializeResult(initialize_result) = result {
// store server details
Expand Down Expand Up @@ -147,7 +147,9 @@ impl McpClient for ClientRuntime {
Err(error_value) => MessageFromClient::Error(error_value),
};
// send the response back with corresponding request id
sender.send(response, Some(jsonrpc_request.id)).await?;
sender
.send(response, Some(jsonrpc_request.id), None)
.await?;
}
ServerMessage::Notification(jsonrpc_notification) => {
self_ref
Expand Down
2 changes: 1 addition & 1 deletion crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl McpServer for ServerRuntime {

// send the response back with corresponding request id
sender
.send(response, Some(client_jsonrpc_request.id))
.send(response, Some(client_jsonrpc_request.id), None)
.await?;
}
ClientMessage::Notification(client_jsonrpc_notification) => {
Expand Down
37 changes: 21 additions & 16 deletions crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use async_trait::async_trait;
use rust_mcp_schema::{
Expand Down Expand Up @@ -170,15 +170,19 @@ pub trait McpClient: Sync + Send {
/// This function sends a `RequestFromClient` message to the server, waits for the response,
/// and handles the result. If the response is empty or of an invalid type, an error is returned.
/// Otherwise, it returns the result from the server.
async fn request(&self, request: RequestFromClient) -> SdkResult<ResultFromServer> {
async fn request(
&self,
request: RequestFromClient,
timeout: Option<Duration>,
) -> SdkResult<ResultFromServer> {
let sender = self.sender().await.read().await;
let sender = sender
.as_ref()
.ok_or(schema_utils::SdkError::connection_closed())?;

// Send the request and receive the response.
let response = sender
.send(MessageFromClient::RequestFromClient(request), None)
.send(MessageFromClient::RequestFromClient(request), None, timeout)
.await?;

let server_message = response.ok_or_else(|| {
Expand All @@ -205,6 +209,7 @@ pub trait McpClient: Sync + Send {
.send(
MessageFromClient::NotificationFromClient(notification),
None,
None,
)
.await?;
Ok(())
Expand All @@ -220,9 +225,9 @@ pub trait McpClient: Sync + Send {
/// # Returns
/// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
/// If the request or conversion fails, an error is returned.
async fn ping(&self) -> SdkResult<rust_mcp_schema::Result> {
async fn ping(&self, timeout: Option<Duration>) -> SdkResult<rust_mcp_schema::Result> {
let ping_request = PingRequest::new(None);
let response = self.request(ping_request.into()).await?;
let response = self.request(ping_request.into(), timeout).await?;
Ok(response.try_into()?)
}

Expand All @@ -231,13 +236,13 @@ pub trait McpClient: Sync + Send {
params: CompleteRequestParams,
) -> SdkResult<rust_mcp_schema::CompleteResult> {
let request = CompleteRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

async fn set_logging_level(&self, level: LoggingLevel) -> SdkResult<rust_mcp_schema::Result> {
let request = SetLevelRequest::new(SetLevelRequestParams { level });
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -246,7 +251,7 @@ pub trait McpClient: Sync + Send {
params: GetPromptRequestParams,
) -> SdkResult<rust_mcp_schema::GetPromptResult> {
let request = GetPromptRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -255,7 +260,7 @@ pub trait McpClient: Sync + Send {
params: Option<ListPromptsRequestParams>,
) -> SdkResult<rust_mcp_schema::ListPromptsResult> {
let request = ListPromptsRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -269,7 +274,7 @@ pub trait McpClient: Sync + Send {
// that excepts an empty params to be passed (like server-everything)
let request =
ListResourcesRequest::new(params.or(Some(ListResourcesRequestParams::default())));
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -278,7 +283,7 @@ pub trait McpClient: Sync + Send {
params: Option<ListResourceTemplatesRequestParams>,
) -> SdkResult<rust_mcp_schema::ListResourceTemplatesResult> {
let request = ListResourceTemplatesRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -287,7 +292,7 @@ pub trait McpClient: Sync + Send {
params: ReadResourceRequestParams,
) -> SdkResult<rust_mcp_schema::ReadResourceResult> {
let request = ReadResourceRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -296,7 +301,7 @@ pub trait McpClient: Sync + Send {
params: SubscribeRequestParams,
) -> SdkResult<rust_mcp_schema::Result> {
let request = SubscribeRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -305,13 +310,13 @@ pub trait McpClient: Sync + Send {
params: UnsubscribeRequestParams,
) -> SdkResult<rust_mcp_schema::Result> {
let request = UnsubscribeRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

async fn call_tool(&self, params: CallToolRequestParams) -> SdkResult<CallToolResult> {
let request = CallToolRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand All @@ -320,7 +325,7 @@ pub trait McpClient: Sync + Send {
params: Option<ListToolsRequestParams>,
) -> SdkResult<rust_mcp_schema::ListToolsResult> {
let request = ListToolsRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
Ok(response.try_into()?)
}

Expand Down
19 changes: 13 additions & 6 deletions crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use async_trait::async_trait;
use rust_mcp_schema::{
schema_utils::{
Expand Down Expand Up @@ -62,14 +64,18 @@ pub trait McpServer: Sync + Send {
/// This function sends a `RequestFromServer` message to the client, waits for the response,
/// and handles the result. If the response is empty or of an invalid type, an error is returned.
/// Otherwise, it returns the result from the client.
async fn request(&self, request: RequestFromServer) -> SdkResult<ResultFromClient> {
async fn request(
&self,
request: RequestFromServer,
timeout: Option<Duration>,
) -> SdkResult<ResultFromClient> {
let sender = self.sender().await;
let sender = sender.read().await;
let sender = sender.as_ref().unwrap();

// Send the request and receive the response.
let response = sender
.send(MessageFromServer::RequestFromServer(request), None)
.send(MessageFromServer::RequestFromServer(request), None, timeout)
.await?;
let client_message = response.ok_or_else(|| {
RpcError::internal_error()
Expand All @@ -95,6 +101,7 @@ pub trait McpServer: Sync + Send {
.send(
MessageFromServer::NotificationFromServer(notification),
None,
None,
)
.await?;
Ok(())
Expand All @@ -110,7 +117,7 @@ pub trait McpServer: Sync + Send {
params: Option<ListRootsRequestParams>,
) -> SdkResult<ListRootsResult> {
let request: ListRootsRequest = ListRootsRequest::new(params);
let response = self.request(request.into()).await?;
let response = self.request(request.into(), None).await?;
ListRootsResult::try_from(response).map_err(|err| err.into())
}

Expand Down Expand Up @@ -178,9 +185,9 @@ pub trait McpServer: Sync + Send {
/// # Returns
/// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful.
/// If the request or conversion fails, an error is returned.
async fn ping(&self) -> SdkResult<rust_mcp_schema::Result> {
async fn ping(&self, timeout: Option<Duration>) -> SdkResult<rust_mcp_schema::Result> {
let ping_request = PingRequest::new(None);
let response = self.request(ping_request.into()).await?;
let response = self.request(ping_request.into(), timeout).await?;
Ok(response.try_into()?)
}

Expand All @@ -194,7 +201,7 @@ pub trait McpServer: Sync + Send {
params: CreateMessageRequestParams,
) -> SdkResult<CreateMessageResult> {
let ping_request = CreateMessageRequest::new(params);
let response = self.request(ping_request.into()).await?;
let response = self.request(ping_request.into(), None).await?;
Ok(response.try_into()?)
}

Expand Down
5 changes: 3 additions & 2 deletions crates/rust-mcp-transport/src/mcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
collections::HashMap,
pin::Pin,
sync::{atomic::AtomicI64, Arc},
time::Duration,
};
use tokio::{
io::{AsyncBufReadExt, BufReader},
Expand All @@ -34,7 +35,7 @@ impl MCPStream {
writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
error_io: IoStream,
pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
timeout_msec: u64,
request_timeout: Duration,
shutdown_rx: Receiver<bool>,
) -> (
Pin<Box<dyn Stream<Item = R> + Send>>,
Expand Down Expand Up @@ -62,7 +63,7 @@ impl MCPStream {
pending_requests,
writable,
Arc::new(AtomicI64::new(0)),
timeout_msec,
request_timeout,
);

(stream, sender, error_io)
Expand Down
14 changes: 8 additions & 6 deletions crates/rust-mcp-transport/src/message_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct MessageDispatcher<R> {
pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
message_id_counter: Arc<AtomicI64>,
timeout_msec: u64,
request_timeout: Duration,
}

impl<R> MessageDispatcher<R> {
Expand All @@ -38,21 +38,21 @@ impl<R> MessageDispatcher<R> {
/// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
/// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
/// * `message_id_counter` - An atomic counter for generating unique request IDs.
/// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
/// * `request_timeout` - The timeout duration in milliseconds for awaiting responses.
///
/// # Returns
/// A new `MessageDispatcher` instance configured for MCP message handling.
pub fn new(
pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
message_id_counter: Arc<AtomicI64>,
timeout_msec: u64,
request_timeout: Duration,
) -> Self {
Self {
pending_requests,
writable_std,
message_id_counter,
timeout_msec,
request_timeout,
}
}

Expand Down Expand Up @@ -112,6 +112,7 @@ impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerM
&self,
message: MessageFromClient,
request_id: Option<RequestId>,
request_timeout: Option<Duration>,
) -> TransportResult<Option<ServerMessage>> {
let mut writable_std = self.writable_std.lock().await;

Expand Down Expand Up @@ -148,7 +149,7 @@ impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerM

if let Some(rx) = rx_response {
// Wait for the response with timeout
match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
Ok(response) => Ok(Some(response)),
Err(error) => match error {
TransportError::OneshotRecvError(_) => {
Expand Down Expand Up @@ -185,6 +186,7 @@ impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientM
&self,
message: MessageFromServer,
request_id: Option<RequestId>,
request_timeout: Option<Duration>,
) -> TransportResult<Option<ClientMessage>> {
let mut writable_std = self.writable_std.lock().await;

Expand Down Expand Up @@ -220,7 +222,7 @@ impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientM
writable_std.flush().await?;

if let Some(rx) = rx_response {
match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
Ok(response) => Ok(Some(response)),
Err(error) => Err(error),
}
Expand Down
13 changes: 9 additions & 4 deletions crates/rust-mcp-transport/src/transport.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::pin::Pin;
use std::{pin::Pin, time::Duration};

use async_trait::async_trait;
use rust_mcp_schema::{schema_utils::McpMessage, RequestId};
Expand Down Expand Up @@ -29,12 +29,12 @@ pub struct TransportOptions {
///
/// This value defines the maximum amount of time to wait for a response before
/// considering the request as timed out.
pub timeout: u64,
pub timeout: Duration,
}
impl Default for TransportOptions {
fn default() -> Self {
Self {
timeout: DEFAULT_TIMEOUT_MSEC,
timeout: Duration::from_millis(DEFAULT_TIMEOUT_MSEC),
}
}
}
Expand Down Expand Up @@ -84,7 +84,12 @@ where
/// Sends a raw message represented by type `S` and optionally includes a `request_id`.
/// The `request_id` is used when sending a message in response to an MCP request.
/// It should match the `request_id` of the original request.
async fn send(&self, message: S, request_id: Option<RequestId>) -> TransportResult<Option<R>>;
async fn send(
&self,
message: S,
request_id: Option<RequestId>,
request_timeout: Option<Duration>,
) -> TransportResult<Option<R>>;
}

/// A trait representing the transport layer for MCP.
Expand Down
2 changes: 1 addition & 1 deletion examples/simple-mcp-client-core/src/inquiry_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl InquiryUtils {
for ping_index in 1..=max_pings {
print!("Ping the server ({} out of {})...", ping_index, max_pings);
std::io::stdout().flush().unwrap();
let ping_result = self.client.ping().await;
let ping_result = self.client.ping(None).await;
print!(
"\rPing the server ({} out of {}) : {}",
ping_index,
Expand Down
2 changes: 1 addition & 1 deletion examples/simple-mcp-client/src/inquiry_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl InquiryUtils {
for ping_index in 1..=max_pings {
print!("Ping the server ({} out of {})...", ping_index, max_pings);
std::io::stdout().flush().unwrap();
let ping_result = self.client.ping().await;
let ping_result = self.client.ping(None).await;
print!(
"\rPing the server ({} out of {}) : {}",
ping_index,
Expand Down