Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions rust/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,37 @@ The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) a

Reach for the `ToolHandler` trait directly when you need shared state across multiple methods or want a named type that shows up by name in stack traces.

### Tool Handler Cancellation

Every `ToolInvocation` carries an optional `cancellation_token: Option<CancellationToken>` that fires when the in-flight handler should stop early. The SDK populates it on dispatch; it's `None` only for invocations you construct yourself (e.g. in tests). Two sources can cancel it:

- **`session.abort().await?`** — cancels all currently in-flight handlers and also sends the `session.abort` RPC to stop the agentic loop.
- **`session.cancel_tool_call(tool_call_id)`** — cancels only the named handler without affecting others or the agentic loop. Returns `true` if an in-flight handler with that ID was found; `false` otherwise.

Handlers that don't need cancellation can ignore the token. Handlers that do long-running work can cooperate:

```rust,ignore
use github_copilot_sdk::tool::ToolHandler;
use github_copilot_sdk::types::ToolInvocation;
use github_copilot_sdk::{Error, ErrorKind, ToolResult};

struct LongRunningTool;

impl ToolHandler for LongRunningTool {
async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
let Some(token) = inv.cancellation_token.clone() else {
return do_expensive_work().await;
};
tokio::select! {
_ = token.cancelled() => {
Err(Error::with_message(ErrorKind::Cancelled, "tool call cancelled"))
}
result = do_expensive_work() => result,
}
}
}
```

### Permission Policies

Set a permission policy directly on `SessionConfig` with the chainable builders. They install a synthesized `PermissionHandler` so only permission requests are intercepted; every other event flows through unchanged.
Expand Down
170 changes: 169 additions & 1 deletion rust/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ pub struct Session {
/// via [`Session::cancellation_token`] to bind their own work to
/// the session lifetime.
shutdown: CancellationToken,
/// Cancellation tokens for all currently in-flight tool handlers, keyed
/// by `tool_call_id`.
///
/// Each dispatched [`ToolInvocation`](crate::types::ToolInvocation)
/// receives a child token registered here. [`Session::abort`] cancels
/// every token in the map; [`Session::cancel_tool_call`] cancels exactly
/// one. The event-loop task removes the entry once the handler future
/// resolves. Shared with the event loop via `Arc<ParkingLotMutex<…>>`.
in_flight_tool_calls: Arc<ParkingLotMutex<HashMap<String, CancellationToken>>>,
/// Only populated while a `send_and_wait` call is in flight.
///
/// Sync `parking_lot::Mutex` because the lock is never held across an
Expand Down Expand Up @@ -500,12 +509,29 @@ impl Session {

/// Abort the current agent turn.
///
/// Cancels the agentic loop and propagates cancellation to all in-flight
/// tool handlers via the [`CancellationToken`] on each
/// [`ToolInvocation`](crate::types::ToolInvocation). Handlers can check
/// [`is_cancelled()`](CancellationToken::is_cancelled) or `select!` on
/// [`cancelled()`](CancellationToken::cancelled) to stop early.
///
/// To cancel a single handler without aborting the agentic loop, use
/// [`cancel_tool_call`](Self::cancel_tool_call) instead.
///
/// # Cancel safety
///
/// **Cancel-safe.** Single `session.abort` RPC; the underlying
/// [`Client::call`](crate::Client::call) is cancel-safe via the
/// writer-actor.
pub async fn abort(&self) -> Result<(), Error> {
// Cancel all in-flight handlers before sending the RPC so they can
// begin cleanup while the network round-trip is in flight.
{
let guard = self.in_flight_tool_calls.lock();
for token in guard.values() {
token.cancel();
}
}
self.client
Comment thread
gimenete marked this conversation as resolved.
.call(
"session.abort",
Expand All @@ -515,6 +541,25 @@ impl Session {
Ok(())
}

/// Cancel a single in-flight tool handler by its `tool_call_id`.
///
/// Fires only the cancellation token for the named handler and removes it
/// from the in-flight registry, leaving all other handlers and the
/// agentic loop untouched. Use [`abort`](Self::abort) to cancel the full
/// turn.
///
/// Returns `true` if a handler with that ID was found and cancelled,
/// `false` if no matching in-flight handler exists.
pub fn cancel_tool_call(&self, tool_call_id: &str) -> bool {
let mut guard = self.in_flight_tool_calls.lock();
if let Some(token) = guard.remove(tool_call_id) {
token.cancel();
true
} else {
false
}
}

/// Switch to a different model.
///
/// Pass `None` for `opts` if no extra configuration is needed.
Expand Down Expand Up @@ -916,6 +961,7 @@ impl Client {
let idle_waiter = Arc::new(ParkingLotMutex::new(None));
let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
let shutdown = CancellationToken::new();
let in_flight_tool_calls = Arc::new(ParkingLotMutex::new(HashMap::new()));
let (event_tx, _) = tokio::sync::broadcast::channel(512);

// For cloud sessions (use_server_generated_id), defer session
Expand Down Expand Up @@ -1017,6 +1063,7 @@ impl Client {
open_canvases.clone(),
event_tx.clone(),
shutdown.clone(),
in_flight_tool_calls.clone(),
);
tracing::debug!(
elapsed_ms = setup_start.elapsed().as_millis(),
Expand All @@ -1041,6 +1088,7 @@ impl Client {
client: self.clone(),
event_loop: ParkingLotMutex::new(Some(event_loop)),
shutdown,
in_flight_tool_calls,
idle_waiter,
capabilities,
open_canvases,
Expand Down Expand Up @@ -1173,6 +1221,7 @@ impl Client {
let idle_waiter = Arc::new(ParkingLotMutex::new(None));
let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
let shutdown = CancellationToken::new();
let in_flight_tool_calls = Arc::new(ParkingLotMutex::new(HashMap::new()));
let (event_tx, _) = tokio::sync::broadcast::channel(512);
let event_loop = spawn_event_loop(
session_id.clone(),
Expand All @@ -1189,6 +1238,7 @@ impl Client {
open_canvases.clone(),
event_tx.clone(),
shutdown.clone(),
in_flight_tool_calls.clone(),
);
let mut registration =
PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone());
Expand Down Expand Up @@ -1284,6 +1334,7 @@ impl Client {
client: self.clone(),
event_loop: ParkingLotMutex::new(Some(event_loop)),
shutdown,
in_flight_tool_calls,
idle_waiter,
capabilities,
open_canvases,
Expand Down Expand Up @@ -1397,6 +1448,7 @@ fn spawn_event_loop(
open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
shutdown: CancellationToken,
in_flight_tool_calls: Arc<ParkingLotMutex<HashMap<String, CancellationToken>>>,
) -> JoinHandle<()> {
let crate::router::SessionChannels {
mut notifications,
Expand All @@ -1421,7 +1473,7 @@ fn spawn_event_loop(
_ = shutdown.cancelled() => break,
Some(notification) = notifications.recv() => {
handle_notification(
&session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx,
&session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx, &shutdown, &in_flight_tool_calls,
).await;
}
Some(request) = requests.recv() => {
Expand Down Expand Up @@ -1494,6 +1546,8 @@ async fn handle_notification(
capabilities: &Arc<parking_lot::RwLock<SessionCapabilities>>,
open_canvases: &Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
event_tx: &tokio::sync::broadcast::Sender<SessionEvent>,
shutdown: &CancellationToken,
in_flight_tool_calls: &Arc<ParkingLotMutex<HashMap<String, CancellationToken>>>,
) {
let dispatch_start = Instant::now();
let event = notification.event.clone();
Expand Down Expand Up @@ -1741,6 +1795,8 @@ async fn handle_notification(
session_id = %sid,
request_id = %request_id
);
let shutdown = shutdown.clone();
let in_flight_tool_calls = in_flight_tool_calls.clone();
tokio::spawn(
async move {
// `tool_name.is_empty()` would have produced a `None`
Expand Down Expand Up @@ -1770,13 +1826,18 @@ async fn handle_notification(
}
let tool_call_id = data.tool_call_id.clone();
let tool_name = data.tool_name.clone();
let cancellation_token = shutdown.child_token();
in_flight_tool_calls
.lock()
.insert(tool_call_id.clone(), cancellation_token.clone());
let invocation = ToolInvocation {
session_id: sid.clone(),
tool_call_id: data.tool_call_id,
tool_name: data.tool_name,
arguments: data
.arguments
.unwrap_or(Value::Object(serde_json::Map::new())),
cancellation_token: Some(cancellation_token),
traceparent: data.traceparent,
tracestate: data.tracestate,
};
Expand All @@ -1785,6 +1846,9 @@ async fn handle_notification(
Ok(r) => r,
Err(e) => tool_failure_result(e.to_string()),
};
// Remove the entry whether the handler succeeded, failed,
// or was cancelled — the token is no longer needed.
in_flight_tool_calls.lock().remove(&tool_call_id);
tracing::debug!(
elapsed_ms = handler_start.elapsed().as_millis(),
session_id = %sid,
Expand Down Expand Up @@ -2320,7 +2384,12 @@ fn inject_transform_sections_resume(

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;

use parking_lot::Mutex as ParkingLotMutex;
use serde_json::json;
use tokio_util::sync::CancellationToken;

use super::notification_permission_payload;
use crate::handler::PermissionResult;
Expand Down Expand Up @@ -2349,4 +2418,103 @@ mod tests {
Some(json!({ "kind": "user-not-available" }))
);
}

// Simulate the in-flight map mechanics used by Session without needing a
// real CLI connection.
fn make_map() -> Arc<ParkingLotMutex<HashMap<String, CancellationToken>>> {
Arc::new(ParkingLotMutex::new(HashMap::new()))
}

fn cancel_tool_call(
map: &Arc<ParkingLotMutex<HashMap<String, CancellationToken>>>,
tool_call_id: &str,
) -> bool {
let mut guard = map.lock();
if let Some(token) = guard.remove(tool_call_id) {
token.cancel();
true
} else {
false
}
}

#[test]
fn cancel_tool_call_cancels_only_the_targeted_handler() {
let map = make_map();
let token_a = CancellationToken::new();
let token_b = CancellationToken::new();
map.lock().insert("tc_a".to_string(), token_a.clone());
map.lock().insert("tc_b".to_string(), token_b.clone());

// Cancelling A leaves B untouched.
assert!(cancel_tool_call(&map, "tc_a"));
assert!(token_a.is_cancelled());
assert!(!token_b.is_cancelled());

// The entry is removed from the map.
assert_eq!(map.lock().len(), 1);
assert!(!map.lock().contains_key("tc_a"));
assert!(map.lock().contains_key("tc_b"));
}

#[test]
fn cancel_tool_call_returns_false_for_unknown_id() {
let map = make_map();
assert!(!cancel_tool_call(&map, "nonexistent"));
}

#[test]
fn abort_cancels_all_in_flight_tokens() {
let map = make_map();
let token_a = CancellationToken::new();
let token_b = CancellationToken::new();
map.lock().insert("tc_a".to_string(), token_a.clone());
map.lock().insert("tc_b".to_string(), token_b.clone());

// Simulate abort(): cancel all tokens in the map.
{
let guard = map.lock();
for token in guard.values() {
token.cancel();
}
}

assert!(token_a.is_cancelled());
assert!(token_b.is_cancelled());
}

/// Verify the end-to-end contract: a handler that selects on its
/// `cancellation_token.cancelled()` unblocks when the map entry is
/// cancelled (as `abort()` would do). This exercises the same path the
/// real dispatch code uses — insert token into map, pass child to handler,
/// cancel map entry — without requiring a live CLI connection.
#[tokio::test]
async fn abort_unblocks_handler_awaiting_cancellation() {
let map = make_map();
let shutdown = CancellationToken::new();

// Simulate dispatch: create a child token, register it, hand it to
// the "handler" task.
let token = shutdown.child_token();
map.lock().insert("tc_x".to_string(), token.clone());

let handler = tokio::spawn(async move {
// Handler blocks until its token fires.
token.cancelled().await;
});

// Simulate abort(): cancel every token in the map.
{
let guard = map.lock();
for t in guard.values() {
t.cancel();
}
}

// The handler task must complete promptly once cancelled.
tokio::time::timeout(std::time::Duration::from_secs(1), handler)
.await
.expect("handler should complete within timeout after abort")
.expect("handler task should not panic");
}
}
Loading