Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,27 @@ async def lookup_issue(params: LookupParams) -> str:
# your logic
```

#### Providing Per-Call Tool Context

Pass `on_provide_tool_context` to `create_session` (or `resume_session`) to inject application context into your tool handlers without exposing it to the model. The provider is invoked once per tool call with the `ToolInvocation` (sync or async); its return value is assigned to `invocation.context` before the handler runs. Use it to hand per-request services or state to handlers that would otherwise need a global lookup. `invocation.context` defaults to `None` when no provider is registered, and is never sent over the wire.

```python
from copilot.tools import ToolInvocation

@define_tool(description="List the current user's open issues")
async def my_issues(invocation: ToolInvocation) -> str:
ctx = invocation.context # whatever the provider returned
return await ctx.db.open_issues_for(ctx.user_id)

async with await client.create_session(
on_permission_request=PermissionHandler.approve_all,
model="gpt-5",
tools=[my_issues],
on_provide_tool_context=lambda invocation: build_request_context(),
) as session:
...
```

## Image Support

The SDK supports image attachments via the `attachments` parameter. You can attach images by providing their file path, or by passing base64-encoded data directly using a blob attachment:
Expand Down
2 changes: 2 additions & 0 deletions python/copilot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
SessionUiApi,
SessionUiCapabilities,
SystemMessageConfig,
ToolContextProvider,
UserInputHandler,
UserInputRequest,
UserInputResponse,
Expand Down Expand Up @@ -275,6 +276,7 @@
"TelemetryConfig",
"Tool",
"ToolBinaryResult",
"ToolContextProvider",
"ToolInvocation",
"ToolResult",
"ToolResultType",
Expand Down
13 changes: 13 additions & 0 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
SessionFsConfig,
SessionHooks,
SystemMessageConfig,
ToolContextProvider,
UserInputHandler,
_PermissionHandlerFn,
)
Expand Down Expand Up @@ -1565,6 +1566,7 @@ async def create_session(
reasoning_summary: ReasoningSummary | None = None,
context_tier: ContextTier | None = None,
tools: list[Tool] | None = None,
on_provide_tool_context: ToolContextProvider | None = None,
system_message: SystemMessageConfig | None = None,
available_tools: list[str] | ToolSet | None = None,
excluded_tools: list[str] | ToolSet | None = None,
Expand Down Expand Up @@ -1639,6 +1641,10 @@ async def create_session(
context_tier: Context window tier for models that support it. Use
``"long_context"`` to pin the session to the long-context tier.
tools: Custom tools to register with the session.
on_provide_tool_context: Optional provider invoked once per tool call
with the ``ToolInvocation``; its return value (awaited when a
coroutine) is assigned to ``ToolInvocation.context`` before the
handler runs.
system_message: System message configuration.
available_tools: Allowlist of tools to enable. When specified, only
these tools will be available. Applies to the full merged tool
Expand Down Expand Up @@ -2008,6 +2014,7 @@ def _initialize_session(sid: str) -> CopilotSession:
)
s._client_session_apis.session_fs = create_session_fs_adapter(fs_provider)
s._register_tools(tools)
s._register_tool_context_provider(on_provide_tool_context)
s._register_commands(commands)
s._register_permission_handler(on_permission_request)
if on_user_input_request:
Expand Down Expand Up @@ -2136,6 +2143,7 @@ async def resume_session(
reasoning_summary: ReasoningSummary | None = None,
context_tier: ContextTier | None = None,
tools: list[Tool] | None = None,
on_provide_tool_context: ToolContextProvider | None = None,
system_message: SystemMessageConfig | None = None,
available_tools: list[str] | ToolSet | None = None,
excluded_tools: list[str] | ToolSet | None = None,
Expand Down Expand Up @@ -2211,6 +2219,10 @@ async def resume_session(
context_tier: Context window tier for models that support it. Use
``"long_context"`` to pin the session to the long-context tier.
tools: Custom tools to register with the session.
on_provide_tool_context: Optional provider invoked once per tool call
with the ``ToolInvocation``; its return value (awaited when a
coroutine) is assigned to ``ToolInvocation.context`` before the
handler runs.
system_message: System message configuration.
available_tools: Allowlist of tools to enable. When specified, only
these tools will be available. Applies to the full merged tool
Expand Down Expand Up @@ -2533,6 +2545,7 @@ async def resume_session(
)
session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider)
session._register_tools(tools)
session._register_tool_context_provider(on_provide_tool_context)
session._register_commands(commands)
session._register_permission_handler(on_permission_request)
if on_user_input_request:
Expand Down
32 changes: 32 additions & 0 deletions python/copilot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ class PermissionNoResult:
PermissionRequestResult = PermissionDecision | PermissionNoResult


ToolContextProvider = Callable[[ToolInvocation], Any]
"""Per-call tool context provider: receives the invocation, returns ``context``."""


_PermissionHandlerFn = Callable[
[PermissionRequest, dict[str, str]],
PermissionRequestResult | Awaitable[PermissionRequestResult],
Expand Down Expand Up @@ -1119,6 +1123,8 @@ def __init__(
self._event_handlers_lock = threading.Lock()
self._tool_handlers: dict[str, ToolHandler] = {}
self._tool_handlers_lock = threading.Lock()
self._tool_context_provider: ToolContextProvider | None = None
self._tool_context_provider_lock = threading.Lock()
self._permission_handler: _PermissionHandlerFn | None = None
self._permission_handler_lock = threading.Lock()
self._user_input_handler: UserInputHandler | None = None
Expand Down Expand Up @@ -1592,6 +1598,13 @@ async def _execute_tool_and_respond(
arguments=arguments,
)

provider = self._get_tool_context_provider()
if provider is not None:
tool_context = provider(invocation)
if inspect.isawaitable(tool_context):
tool_context = await tool_context
invocation.context = tool_context

with trace_context(traceparent, tracestate):
handler_start = time.perf_counter()
result = handler(invocation)
Expand Down Expand Up @@ -1989,6 +2002,25 @@ def _get_tool_handler(self, name: str) -> ToolHandler | None:
with self._tool_handlers_lock:
return self._tool_handlers.get(name)

def _register_tool_context_provider(self, provider: ToolContextProvider | None) -> None:
"""
Register the provider that supplies per-call tool context.

Note:
This method is internal. The provider is typically registered when
creating a session via :meth:`CopilotClient.create_session`.

Args:
provider: The tool context provider, or None to remove it.
"""
with self._tool_context_provider_lock:
self._tool_context_provider = provider

def _get_tool_context_provider(self) -> ToolContextProvider | None:
"""Retrieve the registered tool context provider, if any."""
with self._tool_context_provider_lock:
return self._tool_context_provider

def _register_permission_handler(self, handler: _PermissionHandlerFn | None) -> None:
"""
Register a handler for permission requests.
Expand Down
1 change: 1 addition & 0 deletions python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ToolInvocation:
tool_call_id: str = ""
tool_name: str = ""
arguments: Any = None
context: Any = None


ToolHandler = Callable[[ToolInvocation], ToolResult | Awaitable[ToolResult]]
Expand Down
167 changes: 167 additions & 0 deletions python/test_tool_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Unit tests for the per-call tool context provider.

The provider is registered on a session and invoked once per tool call to
populate ``ToolInvocation.context`` before the handler runs. These tests drive
``CopilotSession._execute_tool_and_respond`` directly with a fake RPC so the
injection path is exercised without a live runtime connection.
"""

from __future__ import annotations

from typing import Any

from copilot import define_tool
from copilot.session import CopilotSession
from copilot.tools import ToolInvocation


class _FakeToolsRpc:
def __init__(self) -> None:
self.calls: list[Any] = []

async def handle_pending_tool_call(self, request: Any) -> None:
self.calls.append(request)


class _FakeRpc:
def __init__(self) -> None:
self.tools = _FakeToolsRpc()


def _session_with_fake_rpc(session_id: str = "sess-1") -> CopilotSession:
session = CopilotSession(session_id, client=None)
session._rpc = _FakeRpc() # type: ignore[assignment]
return session


async def test_provider_value_injected_into_invocation_context():
seen: dict[str, Any] = {}

@define_tool("echo", description="Echo tool")
def echo(invocation: ToolInvocation) -> str:
seen["context"] = invocation.context
return "ok"

session = _session_with_fake_rpc()
session._register_tool_context_provider(lambda inv: {"user": "alice", "tool": inv.tool_name})

await session._execute_tool_and_respond(
request_id="r1",
tool_name="echo",
tool_call_id="c1",
arguments={},
handler=echo.handler,
)

assert seen["context"] == {"user": "alice", "tool": "echo"}


async def test_async_provider_is_awaited():
seen: dict[str, Any] = {}

@define_tool("echo", description="Echo tool")
def echo(invocation: ToolInvocation) -> str:
seen["context"] = invocation.context
return "ok"

async def provider(_: ToolInvocation) -> dict[str, Any]:
return {"async": True}

session = _session_with_fake_rpc()
session._register_tool_context_provider(provider)

await session._execute_tool_and_respond(
request_id="r1",
tool_name="echo",
tool_call_id="c1",
arguments={},
handler=echo.handler,
)

assert seen["context"] == {"async": True}


async def test_provider_receives_full_invocation():
seen: dict[str, ToolInvocation] = {}

@define_tool("echo", description="Echo tool")
def echo(invocation: ToolInvocation) -> str:
return "ok"

def provider(inv: ToolInvocation) -> str:
seen["invocation"] = inv
return "ctx"

session = _session_with_fake_rpc("sess-42")
session._register_tool_context_provider(provider)

await session._execute_tool_and_respond(
request_id="r1",
tool_name="echo",
tool_call_id="call-7",
arguments={"q": "hello"},
handler=echo.handler,
)

inv = seen["invocation"]
assert inv.session_id == "sess-42"
assert inv.tool_name == "echo"
assert inv.tool_call_id == "call-7"
assert inv.arguments == {"q": "hello"}


async def test_no_provider_leaves_context_none():
seen: dict[str, Any] = {}

@define_tool("echo", description="Echo tool")
def echo(invocation: ToolInvocation) -> str:
seen["context"] = invocation.context
return "ok"

session = _session_with_fake_rpc()

await session._execute_tool_and_respond(
request_id="r1",
tool_name="echo",
tool_call_id="c1",
arguments={},
handler=echo.handler,
)

assert seen["context"] is None


async def test_provider_returning_none_leaves_context_none():
seen: dict[str, Any] = {}

@define_tool("echo", description="Echo tool")
def echo(invocation: ToolInvocation) -> str:
seen["context"] = invocation.context
return "ok"

session = _session_with_fake_rpc()
session._register_tool_context_provider(lambda _: None)

await session._execute_tool_and_respond(
request_id="r1",
tool_name="echo",
tool_call_id="c1",
arguments={},
handler=echo.handler,
)

assert seen["context"] is None


def test_register_and_clear_provider_round_trip():
session = CopilotSession("sess-1", client=None)
assert session._get_tool_context_provider() is None

def provider(_: ToolInvocation) -> str:
return "ctx"

session._register_tool_context_provider(provider)
assert session._get_tool_context_provider() is provider

session._register_tool_context_provider(None)
assert session._get_tool_context_provider() is None
43 changes: 43 additions & 0 deletions python/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,49 @@ class Params(BaseModel):
)


class TestToolInvocation:
def test_context_defaults_to_none(self):
inv = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="t",
arguments={},
)
assert inv.context is None

def test_context_can_be_set(self):
sentinel = object()
inv = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="t",
arguments={},
context=sentinel,
)
assert inv.context is sentinel

async def test_handler_can_read_context(self):
seen = None

@define_tool("t", description="Reads context")
def tool(invocation: ToolInvocation) -> str:
nonlocal seen
seen = invocation.context
return "ok"

await tool.handler(
ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="t",
arguments={},
context={"user": "alice"},
)
)

assert seen == {"user": "alice"}


class TestNormalizeResult:
def test_none_returns_empty_success(self):
result = _normalize_result(None)
Expand Down