From d11efa0b0cb55178532fc571b15576dbebb9199c Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sun, 14 Jun 2026 14:03:53 +0200 Subject: [PATCH 1/5] Add SDK MCP OAuth host token handlers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Client.cs | 15 + dotnet/src/Generated/Rpc.cs | 112 ++++++- dotnet/src/Generated/SessionEvents.cs | 94 +++++- dotnet/src/Session.cs | 68 +++++ dotnet/src/Types.cs | 70 +++++ .../test/Unit/ClientSessionLifetimeTests.cs | 152 +++++++++ dotnet/test/Unit/PublicDtoTests.cs | 19 ++ .../Unit/SessionEventSerializationTests.cs | 4 + go/client.go | 30 ++ go/client_test.go | 288 ++++++++++++++++++ go/rpc/zrpc.go | 104 +++++++ go/rpc/zrpc_encoding.go | 83 +++++ go/rpc/zsession_events.go | 28 +- go/session.go | 87 ++++++ go/session_test.go | 104 +++++++ go/types.go | 52 ++++ go/zsession_events.go | 15 +- java/scripts/codegen/java.ts | 9 +- .../generated/McpOauthCompletedEvent.java | 4 +- .../generated/McpOauthCompletedOutcome.java | 37 +++ .../generated/McpOauthRequiredEvent.java | 6 +- ...McpOauthRequiredWwwAuthenticateParams.java | 31 ++ ...SessionEventLogRegisterInterestParams.java | 2 +- .../generated/rpc/SessionMcpOauthApi.java | 16 + ...ionMcpOauthHandlePendingRequestParams.java | 34 +++ ...ionMcpOauthHandlePendingRequestResult.java | 30 ++ .../rpc/SessionMcpOauthRespondParams.java | 2 +- .../com/github/copilot/CopilotClient.java | 28 +- .../com/github/copilot/CopilotSession.java | 77 +++++ .../github/copilot/SessionRequestBuilder.java | 6 + .../github/copilot/rpc/McpAuthHandler.java | 24 ++ .../github/copilot/rpc/McpAuthRequest.java | 22 ++ .../com/github/copilot/rpc/McpAuthResult.java | 32 ++ .../com/github/copilot/rpc/McpAuthToken.java | 13 + .../copilot/rpc/ResumeSessionConfig.java | 24 ++ .../com/github/copilot/rpc/SessionConfig.java | 27 ++ .../McpAuthInterestRegistrationTest.java | 246 +++++++++++++++ nodejs/src/client.ts | 18 +- nodejs/src/generated/rpc.ts | 108 ++++++- nodejs/src/generated/session-events.ts | 31 +- nodejs/src/session.ts | 56 +++- nodejs/src/types.ts | 79 ++++- nodejs/test/client.test.ts | 176 +++++++++++ nodejs/test/e2e/mcp_oauth.e2e.test.ts | 166 ++++++++++ python/copilot/__init__.py | 12 + python/copilot/client.py | 21 ++ python/copilot/generated/rpc.py | 135 ++++++++ python/copilot/generated/session_events.py | 47 ++- python/copilot/session.py | 150 +++++++++ python/test_client.py | 255 ++++++++++++++++ rust/src/generated/api_types.rs | 144 ++++++++- rust/src/generated/rpc.rs | 38 ++- rust/src/generated/session_events.rs | 38 ++- rust/src/handler.rs | 96 +++++- rust/src/session.rs | 117 ++++++- rust/src/types.rs | 34 ++- rust/tests/session_test.rs | 263 +++++++++++++++- test/harness/test-mcp-oauth-server.mjs | 216 +++++++++++++ 58 files changed, 4143 insertions(+), 52 deletions(-) create mode 100644 java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java create mode 100644 java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java create mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java create mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthResult.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthToken.java create mode 100644 java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java create mode 100644 nodejs/test/e2e/mcp_oauth.e2e.test.ts create mode 100644 test/harness/test-mcp-oauth-server.mjs diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 4e8715bd5..16f628f99 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -572,6 +572,7 @@ private CopilotSession InitializeSession( this); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterMcpAuthHandler(config.OnMcpAuthRequest); session.RegisterCommands(config.Commands); session.RegisterElicitationHandler(config.OnElicitationRequest); session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); @@ -878,6 +879,10 @@ public async Task CreateSessionAsync(SessionConfig config, Cance transformCallbacks, hasHooks, "CopilotClient.CreateSessionAsync"); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } } try { @@ -987,6 +992,12 @@ public async Task CreateSessionAsync(SessionConfig config, Cance $"session.create returned sessionId {response.SessionId} but the caller requested {localSessionId}."); } + // Local IDs registered before create; server-assigned IDs can only register now. + if (localSessionId is null && config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } + session.WorkspacePath = response.WorkspacePath; session.SetCapabilities(response.Capabilities); session.SetOpenCanvases(response.OpenCanvases); @@ -1073,6 +1084,10 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes transformCallbacks, hasHooks, "CopilotClient.ResumeSessionAsync"); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } try { diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 4b6cdab93..5ace8cf82 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -5193,7 +5193,7 @@ internal sealed class McpOauthRespondRequest [JsonPropertyName("provider")] internal JsonElement? Provider { get; set; } - /// OAuth request identifier from mcp.oauth_required. + /// OAuth request identifier for the pending request. [JsonPropertyName("requestId")] public string RequestId { get; set; } = string.Empty; @@ -5202,6 +5202,87 @@ internal sealed class McpOauthRespondRequest public string SessionId { get; set; } = string.Empty; } +/// Indicates whether the pending MCP OAuth response was accepted. +[Experimental(Diagnostics.Experimental)] +public sealed class McpOauthHandlePendingResult +{ + /// Whether the response was accepted. False if the request was unknown, timed out, or already resolved. + [JsonPropertyName("success")] + public bool Success { get; set; } +} + +/// Host response to the pending OAuth request. +/// Polymorphic base type discriminated by kind. +[Experimental(Diagnostics.Experimental)] +[JsonPolymorphic( + TypeDiscriminatorPropertyName = "kind", + UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FallBackToBaseType)] +[JsonDerivedType(typeof(McpOauthPendingRequestResponseToken), "token")] +[JsonDerivedType(typeof(McpOauthPendingRequestResponseCancelled), "cancelled")] +public partial class McpOauthPendingRequestResponse +{ + /// The type discriminator. + [JsonPropertyName("kind")] + public virtual string Kind { get; set; } = string.Empty; +} + + +/// Schema for the `McpOauthPendingRequestResponseToken` type. +/// The token variant of . +[Experimental(Diagnostics.Experimental)] +public partial class McpOauthPendingRequestResponseToken : McpOauthPendingRequestResponse +{ + /// + [JsonIgnore] + public override string Kind => "token"; + + /// Access token acquired by the SDK host. + [JsonPropertyName("accessToken")] + public required string AccessToken { get; set; } + + /// Token lifetime in seconds, if known. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("expiresIn")] + public long? ExpiresIn { get; set; } + + /// Refresh token supplied by the host, if available. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("refreshToken")] + public string? RefreshToken { get; set; } + + /// OAuth token type. Defaults to Bearer when omitted. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("tokenType")] + public string? TokenType { get; set; } +} + +/// Schema for the `McpOauthPendingRequestResponseCancelled` type. +/// The cancelled variant of . +[Experimental(Diagnostics.Experimental)] +public partial class McpOauthPendingRequestResponseCancelled : McpOauthPendingRequestResponse +{ + /// + [JsonIgnore] + public override string Kind => "cancelled"; +} + +/// Pending MCP OAuth request ID and host-provided token or cancellation response. +[Experimental(Diagnostics.Experimental)] +internal sealed class McpOauthHandlePendingRequest +{ + /// OAuth request identifier for the pending request. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// Host response to the pending OAuth request. + [JsonPropertyName("result")] + public McpOauthPendingRequestResponse Result { get => field ??= new(); set; } + + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; +} + /// OAuth authorization URL the caller should open, or empty when cached tokens already authenticated the server. [Experimental(Diagnostics.Experimental)] public sealed class McpOauthLoginResult @@ -6539,6 +6620,7 @@ public sealed class SlashCommandInfo /// Canonical command name without a leading slash. [JsonPropertyName("name")] public string Name { get; set; } = string.Empty; + } /// Slash commands available in the session, after applying any include/exclude filters. @@ -9267,7 +9349,7 @@ public sealed class RegisterEventInterestResult [Experimental(Diagnostics.Experimental)] internal sealed class RegisterEventInterestParams { - /// The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. + /// The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. [JsonPropertyName("eventType")] public string EventType { get; set; } = string.Empty; @@ -17507,8 +17589,8 @@ internal McpOauthApi(CopilotSession session) _session = session; } - /// Responds to a pending MCP OAuth provider request. Marked internal because the `provider` argument is an in-process OAuthClientProvider instance that cannot be carried over the wire; the public OAuth surface will route the response through a wire-clean handshake once the CLI moves on top of the SDK. - /// OAuth request identifier from mcp.oauth_required. + /// Responds to a pending MCP OAuth request with an in-process provider. Conceptually similar to handlePendingRequest, but marked internal because this legacy CLI-only path takes a live OAuthClientProvider instance that cannot be carried over the wire. Once the CLI is replatformed on the SDK and can use handlePendingRequest, this API should be removed. + /// OAuth request identifier for the pending request. /// In-process OAuthClientProvider instance, or omitted to deny. Marked internal: cannot be serialized across the JSON-RPC boundary. /// The to monitor for cancellation requests. The default is . /// Empty result after recording the MCP OAuth response. @@ -17521,6 +17603,21 @@ internal async Task RespondAsync(string requestId, object return await CopilotClient.InvokeRpcAsync(_session.Rpc, "session.mcp.oauth.respond", [request], cancellationToken); } + /// Resolves a pending MCP OAuth request with a host-provided token or cancellation. + /// OAuth request identifier for the pending request. + /// Host response to the pending OAuth request. + /// The to monitor for cancellation requests. The default is . + /// Indicates whether the pending MCP OAuth response was accepted. + public async Task HandlePendingRequestAsync(string requestId, McpOauthPendingRequestResponse result, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(requestId); + ArgumentNullException.ThrowIfNull(result); + _session.ThrowIfDisposed(); + + var request = new McpOauthHandlePendingRequest { SessionId = _session.SessionId, RequestId = requestId, Result = result }; + return await CopilotClient.InvokeRpcAsync(_session.Rpc, "session.mcp.oauth.handlePendingRequest", [request], cancellationToken); + } + /// Starts OAuth authentication for a remote MCP server. /// Name of the remote MCP server to authenticate. /// When true, clears any cached OAuth token for the server and runs a full new authorization. Use when the user explicitly wants to switch accounts or believes their session is stuck. @@ -18829,7 +18926,7 @@ public async Task TailAsync(CancellationToken cancellationTo } /// Registers consumer interest in an event type for runtime gating purposes. - /// The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. + /// The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. /// The to monitor for cancellation requests. The default is . /// Opaque handle representing an event-type interest registration. public async Task RegisterInterestAsync(string eventType, CancellationToken cancellationToken = default) @@ -19269,9 +19366,11 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncOAuth authentication request for an MCP server. public sealed partial class McpOauthRequiredData { - /// Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth(). + /// Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest. [JsonPropertyName("requestId")] public required string RequestId { get; set; } @@ -3068,11 +3068,19 @@ public sealed partial class McpOauthRequiredData [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("staticClientConfig")] public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } + + /// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + [JsonPropertyName("wwwAuthenticateParams")] + public required McpOauthRequiredWwwAuthenticateParams WwwAuthenticateParams { get; set; } } /// MCP OAuth request completion notification. public sealed partial class McpOauthCompletedData { + /// How the pending OAuth request was completed. + [JsonPropertyName("outcome")] + public required McpOauthCompletedOutcome Outcome { get; set; } + /// Request ID of the resolved OAuth request. [JsonPropertyName("requestId")] public required string RequestId { get; set; } @@ -5781,6 +5789,25 @@ public sealed partial class McpOauthRequiredStaticClientConfig public bool? PublicClient { get; set; } } +/// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. +/// Nested data type for McpOauthRequiredWwwAuthenticateParams. +public sealed partial class McpOauthRequiredWwwAuthenticateParams +{ + /// Parsed OAuth error from the WWW-Authenticate header, if present. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("error")] + public string? Error { get; set; } + + /// Parsed resource_metadata URL from the WWW-Authenticate header. + [JsonPropertyName("resourceMetadataUrl")] + public required string ResourceMetadataUrl { get; set; } + + /// Parsed OAuth scope from the WWW-Authenticate header, if present. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("scope")] + public string? Scope { get; set; } +} + /// Schema for the `CommandsChangedCommand` type. /// Nested data type for CommandsChangedCommand. public sealed partial class CommandsChangedCommand @@ -7667,6 +7694,70 @@ public override void Write(Utf8JsonWriter writer, ElicitationCompletedAction val } } +/// How the pending OAuth request was completed. +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct McpOauthCompletedOutcome : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public McpOauthCompletedOutcome(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// The pending OAuth request was resolved with a host-provided token/provider. + public static McpOauthCompletedOutcome Token { get; } = new("token"); + + /// The pending OAuth request was cancelled or declined without a token/provider. + public static McpOauthCompletedOutcome Cancelled { get; } = new("cancelled"); + + /// The pending OAuth request timed out before any client responded. + public static McpOauthCompletedOutcome Timeout { get; } = new("timeout"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(McpOauthCompletedOutcome left, McpOauthCompletedOutcome right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(McpOauthCompletedOutcome left, McpOauthCompletedOutcome right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is McpOauthCompletedOutcome other && Equals(other); + + /// + public bool Equals(McpOauthCompletedOutcome other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override McpOauthCompletedOutcome Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, McpOauthCompletedOutcome value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(McpOauthCompletedOutcome)); + } + } +} + /// The user's auto-mode-switch choice. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -8368,6 +8459,7 @@ public override void Write(Utf8JsonWriter writer, CanvasOpenedAvailability value [JsonSerializable(typeof(McpOauthRequiredData))] [JsonSerializable(typeof(McpOauthRequiredEvent))] [JsonSerializable(typeof(McpOauthRequiredStaticClientConfig))] +[JsonSerializable(typeof(McpOauthRequiredWwwAuthenticateParams))] [JsonSerializable(typeof(McpServersLoadedServer))] [JsonSerializable(typeof(ModelCallFailureData))] [JsonSerializable(typeof(ModelCallFailureEvent))] diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 095c1abf7..1d0d9572a 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -62,6 +62,7 @@ public sealed partial class CopilotSession : IAsyncDisposable private readonly CopilotClient _parentClient; private volatile Func>? _permissionHandler; + private volatile Func>? _mcpAuthHandler; private volatile Func>? _userInputHandler; private volatile Func>? _elicitationHandler; private volatile Func>? _exitPlanModeHandler; @@ -561,6 +562,11 @@ internal void RegisterPermissionHandler(Func>? handler) + { + _mcpAuthHandler = handler; + } + /// /// Handles a permission request from the Copilot CLI. /// @@ -636,6 +642,34 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) break; } + case McpOauthRequiredEvent authEvent: + { + var data = authEvent.Data; + if (string.IsNullOrEmpty(data.RequestId)) + return; + + var handler = _mcpAuthHandler; + if (handler is null) + { + _logger.LogWarning( + "Received MCP OAuth request without a registered MCP auth handler. " + + "SessionId={SessionId}, RequestId={RequestId}", + SessionId, + data.RequestId); + return; + } + + await ExecuteMcpAuthAndRespondAsync(data.RequestId, new McpAuthContext + { + SessionId = SessionId, + ServerName = data.ServerName, + ServerUrl = data.ServerUrl, + WwwAuthenticateParams = data.WwwAuthenticateParams, + StaticClientConfig = data.StaticClientConfig + }, handler); + break; + } + case CommandExecuteEvent cmdEvent: { var data = cmdEvent.Data; @@ -705,6 +739,40 @@ await HandleElicitationRequestAsync( } } + private async Task ExecuteMcpAuthAndRespondAsync( + string requestId, + McpAuthContext context, + Func> handler) + { + try + { + var result = await handler(context); + McpOauthPendingRequestResponse response = + result is { Cancelled: false, Token: { } token } + ? new McpOauthPendingRequestResponseToken + { + AccessToken = token.AccessToken, + TokenType = token.TokenType, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn + } + : new McpOauthPendingRequestResponseCancelled(); + + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, response); + } + catch (Exception) + { + try + { + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, new McpOauthPendingRequestResponseCancelled()); + } + catch (Exception rpcEx) when (rpcEx is IOException or ObjectDisposedException) + { + // Connection lost or RPC error — nothing we can do. + } + } + } + /// /// Executes a tool handler and sends the result back via the HandlePendingToolCall RPC. /// diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 7a2ad2951..78244e353 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1106,6 +1106,67 @@ public sealed class ElicitationContext public string? Url { get; set; } } +/// +/// Context for an MCP OAuth request callback. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthContext +{ + /// Identifier of the session that triggered the MCP OAuth request. + public string SessionId { get; set; } = string.Empty; + + /// Display name of the MCP server that requires OAuth. + public string ServerName { get; set; } = string.Empty; + + /// URL of the MCP server that requires OAuth. + public string ServerUrl { get; set; } = string.Empty; + + /// Parsed WWW-Authenticate parameters from the MCP server. + public McpOauthRequiredWwwAuthenticateParams WwwAuthenticateParams { get; set; } = + new() { ResourceMetadataUrl = string.Empty }; + + /// Static OAuth client configuration, if the server specifies one. + public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } +} + +/// +/// Host-provided OAuth token data for a pending MCP OAuth request. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthToken +{ + /// Access token acquired by the SDK host. + public required string AccessToken { get; set; } + + /// OAuth token type. Defaults to Bearer when omitted. + public string? TokenType { get; set; } + + /// Refresh token supplied by the host, if available. + public string? RefreshToken { get; set; } + + /// Token lifetime in seconds, if known. + public long? ExpiresIn { get; set; } +} + +/// +/// Result returned by an MCP auth request handler. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthResult +{ + /// Whether the request should be cancelled instead of resolved with a token. + public bool Cancelled { get; set; } + + /// Host-provided token data. Ignored when is true. + public McpAuthToken? Token { get; set; } + + /// Create a token result. + public static McpAuthResult FromToken(McpAuthToken token) => new() { Token = token }; + + /// Create a cancellation result. + public static McpAuthResult Cancel() => new() { Cancelled = true }; +} + // ============================================================================ // Session Capabilities // ============================================================================ @@ -2470,6 +2531,7 @@ protected SessionConfigBase(SessionConfigBase? other) OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; OnExitPlanModeRequest = other.OnExitPlanModeRequest; + OnMcpAuthRequest = other.OnMcpAuthRequest; OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; @@ -2884,6 +2946,14 @@ protected SessionConfigBase(SessionConfigBase? other) [JsonIgnore] public ICanvasHandler? CanvasHandler { get; set; } #pragma warning restore GHCP001 + + /// + /// Optional handler for MCP OAuth requests from MCP servers. + /// When provided, the SDK can satisfy MCP server OAuth requests with host-provided token data or cancellation. + /// + [Experimental(Diagnostics.Experimental)] + [JsonIgnore] + public Func>? OnMcpAuthRequest { get; set; } } /// diff --git a/dotnet/test/Unit/ClientSessionLifetimeTests.cs b/dotnet/test/Unit/ClientSessionLifetimeTests.cs index c52148a03..48d26fb71 100644 --- a/dotnet/test/Unit/ClientSessionLifetimeTests.cs +++ b/dotnet/test/Unit/ClientSessionLifetimeTests.cs @@ -15,6 +15,8 @@ namespace GitHub.Copilot.Test.Unit; public sealed class ClientSessionLifetimeTests { + private sealed record RpcRequestRecord(string Method, JsonElement Params); + [Fact] public async Task Dropped_Session_Remains_Rooted_By_Client() { @@ -136,6 +138,124 @@ public async Task ResumeSessionAsync_Throws_When_Same_Client_Already_Tracks_Sess AssertSessionCount(client, sessions: 1); } + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.create" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }, + request => Assert.Equal("session.create", request.Method)); + } + + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_After_Cloud_Create_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + var cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository + { + Owner = "github", + Name = "copilot-sdk", + Branch = "main" + } + }; + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = cloud + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()), + Cloud = cloud + }); + + Assert.Collection( + server.Requests.Take(2), + request => Assert.Equal("session.create", request.Method), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }); + } + + [Fact] + public async Task ResumeSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.ResumeSessionAsync("session-without-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.resume" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.ResumeSessionAsync("session-with-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }, + request => Assert.Equal("session.resume", request.Method)); + } + [Fact] public async Task Generated_Session_Rpc_Throws_When_Session_Disposed() { @@ -194,6 +314,8 @@ private sealed class FakeCopilotServer : IAsyncDisposable private readonly TaskCompletionSource _destroyStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource _allowDestroy = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly Task _serverTask; + private readonly List _requests = []; + private readonly object _requestsLock = new(); private string? _lastSessionId; private bool _delayDestroy; @@ -221,6 +343,25 @@ public static Task StartAsync() public Task DestroyStarted => _destroyStarted.Task; + public IReadOnlyList Requests + { + get + { + lock (_requestsLock) + { + return _requests.ToArray(); + } + } + } + + public void ClearRequests() + { + lock (_requestsLock) + { + _requests.Clear(); + } + } + public void DelayDestroy() { _delayDestroy = true; @@ -275,6 +416,13 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel var id = idElement.Clone(); var method = request.GetProperty("method").GetString(); + var paramsElement = request.TryGetProperty("params", out var rawParams) + ? rawParams.Clone() + : JsonDocument.Parse("{}").RootElement.Clone(); + lock (_requestsLock) + { + _requests.Add(new RpcRequestRecord(method!, paramsElement)); + } object? result = method switch { "connect" => new Dictionary @@ -285,6 +433,10 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel }, "session.create" => CreateSessionResult(request), "session.resume" => CreateSessionResult(request), + "session.eventLog.registerInterest" => new Dictionary + { + ["id"] = "interest-1" + }, "session.send" => new Dictionary { ["messageId"] = "message-1" diff --git a/dotnet/test/Unit/PublicDtoTests.cs b/dotnet/test/Unit/PublicDtoTests.cs index c81a8a7a6..d1918d2b9 100644 --- a/dotnet/test/Unit/PublicDtoTests.cs +++ b/dotnet/test/Unit/PublicDtoTests.cs @@ -20,6 +20,25 @@ namespace GitHub.Copilot.Test.Unit; /// public class PublicDtoTests { + [Fact] + public void McpAuth_Result_Factories_Represent_Token_And_Cancellation() + { + var token = new McpAuthToken + { + AccessToken = "host-token", + TokenType = "Bearer", + ExpiresIn = 3600, + }; + + var tokenResult = McpAuthResult.FromToken(token); + Assert.Same(token, tokenResult.Token); + Assert.False(tokenResult.Cancelled); + + var cancelled = McpAuthResult.Cancel(); + Assert.True(cancelled.Cancelled); + Assert.Null(cancelled.Token); + } + [Fact] public void Public_Dto_Properties_Can_Be_Set_And_Read() { diff --git a/dotnet/test/Unit/SessionEventSerializationTests.cs b/dotnet/test/Unit/SessionEventSerializationTests.cs index 47b4ac3f7..2db690a6c 100644 --- a/dotnet/test/Unit/SessionEventSerializationTests.cs +++ b/dotnet/test/Unit/SessionEventSerializationTests.cs @@ -158,6 +158,10 @@ public class SessionEventSerializationTests GrantType = "client_credentials", PublicClient = false, }, + WwwAuthenticateParams = new McpOauthRequiredWwwAuthenticateParams + { + ResourceMetadataUrl = "https://example.com/.well-known/oauth-protected-resource", + }, }, }, "mcp.oauth_required" diff --git a/go/client.go b/go/client.go index cad460557..861b3d61f 100644 --- a/go/client.go +++ b/go/client.go @@ -732,6 +732,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses s.registerTools(config.Tools) s.registerPermissionHandler(config.OnPermissionRequest) + s.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { s.registerUserInputHandler(config.OnUserInputRequest) } @@ -799,6 +800,14 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } session = s registeredSessionID = localSessionID + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request("session.eventLog.registerInterest", map[string]any{ + "sessionId": localSessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + return nil, err + } + } } // For the server-assigned (cloud) path, register the session @@ -860,6 +869,15 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses c.sessionsMux.Unlock() return nil, fmt.Errorf("session.create returned sessionId %s but the caller requested %s", response.SessionID, localSessionID) } + // Local IDs registered before create; server-assigned IDs can only register now. + if localSessionID == "" && config.OnMCPAuthRequest != nil { + if _, err := c.client.Request("session.eventLog.registerInterest", map[string]any{ + "sessionId": session.SessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + return nil, err + } + } session.workspacePath = response.WorkspacePath session.setCapabilities(response.Capabilities) @@ -1024,6 +1042,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) + session.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { session.registerUserInputHandler(config.OnUserInputRequest) } @@ -1055,6 +1074,17 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessionsMux.Lock() c.sessions[sessionID] = session c.sessionsMux.Unlock() + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request("session.eventLog.registerInterest", map[string]any{ + "sessionId": sessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, err + } + } if c.options.SessionFS != nil { if config.CreateSessionFSProvider == nil { diff --git a/go/client_test.go b/go/client_test.go index d5ba47da8..03a62cf78 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -3,6 +3,8 @@ package copilot import ( "context" "encoding/json" + "fmt" + "io" "os" "os/exec" "path/filepath" @@ -13,6 +15,7 @@ import ( "sync" "testing" + "github.com/github/copilot-sdk/go/internal/jsonrpc2" "github.com/github/copilot-sdk/go/internal/truncbuffer" "github.com/github/copilot-sdk/go/rpc" ) @@ -986,6 +989,291 @@ func TestClient_StartStopRace(t *testing.T) { } } +func TestClient_MCPAuthInterestRegistration(t *testing.T) { + t.Run("create skips MCP OAuth interest without auth handler", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Close() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.create") + assertCreateRequestPermission(t, requests.snapshot()) + }) + + t.Run("create registers MCP OAuth interest before local session create when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Close() + + snapshot := requests.snapshot() + assertRequestMethod(t, snapshot, "session.eventLog.registerInterest") + if snapshot[0].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest before session.create, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.create" { + t.Fatalf("expected session.create after MCP auth interest, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[0]) + assertCreateRequestPermission(t, snapshot) + }) + + t.Run("cloud create registers MCP OAuth interest after server assigns id only when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession without auth failed: %v", err) + } + defer withoutAuth.Close() + + assertNoMCPAuthInterest(t, requests.snapshot()) + requests.clear() + + withAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession with auth failed: %v", err) + } + defer withAuth.Close() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.create" { + t.Fatalf("expected cloud session.create before MCP auth interest, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest after cloud session.create, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[1]) + }) + + t.Run("resume conditionally registers MCP OAuth interest before session resume", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.ResumeSession(t.Context(), "session-without-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("ResumeSession without auth failed: %v", err) + } + defer withoutAuth.Close() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.resume") + requests.clear() + + withAuth, err := client.ResumeSession(t.Context(), "session-with-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + }) + if err != nil { + t.Fatalf("ResumeSession with auth failed: %v", err) + } + defer withAuth.Close() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest before session.resume, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.resume" { + t.Fatalf("expected session.resume after MCP auth interest, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[0]) + }) +} + +type recordedRequest struct { + Method string + Params map[string]any +} + +type requestRecorder struct { + mu sync.Mutex + requests []recordedRequest +} + +func (r *requestRecorder) append(request recordedRequest) { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = append(r.requests, request) +} + +func (r *requestRecorder) snapshot() []recordedRequest { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]recordedRequest, len(r.requests)) + copy(out, r.requests) + return out +} + +func (r *requestRecorder) clear() { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = nil +} + +func newInMemoryClient(t *testing.T) (*Client, *requestRecorder, func()) { + t.Helper() + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + rpcClient := jsonrpc2.NewClient(stdinW, stdoutR) + rpcClient.Start() + + client := NewClient(&ClientOptions{}) + client.client = rpcClient + client.RPC = rpc.NewServerRPC(rpcClient) + client.state = stateConnected + + requests := &requestRecorder{} + done := make(chan struct{}) + go serveInMemoryRuntime(t, stdinR, stdoutW, requests, done) + + cleanup := func() { + rpcClient.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + <-done + } + return client, requests, cleanup +} + +func serveInMemoryRuntime(t *testing.T, stdinR *io.PipeReader, stdoutW *io.PipeWriter, requests *requestRecorder, done chan<- struct{}) { + t.Helper() + defer close(done) + + serverAssignedSessions := 0 + for { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + t.Errorf("failed to unmarshal JSON-RPC request: %v", err) + return + } + requests.append(recordedRequest{Method: request.Method, Params: request.Params}) + + result := map[string]any{} + switch request.Method { + case "session.create", "session.resume": + sessionID, _ := request.Params["sessionId"].(string) + if sessionID == "" { + serverAssignedSessions++ + sessionID = fmt.Sprintf("server-assigned-session-%d", serverAssignedSessions) + } + result = map[string]any{"sessionId": sessionID, "workspacePath": nil} + case "session.eventLog.registerInterest": + result = map[string]any{"id": "interest-1"} + case "session.options.update": + result = map[string]any{"success": true} + case "session.skills.reload", "session.destroy": + result = map[string]any{} + default: + t.Errorf("unexpected JSON-RPC method %s", request.Method) + return + } + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": result, + } + data, err := json.Marshal(response) + if err != nil { + t.Errorf("failed to marshal JSON-RPC response: %v", err) + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + return + } + } +} + +func assertRequestMethod(t *testing.T, requests []recordedRequest, method string) { + t.Helper() + for _, request := range requests { + if request.Method == method { + return + } + } + t.Fatalf("expected %s request in %+v", method, requests) +} + +func assertNoMCPAuthInterest(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.eventLog.registerInterest" && request.Params["eventType"] == "mcp.oauth_required" { + t.Fatalf("did not expect MCP auth interest registration in %+v", requests) + } + } +} + +func assertMCPAuthInterest(t *testing.T, request recordedRequest) { + t.Helper() + if request.Method != "session.eventLog.registerInterest" { + t.Fatalf("expected registerInterest request, got %s", request.Method) + } + if request.Params["eventType"] != "mcp.oauth_required" { + t.Fatalf("expected mcp.oauth_required interest, got %v", request.Params["eventType"]) + } +} + +func assertCreateRequestPermission(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.create" { + if request.Params["requestPermission"] != true { + t.Fatalf("expected create requestPermission=true, got %v", request.Params["requestPermission"]) + } + return + } + } + t.Fatalf("session.create request not found in %+v", requests) +} + func TestCreateSessionRequest_Commands(t *testing.T) { t.Run("forwards commands in session.create RPC", func(t *testing.T) { req := createSessionRequest{ diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index b1d9df103..034d05a7a 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -2426,6 +2426,73 @@ type MCPOauthLoginResult struct { AuthorizationURL *string `json:"authorizationUrl,omitempty"` } +// Pending MCP OAuth request ID and host-provided token or cancellation response. +// Experimental: MCPOauthHandlePendingRequest is part of an experimental API and may change +// or be removed. +type MCPOauthHandlePendingRequest struct { + // OAuth request identifier for the pending request. + RequestID string `json:"requestId"` + // Host response to the pending OAuth request. + Result MCPOauthPendingRequestResponse `json:"result"` +} + +// Indicates whether the pending MCP OAuth response was accepted. +// Experimental: MCPOauthHandlePendingResult is part of an experimental API and may change +// or be removed. +type MCPOauthHandlePendingResult struct { + // Whether the response was accepted. False if the request was unknown, timed out, or + // already resolved. + Success bool `json:"success"` +} + +// Host response to the pending OAuth request. +// Experimental: MCPOauthPendingRequestResponse is part of an experimental API and may +// change or be removed. +type MCPOauthPendingRequestResponse interface { + mcpOauthPendingRequestResponse() + Kind() MCPOauthPendingRequestResponseKind +} + +type RawMCPOauthPendingRequestResponseData struct { + Discriminator MCPOauthPendingRequestResponseKind + Raw json.RawMessage +} + +func (RawMCPOauthPendingRequestResponseData) mcpOauthPendingRequestResponse() {} +func (r RawMCPOauthPendingRequestResponseData) Kind() MCPOauthPendingRequestResponseKind { + return r.Discriminator +} + +// Schema for the `McpOauthPendingRequestResponseCancelled` type. +// Experimental: MCPOauthPendingRequestResponseCancelled is part of an experimental API and +// may change or be removed. +type MCPOauthPendingRequestResponseCancelled struct { +} + +func (MCPOauthPendingRequestResponseCancelled) mcpOauthPendingRequestResponse() {} +func (MCPOauthPendingRequestResponseCancelled) Kind() MCPOauthPendingRequestResponseKind { + return MCPOauthPendingRequestResponseKindCancelled +} + +// Schema for the `McpOauthPendingRequestResponseToken` type. +// Experimental: MCPOauthPendingRequestResponseToken is part of an experimental API and may +// change or be removed. +type MCPOauthPendingRequestResponseToken struct { + // Access token acquired by the SDK host + AccessToken string `json:"accessToken"` + // Token lifetime in seconds, if known. + ExpiresIn *int64 `json:"expiresIn,omitempty"` + // Refresh token supplied by the host, if available. + RefreshToken *string `json:"refreshToken,omitempty"` + // OAuth token type. Defaults to Bearer when omitted. + TokenType *string `json:"tokenType,omitempty"` +} + +func (MCPOauthPendingRequestResponseToken) mcpOauthPendingRequestResponse() {} +func (MCPOauthPendingRequestResponseToken) Kind() MCPOauthPendingRequestResponseKind { + return MCPOauthPendingRequestResponseKindToken +} + // MCP OAuth request id and optional provider response. // Experimental: MCPOauthRespondRequest is part of an experimental API and may change or be // removed. @@ -2445,6 +2512,18 @@ type MCPOauthRespondRequest struct { type MCPOauthRespondResult struct { } +// Allowed values for the `McpOauthPendingRequestResponse` discriminator. +// Experimental: MCPOauthPendingRequestResponseKind is part of an experimental API and may +// change or be removed. +type MCPOauthPendingRequestResponseKind string + +const ( + // Schema for the `McpOauthPendingRequestResponseCancelled` type. + MCPOauthPendingRequestResponseKindCancelled MCPOauthPendingRequestResponseKind = "cancelled" + // Schema for the `McpOauthPendingRequestResponseToken` type. + MCPOauthPendingRequestResponseKindToken MCPOauthPendingRequestResponseKind = "token" +) + // Registration parameters for an external MCP client. // Experimental: MCPRegisterExternalClientRequest is part of an experimental API and may // change or be removed. @@ -12724,6 +12803,31 @@ func (s *MCPAPI) Apps() *MCPAppsAPI { // Experimental: MCPOauthAPI contains experimental APIs that may change or be removed. type MCPOauthAPI sessionAPI +// HandlePendingRequest resolves a pending MCP OAuth request with a host-provided token or +// cancellation. +// +// RPC method: session.mcp.oauth.handlePendingRequest. +// +// Parameters: Pending MCP OAuth request ID and host-provided token or cancellation response. +// +// Returns: Indicates whether the pending MCP OAuth response was accepted. +func (a *MCPOauthAPI) HandlePendingRequest(ctx context.Context, params *MCPOauthHandlePendingRequest) (*MCPOauthHandlePendingResult, error) { + req := map[string]any{"sessionId": a.sessionID} + if params != nil { + req["requestId"] = params.RequestID + req["result"] = params.Result + } + raw, err := a.client.Request("session.mcp.oauth.handlePendingRequest", req) + if err != nil { + return nil, err + } + var result MCPOauthHandlePendingResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + // Login starts OAuth authentication for a remote MCP server. // // RPC method: session.mcp.oauth.login. diff --git a/go/rpc/zrpc_encoding.go b/go/rpc/zrpc_encoding.go index bf455dc7e..ff65ac048 100644 --- a/go/rpc/zrpc_encoding.go +++ b/go/rpc/zrpc_encoding.go @@ -1076,6 +1076,89 @@ func (r *MCPConfigUpdateRequest) UnmarshalJSON(data []byte) error { return nil } +func unmarshalMCPOauthPendingRequestResponse(data []byte) (MCPOauthPendingRequestResponse, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind MCPOauthPendingRequestResponseKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case MCPOauthPendingRequestResponseKindCancelled: + var d MCPOauthPendingRequestResponseCancelled + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case MCPOauthPendingRequestResponseKindToken: + var d MCPOauthPendingRequestResponseToken + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawMCPOauthPendingRequestResponseData{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawMCPOauthPendingRequestResponseData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind MCPOauthPendingRequestResponseKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r MCPOauthPendingRequestResponseCancelled) MarshalJSON() ([]byte, error) { + type alias MCPOauthPendingRequestResponseCancelled + return json.Marshal(struct { + Kind MCPOauthPendingRequestResponseKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r MCPOauthPendingRequestResponseToken) MarshalJSON() ([]byte, error) { + type alias MCPOauthPendingRequestResponseToken + return json.Marshal(struct { + Kind MCPOauthPendingRequestResponseKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *MCPOauthHandlePendingRequest) UnmarshalJSON(data []byte) error { + type rawMCPOauthHandlePendingRequest struct { + RequestID string `json:"requestId"` + Result json.RawMessage `json:"result"` + } + var raw rawMCPOauthHandlePendingRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.RequestID = raw.RequestID + if raw.Result != nil { + value, err := unmarshalMCPOauthPendingRequestResponse(raw.Result) + if err != nil { + return err + } + r.Result = value + } + return nil +} + func unmarshalPermissionDecision(data []byte) (PermissionDecision, error) { if string(data) == "null" { return nil, nil diff --git a/go/rpc/zsession_events.go b/go/rpc/zsession_events.go index 98d008256..535d6b3e3 100644 --- a/go/rpc/zsession_events.go +++ b/go/rpc/zsession_events.go @@ -645,6 +645,8 @@ func (*MCPAppToolCallCompleteData) Type() SessionEventType { // MCP OAuth request completion notification type MCPOauthCompletedData struct { + // How the pending OAuth request was completed + Outcome MCPOauthCompletedOutcome `json:"outcome"` // Request ID of the resolved OAuth request RequestID string `json:"requestId"` } @@ -688,7 +690,7 @@ func (*SessionRemoteSteerableChangedData) Type() SessionEventType { // OAuth authentication request for an MCP server type MCPOauthRequiredData struct { - // Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() + // Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest RequestID string `json:"requestId"` // Display name of the MCP server that requires OAuth ServerName string `json:"serverName"` @@ -696,6 +698,8 @@ type MCPOauthRequiredData struct { ServerURL string `json:"serverUrl"` // Static OAuth client configuration, if the server specifies one StaticClientConfig *MCPOauthRequiredStaticClientConfig `json:"staticClientConfig,omitempty"` + // Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + WwwAuthenticateParams MCPOauthRequiredWwwAuthenticateParams `json:"wwwAuthenticateParams"` } func (*MCPOauthRequiredData) sessionEventData() {} @@ -1905,6 +1909,16 @@ type MCPOauthRequiredStaticClientConfig struct { PublicClient *bool `json:"publicClient,omitempty"` } +// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. +type MCPOauthRequiredWwwAuthenticateParams struct { + // Parsed OAuth error from the WWW-Authenticate header, if present + Error *string `json:"error,omitempty"` + // Parsed resource_metadata URL from the WWW-Authenticate header + ResourceMetadataURL string `json:"resourceMetadataUrl"` + // Parsed OAuth scope from the WWW-Authenticate header, if present + Scope *string `json:"scope,omitempty"` +} + // Schema for the `McpServersLoadedServer` type. type MCPServersLoadedServer struct { // Error message if the server failed to connect @@ -3061,6 +3075,18 @@ const ( HandoffSourceTypeRemote HandoffSourceType = "remote" ) +// How the pending OAuth request was completed +type MCPOauthCompletedOutcome string + +const ( + // The pending OAuth request was cancelled or declined without a token/provider. + MCPOauthCompletedOutcomeCancelled MCPOauthCompletedOutcome = "cancelled" + // The pending OAuth request timed out before any client responded. + MCPOauthCompletedOutcomeTimeout MCPOauthCompletedOutcome = "timeout" + // The pending OAuth request was resolved with a host-provided token/provider. + MCPOauthCompletedOutcomeToken MCPOauthCompletedOutcome = "token" +) + // Optional non-default OAuth grant type. When set to 'client_credentials', the OAuth flow runs headlessly using the client_id + keychain-stored secret (no browser, no callback server). type MCPOauthRequiredStaticClientConfigGrantType string diff --git a/go/session.go b/go/session.go index ca67cb2c8..265d18343 100644 --- a/go/session.go +++ b/go/session.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log" "sync" "time" @@ -61,6 +62,8 @@ type Session struct { toolHandlersM sync.RWMutex permissionHandler PermissionHandlerFunc permissionMux sync.RWMutex + mcpAuthHandler MCPAuthHandler + mcpAuthMu sync.RWMutex userInputHandler UserInputHandler userInputMux sync.RWMutex exitPlanModeHandler ExitPlanModeRequestHandler @@ -863,6 +866,46 @@ func (s *Session) getElicitationHandler() ElicitationHandler { return s.elicitationHandler } +func (s *Session) registerMCPAuthHandler(handler MCPAuthHandler) { + s.mcpAuthMu.Lock() + defer s.mcpAuthMu.Unlock() + s.mcpAuthHandler = handler +} + +func (s *Session) getMCPAuthHandler() MCPAuthHandler { + s.mcpAuthMu.RLock() + defer s.mcpAuthMu.RUnlock() + return s.mcpAuthHandler +} + +func (s *Session) handleMCPAuthRequest(request MCPAuthRequest) { + handler := s.getMCPAuthHandler() + if handler == nil { + return + } + + ctx := context.Background() + cancel := &rpc.MCPOauthPendingRequestResponseCancelled{} + result, err := handler(request, MCPAuthInvocation{SessionID: s.SessionID}) + if err != nil || result == nil || result.Kind == "cancelled" || result.Token == nil { + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: cancel, + }) + return + } + + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: &rpc.MCPOauthPendingRequestResponseToken{ + AccessToken: result.Token.AccessToken, + TokenType: result.Token.TokenType, + RefreshToken: result.Token.RefreshToken, + ExpiresIn: result.Token.ExpiresIn, + }, + }) +} + // handleElicitationRequest dispatches an elicitation.requested event to the registered handler // and sends the result back via the RPC layer. Auto-cancels on error. func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, requestID string) { @@ -1309,6 +1352,50 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { } s.executePermissionAndRespond(d.RequestID, d.PermissionRequest, handler) + case *MCPOauthRequiredData: + handler := s.getMCPAuthHandler() + if d.RequestID == "" { + return + } + if handler == nil { + log.Printf( + "Received MCP OAuth request without a registered MCP auth handler. SessionId=%s, RequestId=%s", + s.SessionID, + d.RequestID, + ) + return + } + var staticClientConfig *MCPAuthStaticClientConfig + if d.StaticClientConfig != nil { + var grantType string + if d.StaticClientConfig.GrantType != nil { + grantType = string(*d.StaticClientConfig.GrantType) + } + staticClientConfig = &MCPAuthStaticClientConfig{ + ClientID: d.StaticClientConfig.ClientID, + GrantType: grantType, + PublicClient: d.StaticClientConfig.PublicClient, + } + } + var scope, oauthError string + if d.WwwAuthenticateParams.Scope != nil { + scope = *d.WwwAuthenticateParams.Scope + } + if d.WwwAuthenticateParams.Error != nil { + oauthError = *d.WwwAuthenticateParams.Error + } + s.handleMCPAuthRequest(MCPAuthRequest{ + RequestID: d.RequestID, + ServerName: d.ServerName, + ServerURL: d.ServerURL, + WwwAuthenticateParams: MCPAuthWwwAuthenticateParams{ + ResourceMetadataURL: d.WwwAuthenticateParams.ResourceMetadataURL, + Scope: scope, + Error: oauthError, + }, + StaticClientConfig: staticClientConfig, + }) + case *CommandExecuteData: s.executeCommandAndRespond(d.RequestID, d.CommandName, d.Command, d.Args) diff --git a/go/session_test.go b/go/session_test.go index 15cfbcf57..7cc9caece 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -60,6 +60,110 @@ func TestSession_SetModelOmitsContextTierWhenUnset(t *testing.T) { } } +func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + defer client.Stop() + + paramsCh := make(chan map[string]any, 1) + errCh := make(chan error, 1) + + go func() { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + errCh <- err + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + errCh <- err + return + } + if request.Method != "session.mcp.oauth.handlePendingRequest" { + errCh <- fmt.Errorf("expected session.mcp.oauth.handlePendingRequest, got %s", request.Method) + return + } + + paramsCh <- request.Params + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": map[string]any{"success": true}, + } + data, err := json.Marshal(response) + if err != nil { + errCh <- err + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + errCh <- err + } + }() + + session := &Session{ + SessionID: "session-1", + client: client, + RPC: rpc.NewSessionRPC(client, "session-1"), + } + session.registerMCPAuthHandler(func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) { + if invocation.SessionID != "session-1" { + t.Fatalf("expected invocation session-1, got %s", invocation.SessionID) + } + if request.RequestID != "oauth-request" { + t.Fatalf("expected oauth-request, got %s", request.RequestID) + } + tokenType := "Bearer" + return &MCPAuthResult{ + Kind: "token", + Token: &MCPAuthToken{ + AccessToken: "host-token", + TokenType: &tokenType, + }, + }, nil + }) + session.handleMCPAuthRequest(MCPAuthRequest{RequestID: "oauth-request"}) + + select { + case params := <-paramsCh: + if params["sessionId"] != "session-1" { + t.Fatalf("expected sessionId session-1, got %v", params["sessionId"]) + } + if params["requestId"] != "oauth-request" { + t.Fatalf("expected requestId oauth-request, got %v", params["requestId"]) + } + result, ok := params["result"].(map[string]any) + if !ok { + t.Fatalf("expected result object, got %T", params["result"]) + } + if result["kind"] != "token" { + t.Fatalf("expected token kind, got %v", result["kind"]) + } + if result["accessToken"] != "host-token" { + t.Fatalf("expected accessToken host-token, got %v", result["accessToken"]) + } + if result["tokenType"] != "Bearer" { + t.Fatalf("expected tokenType Bearer, got %v", result["tokenType"]) + } + case err := <-errCh: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for MCP OAuth request") + } +} + func captureSetModelRequest(t *testing.T, opts *SetModelOptions) map[string]any { t.Helper() diff --git a/go/types.go b/go/types.go index 7ffd454a3..d9594fd49 100644 --- a/go/types.go +++ b/go/types.go @@ -309,6 +309,51 @@ type PermissionInvocation struct { SessionID string } +// MCPAuthWwwAuthenticateParams contains parsed parameters from an MCP server's WWW-Authenticate response. +type MCPAuthWwwAuthenticateParams struct { + ResourceMetadataURL string `json:"resourceMetadataUrl"` + Scope string `json:"scope,omitempty"` + Error string `json:"error,omitempty"` +} + +// MCPAuthStaticClientConfig is static OAuth client configuration supplied by an MCP server. +type MCPAuthStaticClientConfig struct { + ClientID string `json:"clientId"` + GrantType string `json:"grantType,omitempty"` + PublicClient *bool `json:"publicClient,omitempty"` +} + +// MCPAuthRequest describes an MCP OAuth request that the SDK host can satisfy with a token. +type MCPAuthRequest struct { + RequestID string `json:"requestId"` + ServerName string `json:"serverName"` + ServerURL string `json:"serverUrl"` + WwwAuthenticateParams MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams"` + StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` +} + +// MCPAuthToken is host-provided OAuth token data for a pending MCP OAuth request. +type MCPAuthToken struct { + AccessToken string `json:"accessToken"` + TokenType *string `json:"tokenType,omitempty"` + RefreshToken *string `json:"refreshToken,omitempty"` + ExpiresIn *int64 `json:"expiresIn,omitempty"` +} + +// MCPAuthResult is the result returned by an MCP auth request handler. +type MCPAuthResult struct { + Kind string + Token *MCPAuthToken +} + +// MCPAuthInvocation provides context about an MCP auth handler invocation. +type MCPAuthInvocation struct { + SessionID string +} + +// MCPAuthHandler handles MCP OAuth requests from the runtime. +type MCPAuthHandler func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) + // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { Question string @@ -952,6 +997,10 @@ type SessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // When provided, the SDK can satisfy MCP server OAuth requests with host-provided + // token data or cancellation. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events @@ -1324,6 +1373,9 @@ type ResumeSessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // See SessionConfig.OnMCPAuthRequest. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events diff --git a/go/zsession_events.go b/go/zsession_events.go index 5a8962efd..ce69dae02 100644 --- a/go/zsession_events.go +++ b/go/zsession_events.go @@ -88,10 +88,12 @@ type ( MCPAppToolCallCompleteToolMeta = rpc.MCPAppToolCallCompleteToolMeta MCPAppToolCallCompleteToolMetaUI = rpc.MCPAppToolCallCompleteToolMetaUI MCPOauthCompletedData = rpc.MCPOauthCompletedData - MCPOauthRequiredData = rpc.MCPOauthRequiredData - MCPOauthRequiredStaticClientConfig = rpc.MCPOauthRequiredStaticClientConfig - MCPOauthRequiredStaticClientConfigGrantType = rpc.MCPOauthRequiredStaticClientConfigGrantType - MCPServersLoadedServer = rpc.MCPServersLoadedServer + MCPOauthCompletedOutcome = rpc.MCPOauthCompletedOutcome + MCPOauthRequiredData = rpc.MCPOauthRequiredData + MCPOauthRequiredStaticClientConfig = rpc.MCPOauthRequiredStaticClientConfig + MCPOauthRequiredStaticClientConfigGrantType = rpc.MCPOauthRequiredStaticClientConfigGrantType + MCPOauthRequiredWwwAuthenticateParams = rpc.MCPOauthRequiredWwwAuthenticateParams + MCPServersLoadedServer = rpc.MCPServersLoadedServer MCPServerSource = rpc.MCPServerSource MCPServerStatus = rpc.MCPServerStatus MCPServerTransport = rpc.MCPServerTransport @@ -326,7 +328,10 @@ const ( ExtensionsLoadedExtensionStatusStarting = rpc.ExtensionsLoadedExtensionStatusStarting HandoffSourceTypeLocal = rpc.HandoffSourceTypeLocal HandoffSourceTypeRemote = rpc.HandoffSourceTypeRemote - MCPOauthRequiredStaticClientConfigGrantTypeClientCredentials = rpc.MCPOauthRequiredStaticClientConfigGrantTypeClientCredentials + MCPOauthCompletedOutcomeCancelled = rpc.MCPOauthCompletedOutcomeCancelled + MCPOauthCompletedOutcomeTimeout = rpc.MCPOauthCompletedOutcomeTimeout + MCPOauthCompletedOutcomeToken = rpc.MCPOauthCompletedOutcomeToken + MCPOauthRequiredStaticClientConfigGrantTypeClientCredentials = rpc.MCPOauthRequiredStaticClientConfigGrantTypeClientCredentials MCPServerSourceBuiltin = rpc.MCPServerSourceBuiltin MCPServerSourcePlugin = rpc.MCPServerSourcePlugin MCPServerSourceUser = rpc.MCPServerSourceUser diff --git a/java/scripts/codegen/java.ts b/java/scripts/codegen/java.ts index 842ba772e..f9aa54243 100644 --- a/java/scripts/codegen/java.ts +++ b/java/scripts/codegen/java.ts @@ -1283,7 +1283,7 @@ function generateRpcClass( return { code: lines.join("\n"), imports }; } -async function generateRpcTypes(schemaPath: string): Promise { +async function generateRpcTypes(schemaPath: string, sessionEventsSchemaPath: string): Promise { console.log("\n🔌 Generating RPC types..."); const schemaContent = await fs.readFile(schemaPath, "utf-8"); const schema = normalizeSchemaBrandCasing(JSON.parse(schemaContent)) as Record & { @@ -1301,7 +1301,6 @@ async function generateRpcTypes(schemaPath: string): Promise { // Load cross-schema definitions (session-events) so that cross-schema $ref values // like "session-events.schema.json#/definitions/Foo" can be resolved. try { - const sessionEventsSchemaPath = await getSessionEventsSchemaPath(); const sessionEventsContent = await fs.readFile(sessionEventsSchemaPath, "utf-8"); const sessionEventsSchema = normalizeSchemaBrandCasing(JSON.parse(sessionEventsContent) as JSONSchema7); crossSchemaDefinitions.set("session-events.schema.json", @@ -2065,13 +2064,13 @@ async function main(): Promise { await fs.rm(generatedOutputDir, { recursive: true, force: true }); await fs.mkdir(generatedOutputDir, { recursive: true }); - const sessionEventsSchemaPath = await getSessionEventsSchemaPath(); + const sessionEventsSchemaPath = process.argv[2] ?? await getSessionEventsSchemaPath(); console.log(`📄 Session events schema: ${sessionEventsSchemaPath}`); - const apiSchemaPath = await getApiSchemaPath(); + const apiSchemaPath = process.argv[3] ?? await getApiSchemaPath(); console.log(`📄 API schema: ${apiSchemaPath}`); await generateSessionEvents(sessionEventsSchemaPath); - await generateRpcTypes(apiSchemaPath); + await generateRpcTypes(apiSchemaPath, sessionEventsSchemaPath); await generateRpcWrappers(apiSchemaPath); console.log("\n✅ Java code generation complete!"); diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java index 635751b43..8bfa56849 100644 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java @@ -35,7 +35,9 @@ public final class McpOauthCompletedEvent extends SessionEvent { @JsonInclude(JsonInclude.Include.NON_NULL) public record McpOauthCompletedEventData( /** Request ID of the resolved OAuth request */ - @JsonProperty("requestId") String requestId + @JsonProperty("requestId") String requestId, + /** How the pending OAuth request was completed */ + @JsonProperty("outcome") McpOauthCompletedOutcome outcome ) { } } diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java new file mode 100644 index 000000000..e362f43bd --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java @@ -0,0 +1,37 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: session-events.schema.json + +package com.github.copilot.generated; + +import javax.annotation.processing.Generated; + +/** + * How the pending OAuth request was completed + * + * @since 1.0.0 + */ +@javax.annotation.processing.Generated("copilot-sdk-codegen") +public enum McpOauthCompletedOutcome { + /** The {@code token} variant. */ + TOKEN("token"), + /** The {@code cancelled} variant. */ + CANCELLED("cancelled"), + /** The {@code timeout} variant. */ + TIMEOUT("timeout"); + + private final String value; + McpOauthCompletedOutcome(String value) { this.value = value; } + @com.fasterxml.jackson.annotation.JsonValue + public String getValue() { return value; } + @com.fasterxml.jackson.annotation.JsonCreator + public static McpOauthCompletedOutcome fromValue(String value) { + for (McpOauthCompletedOutcome v : values()) { + if (v.value.equals(value)) return v; + } + throw new IllegalArgumentException("Unknown McpOauthCompletedOutcome value: " + value); + } +} diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java index 02e67a35f..4ebaf351a 100644 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java @@ -34,14 +34,16 @@ public final class McpOauthRequiredEvent extends SessionEvent { @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) public record McpOauthRequiredEventData( - /** Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() */ + /** Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest */ @JsonProperty("requestId") String requestId, /** Display name of the MCP server that requires OAuth */ @JsonProperty("serverName") String serverName, /** URL of the MCP server that requires OAuth */ @JsonProperty("serverUrl") String serverUrl, /** Static OAuth client configuration, if the server specifies one */ - @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig + @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig, + /** Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. */ + @JsonProperty("wwwAuthenticateParams") McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams ) { } } diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java new file mode 100644 index 000000000..072b09ab2 --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java @@ -0,0 +1,31 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: session-events.schema.json + +package com.github.copilot.generated; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.processing.Generated; + +/** + * Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + * + * @since 1.0.0 + */ +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record McpOauthRequiredWwwAuthenticateParams( + /** Parsed resource_metadata URL from the WWW-Authenticate header */ + @JsonProperty("resourceMetadataUrl") String resourceMetadataUrl, + /** Parsed OAuth scope from the WWW-Authenticate header, if present */ + @JsonProperty("scope") String scope, + /** Parsed OAuth error from the WWW-Authenticate header, if present */ + @JsonProperty("error") String error +) { +} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java index af0bca43e..74250b75e 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java @@ -26,7 +26,7 @@ public record SessionEventLogRegisterInterestParams( /** Target session identifier */ @JsonProperty("sessionId") String sessionId, - /** The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. */ + /** The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. */ @JsonProperty("eventType") String eventType ) { } diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java index 59c4e45a1..95a081206 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java @@ -46,6 +46,22 @@ public CompletableFuture respond(SessionMcpOauthRespondParams params) { return caller.invoke("session.mcp.oauth.respond", _p, Void.class); } + /** + * Pending MCP OAuth request ID and host-provided token or cancellation response. + *

+ * Note: the {@code sessionId} field in the params record is overridden + * by the session-scoped wrapper; any value provided is ignored. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ + @CopilotExperimental + public CompletableFuture handlePendingRequest(SessionMcpOauthHandlePendingRequestParams params) { + com.fasterxml.jackson.databind.node.ObjectNode _p = MAPPER.valueToTree(params); + _p.put("sessionId", this.sessionId); + return caller.invoke("session.mcp.oauth.handlePendingRequest", _p, SessionMcpOauthHandlePendingRequestResult.class); + } + /** * Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy. *

diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java new file mode 100644 index 000000000..5aab57ef9 --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java @@ -0,0 +1,34 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: api.schema.json + +package com.github.copilot.generated.rpc; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; +import javax.annotation.processing.Generated; + +/** + * Pending MCP OAuth request ID and host-provided token or cancellation response. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ +@CopilotExperimental +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record SessionMcpOauthHandlePendingRequestParams( + /** Target session identifier */ + @JsonProperty("sessionId") String sessionId, + /** OAuth request identifier for the pending request. */ + @JsonProperty("requestId") String requestId, + /** Host response to the pending OAuth request. */ + @JsonProperty("result") Object result +) { +} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java new file mode 100644 index 000000000..a7bca646e --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java @@ -0,0 +1,30 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: api.schema.json + +package com.github.copilot.generated.rpc; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; +import javax.annotation.processing.Generated; + +/** + * Indicates whether the pending MCP OAuth response was accepted. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ +@CopilotExperimental +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record SessionMcpOauthHandlePendingRequestResult( + /** Whether the response was accepted. False if the request was unknown, timed out, or already resolved. */ + @JsonProperty("success") Boolean success +) { +} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java index 9757a9538..d46890ca9 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java @@ -26,7 +26,7 @@ public record SessionMcpOauthRespondParams( /** Target session identifier */ @JsonProperty("sessionId") String sessionId, - /** OAuth request identifier from mcp.oauth_required */ + /** OAuth request identifier for the pending request. */ @JsonProperty("requestId") String requestId, /** In-process OAuthClientProvider instance, or omitted to deny. Marked internal: cannot be serialized across the JSON-RPC boundary. */ @JsonProperty("provider") Object provider diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index 50137aefe..c6b156926 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -27,6 +27,7 @@ import com.github.copilot.generated.rpc.SessionInstalledPlugin; import com.github.copilot.generated.rpc.ConnectParams; import com.github.copilot.generated.rpc.ServerRpc; +import com.github.copilot.generated.rpc.SessionEventLogRegisterInterestParams; import com.github.copilot.rpc.DeleteSessionResponse; import com.github.copilot.rpc.GetAuthStatusResponse; import com.github.copilot.rpc.GetLastSessionIdResponse; @@ -504,6 +505,7 @@ public CompletableFuture createSession(SessionConfig config) { String[] registeredIdHolder = new String[1]; CopilotSession[] preRegisteredSessionHolder = new CopilotSession[1]; + CompletableFuture preCreateInterest = CompletableFuture.completedFuture(null); // Pre-register non-cloud sessions BEFORE issuing the RPC so any // session-scoped requests the CLI emits during session.create @@ -511,6 +513,11 @@ public CompletableFuture createSession(SessionConfig config) { if (localSessionId != null) { preRegisteredSessionHolder[0] = initializeSession.apply(localSessionId); registeredIdHolder[0] = localSessionId; + if (config.getOnMcpAuthRequest() != null) { + preCreateInterest = preRegisteredSessionHolder[0].getRpc().eventLog + .registerInterest(new SessionEventLogRegisterInterestParams(localSessionId, + "mcp.oauth_required")); + } } var request = SessionRequestBuilder.buildCreateRequest(config, localSessionId); @@ -557,7 +564,8 @@ public CompletableFuture createSession(SessionConfig config) { } long rpcNanos = System.nanoTime(); - return connection.rpc.invoke("session.create", request, CreateSessionResponse.class) + return preCreateInterest.thenCompose(ignored -> connection.rpc.invoke("session.create", request, + CreateSessionResponse.class)) .thenCompose(response -> { String returnedId = response.sessionId(); LoggingHelpers.logTiming(LOG, Level.FINE, @@ -575,14 +583,22 @@ public CompletableFuture createSession(SessionConfig config) { ? preRegisteredSessionHolder[0] : initializeSession.apply(returnedId); registeredIdHolder[0] = returnedId; + // Local IDs registered before create; server-assigned IDs can only register now. + CompletableFuture interest = config.getOnMcpAuthRequest() != null + && preRegisteredSessionHolder[0] == null + ? session.getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(returnedId, + "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); session.setWorkspacePath(response.workspacePath()); session.setCapabilities(response.capabilities()); session.setOpenCanvases(response.openCanvases()); - return updateSessionOptionsForMode(session, config.getSkipCustomInstructions().orElse(null), + return interest.thenCompose(ignored -> updateSessionOptionsForMode(session, + config.getSkipCustomInstructions().orElse(null), config.getCustomAgentsLocalOnly().orElse(null), config.getCoauthorEnabled().orElse(null), - config.getManageScheduleEnabled().orElse(null)).thenApply(v -> { + config.getManageScheduleEnabled().orElse(null))).thenApply(v -> { LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.createSession complete. Elapsed={Elapsed}, SessionId=" + session.getSessionId(), @@ -651,6 +667,10 @@ public CompletableFuture resumeSession(String sessionId, ResumeS if (extracted.transformCallbacks() != null) { session.registerTransformCallbacks(extracted.transformCallbacks()); } + CompletableFuture interest = config.getOnMcpAuthRequest() != null + ? session.getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(sessionId, "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); var request = SessionRequestBuilder.buildResumeRequest(sessionId, config); if (extracted.wireSystemMessage() != config.getSystemMessage()) { @@ -694,7 +714,7 @@ public CompletableFuture resumeSession(String sessionId, ResumeS } long rpcNanos = System.nanoTime(); - return connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class) + return interest.thenCompose(ignored -> connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class)) .thenCompose(response -> { LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.resumeSession session resume request completed. Elapsed={Elapsed}, SessionId=" diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index fa080c925..2c7f07f4a 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -33,6 +33,7 @@ import com.github.copilot.generated.rpc.SessionCommandsHandlePendingCommandParams; import com.github.copilot.generated.rpc.SessionLogParams; import com.github.copilot.generated.rpc.SessionLogLevel; +import com.github.copilot.generated.rpc.SessionMcpOauthHandlePendingRequestParams; import com.github.copilot.generated.rpc.ModelCapabilitiesOverride; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideLimits; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideSupports; @@ -49,6 +50,7 @@ import com.github.copilot.generated.CommandExecuteEvent; import com.github.copilot.generated.ElicitationRequestedEvent; import com.github.copilot.generated.ExternalToolRequestedEvent; +import com.github.copilot.generated.McpOauthRequiredEvent; import com.github.copilot.generated.PermissionRequestedEvent; import com.github.copilot.generated.SessionCanvasClosedEvent; import com.github.copilot.generated.SessionCanvasOpenedEvent; @@ -79,6 +81,9 @@ import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.McpAuthHandler; +import com.github.copilot.rpc.McpAuthRequest; +import com.github.copilot.rpc.McpAuthResult; import com.github.copilot.rpc.PermissionHandler; import com.github.copilot.rpc.PermissionInvocation; import com.github.copilot.rpc.PermissionRequest; @@ -170,6 +175,7 @@ public final class CopilotSession implements AutoCloseable { private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); + private final AtomicReference mcpAuthHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); private final AtomicReference exitPlanModeHandler = new AtomicReference<>(); @@ -838,6 +844,19 @@ private void handleBroadcastEventAsync(SessionEvent event) { } executePermissionAndRespondAsync(data.requestId(), MAPPER.convertValue(data.permissionRequest(), PermissionRequest.class), handler); + } else if (event instanceof McpOauthRequiredEvent authEvent) { + var data = authEvent.getData(); + if (data == null || data.requestId() == null) { + return; + } + McpAuthHandler handler = mcpAuthHandler.get(); + if (handler == null) { + LOG.warning(() -> "Received MCP OAuth request without a registered MCP auth handler. SessionId=" + + sessionId + ", RequestId=" + data.requestId()); + return; + } + executeMcpAuthAndRespondAsync(new McpAuthRequest(sessionId, data.requestId(), data.serverName(), + data.serverUrl(), data.wwwAuthenticateParams(), data.staticClientConfig()), handler); } else if (event instanceof CommandExecuteEvent cmdEvent) { var data = cmdEvent.getData(); if (data == null || data.requestId() == null || data.commandName() == null) { @@ -1005,6 +1024,60 @@ private void executePermissionAndRespondAsync(String requestId, PermissionReques } } + private void executeMcpAuthAndRespondAsync(McpAuthRequest request, McpAuthHandler handler) { + Runnable task = () -> { + try { + handler.handle(request).thenAccept(result -> sendMcpAuthResponse(request.requestId(), result)) + .exceptionally(ex -> { + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + return null; + }); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error executing MCP auth handler for requestId=" + request.requestId(), e); + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + } + }; + try { + if (executor != null) { + CompletableFuture.runAsync(task, executor); + } else { + CompletableFuture.runAsync(task); + } + } catch (RejectedExecutionException e) { + LOG.log(Level.WARNING, + "Executor rejected MCP auth task for requestId=" + request.requestId() + "; running inline", e); + task.run(); + } + } + + private void sendMcpAuthResponse(String requestId, McpAuthResult result) { + try { + Object response; + if (result == null || result.cancelled() || result.token() == null) { + response = Map.of("kind", "cancelled"); + } else { + var token = result.token(); + var tokenResponse = new java.util.HashMap(); + tokenResponse.put("kind", "token"); + tokenResponse.put("accessToken", token.accessToken()); + if (token.tokenType() != null) { + tokenResponse.put("tokenType", token.tokenType()); + } + if (token.refreshToken() != null) { + tokenResponse.put("refreshToken", token.refreshToken()); + } + if (token.expiresIn() != null) { + tokenResponse.put("expiresIn", token.expiresIn()); + } + response = tokenResponse; + } + getRpc().mcp.oauth + .handlePendingRequest(new SessionMcpOauthHandlePendingRequestParams(sessionId, requestId, response)); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error sending MCP auth response for requestId=" + requestId, e); + } + } + /** * Registers custom tool handlers for this session. *

@@ -1268,6 +1341,10 @@ void registerPermissionHandler(PermissionHandler handler) { permissionHandler.set(handler); } + void registerMcpAuthHandler(McpAuthHandler handler) { + mcpAuthHandler.set(handler); + } + /** * Handles a permission request from the Copilot CLI. *

diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index ded92a506..cba37b07f 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -309,6 +309,9 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } @@ -351,6 +354,9 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java new file mode 100644 index 000000000..e92ef1dd4 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.concurrent.CompletableFuture; + +/** + * Handles MCP OAuth requests from the runtime. + * + * @since 1.0.0 + */ +@FunctionalInterface +public interface McpAuthHandler { + /** + * Handles an MCP OAuth request. + * + * @param request + * the MCP OAuth request context + * @return a future resolving to token data or cancellation + */ + CompletableFuture handle(McpAuthRequest request); +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java new file mode 100644 index 000000000..ae5222f4c --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java @@ -0,0 +1,22 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import com.github.copilot.generated.McpOauthRequiredStaticClientConfig; +import com.github.copilot.generated.McpOauthRequiredWwwAuthenticateParams; + +/** + * MCP OAuth request that the SDK host can satisfy with a host-acquired token. + * + * @since 1.0.0 + */ +public record McpAuthRequest( + String sessionId, + String requestId, + String serverName, + String serverUrl, + McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams, + McpOauthRequiredStaticClientConfig staticClientConfig) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java new file mode 100644 index 000000000..b8a0acfc4 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Result returned by an MCP auth request handler. + * + * @since 1.0.0 + */ +public record McpAuthResult(boolean cancelled, McpAuthToken token) { + /** + * Creates a token result. + * + * @param token + * the host-provided OAuth token data + * @return token result + */ + public static McpAuthResult token(McpAuthToken token) { + return new McpAuthResult(false, token); + } + + /** + * Creates a cancellation result. + * + * @return cancellation result + */ + public static McpAuthResult cancelled() { + return new McpAuthResult(true, null); + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java new file mode 100644 index 000000000..5df1b33ff --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java @@ -0,0 +1,13 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Host-provided OAuth token data for a pending MCP OAuth request. + * + * @since 1.0.0 + */ +public record McpAuthToken(String accessToken, String tokenType, String refreshToken, Long expiresIn) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java index aee27e1b1..b28b13336 100644 --- a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java @@ -55,6 +55,7 @@ public class ResumeSessionConfig { private String contextTier; private ModelCapabilitiesOverride modelCapabilities; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -550,6 +551,28 @@ public ResumeSessionConfig setOnPermissionRequest(PermissionHandler onPermission return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public ResumeSessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1563,6 +1586,7 @@ public ResumeSessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java index fa7fd2244..66bd2cbec 100644 --- a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java @@ -55,6 +55,7 @@ public class SessionConfig { private Boolean coauthorEnabled; private Boolean manageScheduleEnabled; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -592,6 +593,31 @@ public SessionConfig setOnPermissionRequest(PermissionHandler onPermissionReques return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + *

+ * When provided, the SDK can satisfy MCP server OAuth requests with host-provided + * token data or cancellation. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public SessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1686,6 +1712,7 @@ public SessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java new file mode 100644 index 000000000..328da7dcd --- /dev/null +++ b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java @@ -0,0 +1,246 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.rpc.CloudSessionOptions; +import com.github.copilot.rpc.CloudSessionRepository; +import com.github.copilot.rpc.CopilotClientOptions; +import com.github.copilot.rpc.McpAuthResult; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ResumeSessionConfig; +import com.github.copilot.rpc.SessionConfig; + +class McpAuthInterestRegistrationTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.createSession(new SessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.create".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client.createSession(new SessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))).get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.eventLog.registerInterest", requests.get(0).method()); + assertEquals("mcp.oauth_required", requests.get(0).params().path("eventType").asText()); + assertEquals("session.create", requests.get(1).method()); + } + } + + @Test + void cloudCreateSessionRegistersMcpAuthInterestAfterCreateOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + var cloud = new CloudSessionOptions() + .setRepository(new CloudSessionRepository().setOwner("github").setName("copilot-sdk").setBranch("main")); + + try (var session = client.createSession(new SessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setCloud(cloud)).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + server.clearRequests(); + + try (var session = client.createSession(new SessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setCloud(cloud) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))).get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.create", requests.get(0).method()); + assertEquals("session.eventLog.registerInterest", requests.get(1).method()); + assertEquals("mcp.oauth_required", requests.get(1).params().path("eventType").asText()); + } + } + + @Test + void resumeSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.resumeSession("session-without-auth", new ResumeSessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.resume".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client.resumeSession("session-with-auth", new ResumeSessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))).get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.eventLog.registerInterest", requests.get(0).method()); + assertEquals("mcp.oauth_required", requests.get(0).params().path("eventType").asText()); + assertEquals("session.resume", requests.get(1).method()); + } + } + + private static void assertNoMcpAuthInterest(List requests) { + assertFalse(requests.stream().anyMatch(request -> "session.eventLog.registerInterest".equals(request.method()) + && "mcp.oauth_required".equals(request.params().path("eventType").asText()))); + } + + private record RpcRequest(String method, JsonNode params) { + } + + private static final class RecordingRuntime implements AutoCloseable { + private final ServerSocket listener; + private final Thread thread; + private final List requests = new CopyOnWriteArrayList<>(); + private volatile boolean running = true; + + RecordingRuntime() throws Exception { + listener = new ServerSocket(0); + thread = new Thread(this::run, "mcp-auth-interest-test-runtime"); + thread.setDaemon(true); + thread.start(); + } + + String url() { + return "127.0.0.1:" + listener.getLocalPort(); + } + + List requests() { + return List.copyOf(requests); + } + + void clearRequests() { + requests.clear(); + } + + @Override + public void close() throws Exception { + running = false; + listener.close(); + thread.join(2000); + } + + private void run() { + try (Socket socket = listener.accept()) { + var in = socket.getInputStream(); + var out = socket.getOutputStream(); + while (running) { + JsonNode message = readMessage(in); + if (message == null) { + return; + } + String method = message.path("method").asText(); + requests.add(new RpcRequest(method, message.path("params").deepCopy())); + sendResponse(out, message.path("id").asLong(), resultFor(method, message.path("params"))); + } + } catch (Exception ex) { + if (running) { + throw new RuntimeException(ex); + } + } + } + + private static JsonNode resultFor(String method, JsonNode params) { + ObjectNode result = MAPPER.createObjectNode(); + switch (method) { + case "connect" -> { + result.put("ok", true); + result.put("protocolVersion", 3); + result.put("version", "test"); + } + case "session.create", "session.resume" -> { + String sessionId = params.path("sessionId").asText("server-assigned-session"); + if (sessionId.isEmpty()) { + sessionId = "server-assigned-session"; + } + result.put("sessionId", sessionId); + result.putNull("workspacePath"); + result.putNull("capabilities"); + } + case "session.eventLog.registerInterest" -> result.put("id", "interest-1"); + case "session.options.update" -> result.put("success", true); + case "session.skills.reload", "session.destroy" -> { + } + default -> throw new IllegalStateException("Unexpected RPC method " + method); + } + return result; + } + + private static JsonNode readMessage(java.io.InputStream in) throws Exception { + StringBuilder header = new StringBuilder(); + int b; + while ((b = in.read()) != -1) { + header.append((char) b); + if (header.toString().endsWith("\r\n\r\n")) { + break; + } + } + if (b == -1) { + return null; + } + int contentLength = 0; + for (String line : header.toString().split("\r\n")) { + int colon = line.indexOf(':'); + if (colon > 0 && "Content-Length".equals(line.substring(0, colon))) { + contentLength = Integer.parseInt(line.substring(colon + 1).trim()); + } + } + byte[] body = in.readNBytes(contentLength); + return MAPPER.readTree(body); + } + + private static void sendResponse(OutputStream out, long id, JsonNode result) throws Exception { + ObjectNode response = MAPPER.createObjectNode(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.set("result", result); + byte[] body = MAPPER.writeValueAsBytes(response); + out.write(("Content-Length: " + body.length + "\r\n\r\n").getBytes(StandardCharsets.UTF_8)); + out.write(body); + out.flush(); + } + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 8dc35b8d7..79b51f21b 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -1046,7 +1046,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); s.registerTools(config.tools); s.registerCanvases(config.canvases); @@ -1184,6 +1185,12 @@ export class CopilotClient { session = initializeSession(returnedSessionId); registeredId = returnedSessionId; } + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId: returnedSessionId, + eventType: "mcp.oauth_required", + }); + } session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); @@ -1233,7 +1240,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); session.registerTools(config.tools); session.registerCanvases(config.canvases); @@ -1270,6 +1278,12 @@ export class CopilotClient { } this.sessions.set(sessionId, session); this.setupSessionFs(session, config); + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId, + eventType: "mcp.oauth_required", + }); + } const toolFilterOptions = this.resolveToolFilterOptions(config); diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index c17a21d21..62060f9a9 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -627,6 +627,16 @@ export type McpServerConfigHttpOauthGrantType = | "authorization_code" /** Headless client credentials flow using the configured OAuth client. */ | "client_credentials"; +/** + * Host response to the pending OAuth request. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpOauthPendingRequestResponse". + */ +/** @experimental */ +export type McpOauthPendingRequestResponse = + | McpOauthPendingRequestResponseToken + | McpOauthPendingRequestResponseCancelled; /** * Outcome of the sampling inference. 'success' produced a response; 'failure' encountered an error (including agent-side rejection by content filter or criteria); 'cancelled' the caller cancelled this execution via cancelSamplingExecution. * @@ -5090,6 +5100,75 @@ export interface McpTools { */ description?: string; } +/** + * Pending MCP OAuth request ID and host-provided token or cancellation response. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpOauthHandlePendingRequest". + */ +/** @experimental */ +export interface McpOauthHandlePendingRequest { + /** + * OAuth request identifier for the pending request. + */ + requestId: string; + result: McpOauthPendingRequestResponse; +} +/** + * Schema for the `McpOauthPendingRequestResponseToken` type. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpOauthPendingRequestResponseToken". + */ +/** @experimental */ +export interface McpOauthPendingRequestResponseToken { + /** + * Supplies a host-acquired OAuth access token. + */ + kind: "token"; + /** + * Access token acquired by the SDK host + */ + accessToken: string; + /** + * OAuth token type. Defaults to Bearer when omitted. + */ + tokenType?: string; + /** + * Refresh token supplied by the host, if available. + */ + refreshToken?: string; + /** + * Token lifetime in seconds, if known. + */ + expiresIn?: number; +} +/** + * Schema for the `McpOauthPendingRequestResponseCancelled` type. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpOauthPendingRequestResponseCancelled". + */ +/** @experimental */ +export interface McpOauthPendingRequestResponseCancelled { + /** + * Declines or cancels the pending OAuth request. + */ + kind: "cancelled"; +} +/** + * Indicates whether the pending MCP OAuth response was accepted. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpOauthHandlePendingResult". + */ +/** @experimental */ +export interface McpOauthHandlePendingResult { + /** + * Whether the response was accepted. False if the request was unknown, timed out, or already resolved. + */ + success: boolean; +} /** * Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy. * @@ -5138,7 +5217,7 @@ export interface McpOauthLoginResult { /** @internal */ export interface McpOauthRespondRequest { /** - * OAuth request identifier from mcp.oauth_required + * OAuth request identifier for the pending request. */ requestId: string; /** @@ -5383,6 +5462,19 @@ export interface McpUnregisterExternalClientRequest { */ serverName: string; } +/** + * Memory configuration for this session. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "MemoryConfiguration". + */ +/** @experimental */ +export interface MemoryConfiguration { + /** + * Whether memory is enabled for the session. + */ + enabled: boolean; +} /** * Model identifier and token limits used to compute the context-info breakdown. * @@ -8048,7 +8140,7 @@ export interface QueueRemoveMostRecentResult { /** @experimental */ export interface RegisterEventInterestParams { /** - * The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. + * The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. */ eventType: string; } @@ -9736,6 +9828,7 @@ export interface SessionOpenOptions { * @experimental */ additionalContentExclusionPolicies?: SessionOpenOptionsAdditionalContentExclusionPolicy[]; + memory?: MemoryConfiguration; /** * Capabilities enabled for this session. */ @@ -13829,6 +13922,15 @@ export function createSessionRpc(connection: MessageConnection, sessionId: strin connection.sendRequest("session.mcp.isServerRunning", { sessionId, ...params }), /** @experimental */ oauth: { + /** + * Resolves a pending MCP OAuth request with a host-provided token or cancellation. + * + * @param params Pending MCP OAuth request ID and host-provided token or cancellation response. + * + * @returns Indicates whether the pending MCP OAuth response was accepted. + */ + handlePendingRequest: async (params: McpOauthHandlePendingRequest): Promise => + connection.sendRequest("session.mcp.oauth.handlePendingRequest", { sessionId, ...params }), /** * Starts OAuth authentication for a remote MCP server. * @@ -14673,7 +14775,7 @@ export function createInternalSessionRpc(connection: MessageConnection, sessionI /** @experimental */ oauth: { /** - * Responds to a pending MCP OAuth provider request. Marked internal because the `provider` argument is an in-process OAuthClientProvider instance that cannot be carried over the wire; the public OAuth surface will route the response through a wire-clean handshake once the CLI moves on top of the SDK. + * Responds to a pending MCP OAuth request with an in-process provider. Conceptually similar to handlePendingRequest, but marked internal because this legacy CLI-only path takes a live OAuthClientProvider instance that cannot be carried over the wire. Once the CLI is replatformed on the SDK and can use handlePendingRequest, this API should be removed. * * @param params MCP OAuth request id and optional provider response. * diff --git a/nodejs/src/generated/session-events.ts b/nodejs/src/generated/session-events.ts index a4fba8f33..9215b178d 100644 --- a/nodejs/src/generated/session-events.ts +++ b/nodejs/src/generated/session-events.ts @@ -428,6 +428,16 @@ export type ElicitationCompletedAction = * Schema for the `ElicitationCompletedContent` type. */ export type ElicitationCompletedContent = (string | number | boolean | string[]) | undefined; +/** + * How the pending OAuth request was completed + */ +export type McpOauthCompletedOutcome = + /** The pending OAuth request was resolved with a host-provided token/provider. */ + | "token" + /** The pending OAuth request was cancelled or declined without a token/provider. */ + | "cancelled" + /** The pending OAuth request timed out before any client responded. */ + | "timeout"; /** * Source-defined JSON payload for the custom notification */ @@ -5778,7 +5788,7 @@ export interface McpOauthRequiredEvent { */ export interface McpOauthRequiredData { /** - * Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() + * Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest */ requestId: string; /** @@ -5790,6 +5800,7 @@ export interface McpOauthRequiredData { */ serverUrl: string; staticClientConfig?: McpOauthRequiredStaticClientConfig; + wwwAuthenticateParams: McpOauthRequiredWwwAuthenticateParams; } /** * Static OAuth client configuration, if the server specifies one @@ -5808,6 +5819,23 @@ export interface McpOauthRequiredStaticClientConfig { */ publicClient?: boolean; } +/** + * Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + */ +export interface McpOauthRequiredWwwAuthenticateParams { + /** + * Parsed OAuth error from the WWW-Authenticate header, if present + */ + error?: string; + /** + * Parsed resource_metadata URL from the WWW-Authenticate header + */ + resourceMetadataUrl: string; + /** + * Parsed OAuth scope from the WWW-Authenticate header, if present + */ + scope?: string; +} /** * Session event "mcp.oauth_completed". MCP OAuth request completion notification */ @@ -5842,6 +5870,7 @@ export interface McpOauthCompletedEvent { * MCP OAuth request completion notification */ export interface McpOauthCompletedData { + outcome: McpOauthCompletedOutcome; /** * Request ID of the resolved OAuth request */ diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 8ae19755a..cc2bdee8d 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -10,7 +10,11 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; import { ConnectionError, ErrorCodes, ResponseError } from "vscode-jsonrpc/node.js"; import { createSessionRpc } from "./generated/rpc.js"; -import type { ClientSessionApiHandlers, CanvasActionInvokeResult } from "./generated/rpc.js"; +import type { + ClientSessionApiHandlers, + CanvasActionInvokeResult, + McpOauthPendingRequestResponse, +} from "./generated/rpc.js"; import { type Canvas, CanvasError } from "./canvas.js"; import type { OpenCanvasInstance } from "./generated/rpc.js"; import { getTraceContext } from "./telemetry.js"; @@ -28,6 +32,8 @@ import type { ExitPlanModeResult, UiInputOptions, MessageOptions, + McpAuthHandler, + McpAuthRequest, PermissionHandler, PermissionRequest, ContextTier, @@ -124,6 +130,7 @@ export class CopilotSession { private canvases: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; + private mcpAuthHandler?: McpAuthHandler; private userInputHandler?: UserInputHandler; private elicitationHandler?: ElicitationHandler; private exitPlanModeHandler?: ExitPlanModeHandler; @@ -151,9 +158,11 @@ export class CopilotSession { public readonly sessionId: string, private connection: MessageConnection, private _workspacePath?: string, - traceContextProvider?: TraceContextProvider + traceContextProvider?: TraceContextProvider, + options?: { mcpAuthHandler?: McpAuthHandler } ) { this.traceContextProvider = traceContextProvider; + this.mcpAuthHandler = options?.mcpAuthHandler; } /** @@ -479,6 +488,19 @@ export class CopilotSession { if (this.permissionHandler) { void this._executePermissionAndRespond(requestId, permissionRequest); } + } else if (event.type === "mcp.oauth_required") { + const data = event.data as McpAuthRequest | undefined; + if (!data?.requestId) { + return; + } + if (!this.mcpAuthHandler) { + console.warn( + "Received MCP OAuth request without a registered MCP auth handler. " + + `SessionId=${this.sessionId}, RequestId=${data.requestId}` + ); + return; + } + void this._executeMcpAuthAndRespond(data); } else if (event.type === "command.execute") { const { requestId, commandName, command, args } = event.data as { requestId: string; @@ -611,6 +633,7 @@ export class CopilotSession { if (result.kind === "no-result") { return; } + await this.rpc.permissions.handlePendingPermissionRequest({ requestId, result }); } catch (_error) { try { @@ -629,6 +652,35 @@ export class CopilotSession { } } + /** + * Executes an MCP auth handler and sends the result back via RPC. + * @internal + */ + private async _executeMcpAuthAndRespond(request: McpAuthRequest): Promise { + try { + const result = await this.mcpAuthHandler!(request, { sessionId: this.sessionId }); + const response: McpOauthPendingRequestResponse = + result && "accessToken" in result + ? { kind: "token", ...result } + : { kind: "cancelled" }; + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: response, + }); + } catch (_error) { + try { + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: { kind: "cancelled" }, + }); + } catch (rpcError) { + if (!(rpcError instanceof ConnectionError || rpcError instanceof ResponseError)) { + throw rpcError; + } + } + } + } + /** * Executes a command handler and sends the result back via RPC. * @internal diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 75aa5159f..166f6c861 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -14,8 +14,10 @@ import type { SessionEvent as GeneratedSessionEvent, } from "./generated/session-events.js"; import type { CopilotSession } from "./session.js"; -import type { RemoteSessionMode } from "./generated/rpc.js"; -import type { OpenCanvasInstance } from "./generated/rpc.js"; +import type { + OpenCanvasInstance, + RemoteSessionMode, +} from "./generated/rpc.js"; import type { ToolSet } from "./toolSet.js"; export type { RemoteSessionMode } from "./generated/rpc.js"; export type SessionEvent = GeneratedSessionEvent; @@ -1546,6 +1548,72 @@ export type ReasoningEffort = "low" | "medium" | "high" | "xhigh"; */ export type ContextTier = "default" | "long_context"; +/** Parsed parameters from an MCP server's WWW-Authenticate response. */ +export interface McpAuthWwwAuthenticateParams { + /** Parsed resource_metadata URL used for protected-resource metadata discovery. */ + resourceMetadataUrl: string; + /** Parsed OAuth scope, if present. */ + scope?: string; + /** Parsed OAuth error, if present. */ + error?: string; +} + +/** Static OAuth client configuration supplied by the MCP server, if available. */ +export interface McpAuthStaticClientConfig { + /** OAuth client ID for the server. */ + clientId: string; + /** Optional non-default OAuth grant type. */ + grantType?: "client_credentials"; + /** Whether this is a public OAuth client. */ + publicClient?: boolean; +} + +/** MCP OAuth request that the SDK host can satisfy with a host-acquired token. */ +export interface McpAuthRequest { + /** Unique request identifier used by the SDK when responding. */ + requestId: string; + /** Display name of the MCP server that requires OAuth. */ + serverName: string; + /** URL of the MCP server that requires OAuth. */ + serverUrl: string; + /** Parsed WWW-Authenticate parameters from the MCP server. */ + wwwAuthenticateParams: McpAuthWwwAuthenticateParams; + /** Static OAuth client configuration, if the server specifies one. */ + staticClientConfig?: McpAuthStaticClientConfig; +} + +/** Host-provided OAuth token data for a pending MCP OAuth request. */ +export interface McpAuthToken { + /** Access token acquired by the SDK host. */ + accessToken: string; + /** OAuth token type. Defaults to Bearer when omitted. */ + tokenType?: string; + /** Refresh token supplied by the host, if available. */ + refreshToken?: string; + /** Token lifetime in seconds, if known. */ + expiresIn?: number; +} + +/** + * Result returned by an MCP auth request handler. + * + * Return `null`/`undefined` or `{ kind: "cancelled" }` to cancel the pending + * OAuth request. Return `{ kind: "token", ... }` to provide host-acquired + * OAuth token data. + */ +export type McpAuthResult = ({ kind: "token" } & McpAuthToken) | { kind: "cancelled" }; + +/** Callback invoked when an MCP server requires OAuth and the SDK host opted in. */ +export type McpAuthHandler = ( + request: McpAuthRequest, + context: { sessionId: string } +) => + | McpAuthResult + | McpAuthToken + | null + | undefined + | Promise; + /** * Stable extension identity for session participants that provide canvases. */ @@ -1771,6 +1839,13 @@ export interface SessionConfigBase { */ onPermissionRequest?: PermissionHandler; + /** + * Optional handler for MCP OAuth requests from MCP servers. + * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. + */ + onMcpAuthRequest?: McpAuthHandler; + /** * Handler for user input requests from the agent. * When provided, enables the ask_user tool allowing the agent to ask questions. diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 9352eb627..e48d9805e 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -23,6 +23,182 @@ describe("CopilotClient", () => { expect(spy).not.toHaveBeenCalled(); }); + it("responds to MCP OAuth requests with host token data", async () => { + const sendRequest = vi.fn(async () => ({ success: true })); + const session = new CopilotSession( + "session-1", + { sendRequest } as any, + undefined, + undefined, + { + mcpAuthHandler: async () => ({ + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }), + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + wwwAuthenticateParams: { + resourceMetadataUrl: "https://example.com/.well-known/oauth-protected-resource", + }, + }); + + expect(sendRequest).toHaveBeenCalledWith("session.mcp.oauth.handlePendingRequest", { + sessionId: "session-1", + requestId: "oauth-request", + result: { + kind: "token", + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }, + }); + }); + + it("registers interest in MCP OAuth required events after create when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }), + ]); + expect(spy.mock.calls[1][1].sessionId).toBe(spy.mock.calls[0][1].sessionId); + }); + + it("does not register MCP OAuth interest without an auth handler", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.create", + expect.objectContaining({ requestPermission: true }) + ); + }); + + it("registers MCP OAuth interest after cloud create only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let cloudCreateCount = 0; + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") + return { sessionId: `server-assigned-session-${++cloudCreateCount}` }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + + spy.mockClear(); + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "server-assigned-session-2", eventType: "mcp.oauth_required" }, + ]); + }); + + it("registers MCP OAuth interest before resuming only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.resume") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.resumeSession("session-with-auth", { + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "session-with-auth", eventType: "mcp.oauth_required" }, + ]); + expect(spy.mock.calls[1][0]).toBe("session.resume"); + expect(spy.mock.calls[1][1]).toEqual(expect.objectContaining({ requestPermission: true })); + + spy.mockClear(); + await client.resumeSession("session-without-auth", { + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.resume", + expect.objectContaining({ sessionId: "session-without-auth", requestPermission: true }) + ); + }); + it("forwards canvas declarations and request flags in session.create", async () => { const client = new CopilotClient(); await client.start(); diff --git a/nodejs/test/e2e/mcp_oauth.e2e.test.ts b/nodejs/test/e2e/mcp_oauth.e2e.test.ts new file mode 100644 index 000000000..2817ac5c3 --- /dev/null +++ b/nodejs/test/e2e/mcp_oauth.e2e.test.ts @@ -0,0 +1,166 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { spawn, type ChildProcessWithoutNullStreams } from "node:child_process"; +import { dirname, resolve } from "node:path"; +import { createInterface } from "node:readline"; +import { fileURLToPath } from "node:url"; +import { describe, expect, it, onTestFinished } from "vitest"; +import type { CopilotSession, MCPServerConfig, McpAuthRequest } from "../../src/index.js"; +import { approveAll } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { waitForCondition } from "./harness/sdkTestHelper.js"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); +const TEST_MCP_OAUTH_SERVER = resolve(__dirname, "../../../test/harness/test-mcp-oauth-server.mjs"); +const EXPECTED_TOKEN = "sdk-host-token"; + +describe("MCP OAuth host auth", async () => { + const { copilotClient: client } = await createSdkTestContext(); + + it("should satisfy MCP OAuth using host-provided token", { timeout: 120_000 }, async () => { + const oauthServer = await startOAuthMcpServer(); + const serverName = "oauth-protected-mcp"; + let authRequest: McpAuthRequest | undefined; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: async (request) => { + authRequest = request; + return { + kind: "token", + accessToken: EXPECTED_TOKEN, + tokenType: "Bearer", + expiresIn: 3600, + }; + }, + mcpServers: { + [serverName]: { + type: "http", + url: `${oauthServer.url}/mcp`, + tools: ["*"], + oauthClientId: "sdk-e2e-client", + oauthPublicClient: true, + } as unknown as MCPServerConfig, + }, + }); + onTestFinished(() => disconnectSession(session)); + + await waitForMcpServerStatus(session, serverName); + + const tools = await session.rpc.mcp.listTools({ serverName }); + expect(tools.tools.map((tool) => tool.name)).toContain("whoami"); + + expect(authRequest).toMatchObject({ + requestId: expect.any(String), + serverName, + serverUrl: `${oauthServer.url}/mcp`, + wwwAuthenticateParams: { + resourceMetadataUrl: `${oauthServer.url}/.well-known/oauth-protected-resource`, + scope: "mcp.read", + error: "invalid_token", + }, + }); + + const requests = await oauthServer.requests(); + expect(requests.some((request) => request.authorization === null)).toBe(true); + expect( + requests.some((request) => request.authorization === `Bearer ${EXPECTED_TOKEN}`) + ).toBe(true); + }); +}); + +async function waitForMcpServerStatus( + session: CopilotSession, + serverName: string, + expectedStatus = "connected" +): Promise { + let lastStatus = ""; + await waitForCondition( + async () => { + const result = await session.rpc.mcp.list(); + const server = result.servers.find((entry) => entry.name === serverName); + lastStatus = server?.status ?? ""; + return server?.status === expectedStatus; + }, + { + timeoutMs: 60_000, + intervalMs: 200, + timeoutMessage: `${serverName} did not reach ${expectedStatus}; last status was ${lastStatus}`, + } + ); +} + +async function startOAuthMcpServer(): Promise<{ + url: string; + requests: () => Promise>; +}> { + const child = spawn(process.execPath, [TEST_MCP_OAUTH_SERVER], { + env: { ...process.env, EXPECTED_TOKEN }, + stdio: ["ignore", "pipe", "pipe"], + }); + onTestFinished(() => stopChild(child)); + + const stderr: string[] = []; + child.stderr.on("data", (chunk) => stderr.push(String(chunk))); + + const url = await new Promise((resolvePromise, reject) => { + const rl = createInterface({ input: child.stdout }); + const timeout = setTimeout(() => { + rl.close(); + reject(new Error(`Timed out waiting for OAuth MCP server. ${stderr.join("")}`)); + }, 10_000); + + child.once("exit", (code, signal) => { + clearTimeout(timeout); + rl.close(); + reject( + new Error( + `OAuth MCP server exited before listening. code=${code} signal=${signal} ${stderr.join("")}` + ) + ); + }); + + rl.on("line", (line) => { + const match = /^Listening: (.+)$/.exec(line); + if (!match) { + return; + } + clearTimeout(timeout); + rl.close(); + resolvePromise(match[1]); + }); + }); + + return { + url, + requests: async () => { + const response = await fetch(`${url}/__requests`); + if (!response.ok) { + throw new Error(`Failed to fetch OAuth MCP requests: ${response.status}`); + } + return response.json(); + }, + }; +} + +async function disconnectSession(session: CopilotSession): Promise { + try { + await session.disconnect(); + } catch { + // Best-effort cleanup. + } +} + +function stopChild(child: ChildProcessWithoutNullStreams): Promise { + if (child.exitCode !== null || child.killed) { + return Promise.resolve(); + } + const exitPromise = new Promise((resolvePromise) => { + child.once("exit", () => resolvePromise()); + }); + child.kill("SIGTERM"); + return exitPromise; +} diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 3f1a84d25..9d8cf355c 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -97,6 +97,12 @@ MCPHTTPServerConfig, MCPServerConfig, MCPStdioServerConfig, + McpAuthHandler, + McpAuthRequest, + McpAuthResult, + McpAuthStaticClientConfig, + McpAuthToken, + McpAuthWwwAuthenticateParams, PermissionHandler, PermissionNoResult, PermissionRequestResult, @@ -201,6 +207,12 @@ "MCPHTTPServerConfig", "MCPServerConfig", "MCPStdioServerConfig", + "McpAuthHandler", + "McpAuthRequest", + "McpAuthResult", + "McpAuthStaticClientConfig", + "McpAuthToken", + "McpAuthWwwAuthenticateParams", "ModelBilling", "ModelCapabilities", "ModelCapabilitiesOverride", diff --git a/python/copilot/client.py b/python/copilot/client.py index 7dcec6e8f..e10a25c83 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -87,6 +87,7 @@ InfiniteSessionConfig, LargeToolOutputConfig, MCPServerConfig, + McpAuthHandler, ProviderConfig, ReasoningEffort, ReasoningSummary, @@ -1604,6 +1605,7 @@ async def create_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2008,6 +2010,7 @@ def _initialize_session(sid: str) -> CopilotSession: s._register_tools(tools) s._register_commands(commands) s._register_permission_handler(on_permission_request) + s._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: s._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2048,6 +2051,11 @@ def _initialize_session(sid: str) -> CopilotSession: if local_session_id is not None: session = _initialize_session(local_session_id) registered_session_id = local_session_id + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": local_session_id, "eventType": "mcp.oauth_required"}, + ) try: rpc_start = time.perf_counter() @@ -2087,6 +2095,12 @@ def _register_inline(raw_response: Any) -> None: f"session.create returned sessionId {response.get('sessionId')} " f"but the caller requested {local_session_id}" ) + # Local IDs registered before create; server-assigned IDs can only register now. + if local_session_id is None and on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session.session_id, "eventType": "mcp.oauth_required"}, + ) session._workspace_path = response.get("workspacePath") capabilities = response.get("capabilities") session._set_capabilities(capabilities) @@ -2173,6 +2187,7 @@ async def resume_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2531,6 +2546,7 @@ async def resume_session( session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) + session._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: session._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2549,6 +2565,11 @@ async def resume_session( session.on(on_event) with self._sessions_lock: self._sessions[session_id] = session + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session_id, "eventType": "mcp.oauth_required"}, + ) log_timing( logger, logging.DEBUG, diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 69bf16007..58b0ef564 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -2575,6 +2575,34 @@ def to_dict(self) -> dict: result["serverName"] = from_str(self.server_name) return result + +class MCPOauthPendingRequestResponseKind(Enum): + CANCELLED = "cancelled" + TOKEN = "token" + + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class MCPOauthHandlePendingResult: + """Indicates whether the pending MCP OAuth response was accepted.""" + + success: bool + """Whether the response was accepted. False if the request was unknown, timed out, or + already resolved. + """ + + @staticmethod + def from_dict(obj: Any) -> 'MCPOauthHandlePendingResult': + assert isinstance(obj, dict) + success = from_bool(obj.get("success")) + return MCPOauthHandlePendingResult(success) + + def to_dict(self) -> dict: + result: dict = {} + result["success"] = from_bool(self.success) + return result + + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class MCPOauthLoginRequest: @@ -2648,6 +2676,100 @@ def to_dict(self) -> dict: result["authorizationUrl"] = from_union([from_str, from_none], self.authorization_url) return result + +class MCPOauthPendingRequestResponseCancelledKind(Enum): + CANCELLED = "cancelled" + + +@dataclass +class MCPOauthPendingRequestResponseCancelled: + "Schema for the `McpOauthPendingRequestResponseCancelled` type." + kind: MCPOauthPendingRequestResponseCancelledKind = MCPOauthPendingRequestResponseCancelledKind.CANCELLED + + @staticmethod + def from_dict(obj: Any) -> 'MCPOauthPendingRequestResponseCancelled': + assert isinstance(obj, dict) + kind = MCPOauthPendingRequestResponseCancelledKind(obj.get("kind")) + return MCPOauthPendingRequestResponseCancelled(kind) + + def to_dict(self) -> dict: + result: dict = {} + result["kind"] = to_enum(MCPOauthPendingRequestResponseCancelledKind, self.kind) + return result + + +class TypeEnum(Enum): + TOKEN = "token" + + +@dataclass +class MCPOauthPendingRequestResponseToken: + "Schema for the `McpOauthPendingRequestResponseToken` type." + access_token: str + kind: TypeEnum = TypeEnum.TOKEN + expires_in: int | None = None + refresh_token: str | None = None + token_type: str | None = None + + @staticmethod + def from_dict(obj: Any) -> 'MCPOauthPendingRequestResponseToken': + assert isinstance(obj, dict) + access_token = from_str(obj.get("accessToken")) + kind = TypeEnum(obj.get("kind")) + expires_in = from_union([from_int, from_none], obj.get("expiresIn")) + refresh_token = from_union([from_str, from_none], obj.get("refreshToken")) + token_type = from_union([from_str, from_none], obj.get("tokenType")) + return MCPOauthPendingRequestResponseToken(access_token, kind, expires_in, refresh_token, token_type) + + def to_dict(self) -> dict: + result: dict = {} + result["accessToken"] = from_str(self.access_token) + result["kind"] = to_enum(TypeEnum, self.kind) + if self.expires_in is not None: + result["expiresIn"] = from_union([from_int, from_none], self.expires_in) + if self.refresh_token is not None: + result["refreshToken"] = from_union([from_str, from_none], self.refresh_token) + if self.token_type is not None: + result["tokenType"] = from_union([from_str, from_none], self.token_type) + return result + + +MCPOauthPendingRequestResponse = Union[MCPOauthPendingRequestResponseCancelled, MCPOauthPendingRequestResponseToken] + + +def _load_mcp_oauth_pending_request_response(obj: Any) -> "MCPOauthPendingRequestResponse": + assert isinstance(obj, dict) + kind = obj.get("kind") + match kind: + case "cancelled": return MCPOauthPendingRequestResponseCancelled.from_dict(obj) + case "token": return MCPOauthPendingRequestResponseToken.from_dict(obj) + case _: raise ValueError(f"Unknown MCPOauthPendingRequestResponse kind: {kind!r}") + + +@dataclass +class MCPOauthHandlePendingRequest: + """Pending MCP OAuth request ID and host-provided token or cancellation response.""" + + request_id: str + """OAuth request identifier for the pending request.""" + + result: MCPOauthPendingRequestResponse + """Host response to the pending OAuth request.""" + + @staticmethod + def from_dict(obj: Any) -> 'MCPOauthHandlePendingRequest': + assert isinstance(obj, dict) + request_id = from_str(obj.get("requestId")) + result = _load_mcp_oauth_pending_request_response(obj.get("result")) + return MCPOauthHandlePendingRequest(request_id, result) + + def to_dict(self) -> dict: + result: dict = {} + result["requestId"] = from_str(self.request_id) + result["result"] = to_class(cast(Any, self.result), self.result) + return result + + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class MCPOauthRespondResult: @@ -22497,6 +22619,12 @@ def __init__(self, client: "JsonRpcClient", session_id: str): self._client = client self._session_id = session_id + async def handle_pending_request(self, params: MCPOauthHandlePendingRequest, *, timeout: float | None = None) -> MCPOauthHandlePendingResult: + "Resolves a pending MCP OAuth request with a host-provided token or cancellation.\n\nArgs:\n params: Pending MCP OAuth request ID and host-provided token or cancellation response.\n\nReturns:\n Indicates whether the pending MCP OAuth response was accepted." + params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} + params_dict["sessionId"] = self._session_id + return MCPOauthHandlePendingResult.from_dict(await self._client.request("session.mcp.oauth.handlePendingRequest", params_dict, **_timeout_kwargs(timeout))) + async def login(self, params: MCPOauthLoginRequest, *, timeout: float | None = None) -> MCPOauthLoginResult: "Starts OAuth authentication for a remote MCP server.\n\nArgs:\n params: Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy.\n\nReturns:\n OAuth authorization URL the caller should open, or empty when cached tokens already authenticated the server." params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} @@ -23704,8 +23832,15 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "MCPIsServerRunningResult", "MCPListToolsRequest", "MCPListToolsResult", + "MCPOauthHandlePendingRequest", + "MCPOauthHandlePendingResult", "MCPOauthLoginRequest", "MCPOauthLoginResult", + "MCPOauthPendingRequestResponse", + "MCPOauthPendingRequestResponseCancelled", + "MCPOauthPendingRequestResponseCancelledKind", + "MCPOauthPendingRequestResponseKind", + "MCPOauthPendingRequestResponseToken", "MCPOauthRespondRequest", "MCPOauthRespondResult", "MCPRegisterExternalClientRequest", diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index f2e155f49..52fd43512 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -2195,21 +2195,31 @@ def to_dict(self) -> dict: return result +class McpOauthCompletedOutcome(Enum): + CANCELLED = "cancelled" + TIMEOUT = "timeout" + TOKEN = "token" + + @dataclass class McpOauthCompletedData: "MCP OAuth request completion notification" + outcome: McpOauthCompletedOutcome request_id: str @staticmethod def from_dict(obj: Any) -> "McpOauthCompletedData": assert isinstance(obj, dict) + outcome = parse_enum(McpOauthCompletedOutcome, obj.get("outcome")) request_id = from_str(obj.get("requestId")) return McpOauthCompletedData( + outcome=outcome, request_id=request_id, ) def to_dict(self) -> dict: result: dict = {} + result["outcome"] = to_enum(McpOauthCompletedOutcome, self.outcome) result["requestId"] = from_str(self.request_id) return result @@ -2220,6 +2230,7 @@ class McpOauthRequiredData: request_id: str server_name: str server_url: str + www_authenticate_params: McpOauthRequiredWwwAuthenticateParams static_client_config: McpOauthRequiredStaticClientConfig | None = None @staticmethod @@ -2228,11 +2239,13 @@ def from_dict(obj: Any) -> "McpOauthRequiredData": request_id = from_str(obj.get("requestId")) server_name = from_str(obj.get("serverName")) server_url = from_str(obj.get("serverUrl")) + www_authenticate_params = McpOauthRequiredWwwAuthenticateParams.from_dict(obj.get("wwwAuthenticateParams")) static_client_config = from_union([from_none, McpOauthRequiredStaticClientConfig.from_dict], obj.get("staticClientConfig")) return McpOauthRequiredData( request_id=request_id, server_name=server_name, server_url=server_url, + www_authenticate_params=www_authenticate_params, static_client_config=static_client_config, ) @@ -2241,6 +2254,7 @@ def to_dict(self) -> dict: result["requestId"] = from_str(self.request_id) result["serverName"] = from_str(self.server_name) result["serverUrl"] = from_str(self.server_url) + result["wwwAuthenticateParams"] = to_class(McpOauthRequiredWwwAuthenticateParams, self.www_authenticate_params) if self.static_client_config is not None: result["staticClientConfig"] = from_union([from_none, lambda x: to_class(McpOauthRequiredStaticClientConfig, x)], self.static_client_config) return result @@ -2276,7 +2290,36 @@ def to_dict(self) -> dict: @dataclass -class McpServersLoadedServer: +class McpOauthRequiredWwwAuthenticateParams: + "Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery." + resource_metadata_url: str + error: str | None = None + scope: str | None = None + + @staticmethod + def from_dict(obj: Any) -> "McpOauthRequiredWwwAuthenticateParams": + assert isinstance(obj, dict) + resource_metadata_url = from_str(obj.get("resourceMetadataUrl")) + error = from_union([from_none, from_str], obj.get("error")) + scope = from_union([from_none, from_str], obj.get("scope")) + return McpOauthRequiredWwwAuthenticateParams( + resource_metadata_url=resource_metadata_url, + error=error, + scope=scope, + ) + + def to_dict(self) -> dict: + result: dict = {} + result["resourceMetadataUrl"] = from_str(self.resource_metadata_url) + if self.error is not None: + result["error"] = from_union([from_none, from_str], self.error) + if self.scope is not None: + result["scope"] = from_union([from_none, from_str], self.scope) + return result + + +@dataclass +class MCPServersLoadedServer: "Schema for the `McpServersLoadedServer` type." name: str status: McpServerStatus @@ -7400,8 +7443,10 @@ def session_event_to_dict(x: SessionEvent) -> Any: "McpAppToolCallCompleteToolMeta", "McpAppToolCallCompleteToolMetaUI", "McpOauthCompletedData", + "McpOauthCompletedOutcome", "McpOauthRequiredData", "McpOauthRequiredStaticClientConfig", + "McpOauthRequiredWwwAuthenticateParams", "McpServerSource", "McpServerStatus", "McpServerTransport", diff --git a/python/copilot/session.py b/python/copilot/session.py index 32201870c..672dfa20c 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -39,6 +39,9 @@ ExternalToolTextResultForLlm, HandlePendingToolCallRequest, LogRequest, + MCPOauthHandlePendingRequest, + MCPOauthPendingRequestResponseCancelled, + MCPOauthPendingRequestResponseToken, ModelSwitchToRequest, PermissionDecision, PermissionDecisionApproveOnce, @@ -65,6 +68,7 @@ CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + McpOauthRequiredData, PermissionRequest, PermissionRequestedData, SessionCanvasClosedData, @@ -294,6 +298,62 @@ def approve_all( return PermissionDecisionApproveOnce() +# ============================================================================ +# MCP Auth Types +# ============================================================================ + + +class McpAuthWwwAuthenticateParams(TypedDict, total=False): + """Parsed parameters from an MCP server's WWW-Authenticate response.""" + + resourceMetadataUrl: Required[str] + scope: str + error: str + + +class McpAuthStaticClientConfig(TypedDict, total=False): + """Static OAuth client configuration supplied by the MCP server, if available.""" + + clientId: Required[str] + grantType: Literal["client_credentials"] + publicClient: bool + + +class McpAuthRequest(TypedDict, total=False): + """MCP OAuth request that the SDK host can satisfy with a host-acquired token.""" + + requestId: Required[str] + serverName: Required[str] + serverUrl: Required[str] + wwwAuthenticateParams: Required[McpAuthWwwAuthenticateParams] + staticClientConfig: McpAuthStaticClientConfig + + +class McpAuthToken(TypedDict, total=False): + """Host-provided OAuth token data for a pending MCP OAuth request.""" + + accessToken: Required[str] + tokenType: str + refreshToken: str + expiresIn: int + + +class McpAuthResult(TypedDict, total=False): + """Result returned by an MCP auth request handler.""" + + kind: Required[Literal["token", "cancelled"]] + accessToken: str + tokenType: str + refreshToken: str + expiresIn: int + + +McpAuthHandler = Callable[ + [McpAuthRequest, dict[str, str]], + McpAuthResult | McpAuthToken | None | Awaitable[McpAuthResult | McpAuthToken | None], +] + + # ============================================================================ # User Input Request Types # ============================================================================ @@ -1121,6 +1181,8 @@ def __init__( self._tool_handlers_lock = threading.Lock() self._permission_handler: _PermissionHandlerFn | None = None self._permission_handler_lock = threading.Lock() + self._mcp_auth_handler: McpAuthHandler | None = None + self._mcp_auth_handler_lock = threading.Lock() self._user_input_handler: UserInputHandler | None = None self._user_input_handler_lock = threading.Lock() self._exit_plan_mode_handler: ExitPlanModeHandler | None = None @@ -1508,6 +1570,42 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: ) ) + case McpOauthRequiredData() as data: + with self._mcp_auth_handler_lock: + handler = self._mcp_auth_handler + if not data.request_id: + return + if not handler: + logger.warning( + "Received MCP OAuth request without a registered MCP auth handler. " + "SessionId=%s, RequestId=%s", + self.session_id, + data.request_id, + ) + return + request: McpAuthRequest = { + "requestId": data.request_id, + "serverName": data.server_name, + "serverUrl": data.server_url, + "wwwAuthenticateParams": { + "resourceMetadataUrl": data.www_authenticate_params.resource_metadata_url, + }, + } + if data.www_authenticate_params.scope is not None: + request["wwwAuthenticateParams"]["scope"] = data.www_authenticate_params.scope + if data.www_authenticate_params.error is not None: + request["wwwAuthenticateParams"]["error"] = data.www_authenticate_params.error + if data.static_client_config is not None: + static_client_config: McpAuthStaticClientConfig = { + "clientId": data.static_client_config.client_id, + } + if data.static_client_config.grant_type is not None: + static_client_config["grantType"] = data.static_client_config.grant_type + if data.static_client_config.public_client is not None: + static_client_config["publicClient"] = data.static_client_config.public_client + request["staticClientConfig"] = static_client_config + asyncio.ensure_future(self._execute_mcp_auth_and_respond(request, handler)) + case CommandExecuteData() as data: request_id = data.request_id command_name = data.command_name @@ -1726,6 +1824,53 @@ async def _execute_permission_and_respond( except (JsonRpcError, ProcessExitedError, OSError): pass # Connection lost or RPC error — nothing we can do + async def _execute_mcp_auth_and_respond( + self, + request: McpAuthRequest, + handler: McpAuthHandler, + ) -> None: + """Execute an MCP auth handler and respond via RPC.""" + request_id = request["requestId"] + try: + handler_start = time.perf_counter() + result = handler(request, {"session_id": self.session_id}) + if inspect.isawaitable(result): + result = await result + log_timing( + logger, + logging.DEBUG, + "CopilotSession._execute_mcp_auth_and_respond dispatch", + handler_start, + session_id=self.session_id, + request_id=request_id, + ) + + if result and result.get("kind", "token") == "token": + rpc_result = MCPOauthPendingRequestResponseToken( + access_token=result["accessToken"], + expires_in=result.get("expiresIn"), + refresh_token=result.get("refreshToken"), + token_type=result.get("tokenType"), + ) + else: + rpc_result = MCPOauthPendingRequestResponseCancelled() + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=rpc_result, + ) + ) + except Exception: + try: + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=MCPOauthPendingRequestResponseCancelled(), + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost or RPC error — nothing we can do + async def _execute_command_and_respond( self, request_id: str, @@ -1888,6 +2033,11 @@ def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> N with self._elicitation_handler_lock: self._elicitation_handler = handler + def _register_mcp_auth_handler(self, handler: McpAuthHandler | None) -> None: + """Register the MCP auth handler for this session.""" + with self._mcp_auth_handler_lock: + self._mcp_auth_handler = handler + def _register_exit_plan_mode_handler(self, handler: ExitPlanModeHandler | None) -> None: """Register the exit-plan-mode handler for this session.""" with self._exit_plan_mode_handler_lock: diff --git a/python/test_client.py b/python/test_client.py index 502d410ab..dd4f90017 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -4,6 +4,7 @@ This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead. """ +import asyncio from datetime import UTC, datetime from unittest.mock import AsyncMock, patch @@ -24,6 +25,12 @@ ModelSupports, ) from copilot.session import PermissionHandler +from copilot.session_events import ( + McpOauthRequiredData, + McpOauthRequiredWwwAuthenticateParams, + SessionEvent, + SessionEventType, +) from e2e.testharness import CLI_PATH @@ -63,6 +70,254 @@ async def test_resume_session_allows_none_permission_handler(self): class TestCreateSessionConfig: + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_in_create_session(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + interest_method, interest_payload = captured[0] + create_method, create_payload = captured[1] + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload["eventType"] == "mcp.oauth_required" + assert create_method == "session.create" + assert interest_payload["sessionId"] == create_payload["sessionId"] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_interest_is_not_registered_without_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + await client.resume_session( + "session-without-auth", + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + + assert session.session_id + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + assert any( + method == "session.create" and params["requestPermission"] is True + for method, params in captured + ) + assert any( + method == "session.resume" and params["requestPermission"] is True + for method, params in captured + ) + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_before_resume(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + await client.resume_session( + "session-with-auth", + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + interest_method, interest_payload = captured[0] + resume_method, resume_payload = captured[1] + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "session-with-auth", + "eventType": "mcp.oauth_required", + } + assert resume_method == "session.resume" + assert resume_payload["requestPermission"] is True + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_after_cloud_create_only_with_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + create_count = 0 + + async def mock_request(method, params, **kwargs): + nonlocal create_count + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + create_count += 1 + result = { + "sessionId": f"server-assigned-session-{create_count}", + "workspacePath": None, + } + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + cloud = CloudSessionOptions( + repository=CloudSessionRepository( + owner="github", + name="copilot-sdk", + branch="main", + ) + ) + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + cloud=cloud, + ) + + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + + captured.clear() + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + cloud=cloud, + ) + + create_method, _create_payload = captured[0] + interest_method, interest_payload = captured[1] + assert create_method == "session.create" + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "server-assigned-session-2", + "eventType": "mcp.oauth_required", + } + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_required_event_sends_host_token(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + if method == "session.mcp.oauth.handlePendingRequest": + captured.append((method, params)) + return {"success": True} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + return {} + + client._client.request = mock_request + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: { + "accessToken": "host-token", + "tokenType": "Bearer", + }, + ) + + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request", + server_name="oauth-server", + server_url="https://example.com/mcp", + www_authenticate_params=McpOauthRequiredWwwAuthenticateParams( + resource_metadata_url="https://example.com/.well-known/oauth-protected-resource" + ), + ), + id="evt-1", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if captured: + break + await asyncio.sleep(0.005) + + assert captured == [ + ( + "session.mcp.oauth.handlePendingRequest", + { + "sessionId": session.session_id, + "requestId": "oauth-request", + "result": { + "kind": "token", + "accessToken": "host-token", + "tokenType": "Bearer", + }, + }, + ) + ] + finally: + await client.force_stop() + @pytest.mark.asyncio async def test_create_session_forwards_cloud_options(self): client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs index f52c06651..49889c0d5 100644 --- a/rust/src/generated/api_types.rs +++ b/rust/src/generated/api_types.rs @@ -292,6 +292,9 @@ pub mod rpc_methods { pub const SESSION_MCP_ISSERVERRUNNING: &str = "session.mcp.isServerRunning"; /// `session.mcp.oauth.respond` pub const SESSION_MCP_OAUTH_RESPOND: &str = "session.mcp.oauth.respond"; + /// `session.mcp.oauth.handlePendingRequest` + pub const SESSION_MCP_OAUTH_HANDLEPENDINGREQUEST: &str = + "session.mcp.oauth.handlePendingRequest"; /// `session.mcp.oauth.login` pub const SESSION_MCP_OAUTH_LOGIN: &str = "session.mcp.oauth.login"; /// `session.mcp.apps.readResource` @@ -4077,6 +4080,79 @@ pub struct McpListToolsResult { pub tools: Vec, } +/// Schema for the `McpOauthPendingRequestResponseToken` type. +/// +///

+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthPendingRequestResponseToken { + /// Access token acquired by the SDK host + pub access_token: String, + /// Token lifetime in seconds, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_in: Option, + /// Supplies a host-acquired OAuth access token. + pub kind: McpOauthPendingRequestResponseTokenKind, + /// Refresh token supplied by the host, if available. + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + /// OAuth token type. Defaults to Bearer when omitted. + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type: Option, +} + +/// Schema for the `McpOauthPendingRequestResponseCancelled` type. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthPendingRequestResponseCancelled { + /// Declines or cancels the pending OAuth request. + pub kind: McpOauthPendingRequestResponseCancelledKind, +} + +/// Pending MCP OAuth request ID and host-provided token or cancellation response. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthHandlePendingRequest { + /// OAuth request identifier for the pending request. + pub request_id: RequestId, + /// Host response to the pending OAuth request. + pub result: McpOauthPendingRequestResponse, +} + +/// Indicates whether the pending MCP OAuth response was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthHandlePendingResult { + /// Whether the response was accepted. False if the request was unknown, timed out, or already resolved. + pub success: bool, +} + /// Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy. /// ///
@@ -4132,7 +4208,7 @@ pub(crate) struct McpOauthRespondRequest { #[doc(hidden)] #[serde(skip_serializing_if = "Option::is_none")] pub(crate) provider: Option, - /// OAuth request identifier from mcp.oauth_required + /// OAuth request identifier for the pending request. pub request_id: RequestId, } @@ -4464,6 +4540,21 @@ pub(crate) struct McpUnregisterExternalClientRequest { pub server_name: String, } +/// Memory configuration for this session. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MemoryConfiguration { + /// Whether memory is enabled for the session. + pub enabled: bool, +} + /// Model identifier and token limits used to compute the context-info breakdown. /// ///
@@ -7406,7 +7497,7 @@ pub struct QueueRemoveMostRecentResult { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RegisterEventInterestParams { - /// The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. + /// The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. pub event_type: String, } @@ -9129,6 +9220,9 @@ pub struct SessionOpenOptions { /// Identifier sent to LSP-style integrations. #[serde(skip_serializing_if = "Option::is_none")] pub lsp_client_name: Option, + /// Memory configuration for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub memory: Option, /// Initial model identifier. #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, @@ -14152,6 +14246,21 @@ pub struct SessionMcpIsServerRunningResult { #[serde(rename_all = "camelCase")] pub struct SessionMcpOauthRespondResult {} +/// Indicates whether the pending MCP OAuth response was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpOauthHandlePendingRequestResult { + /// Whether the response was accepted. False if the request was unknown, timed out, or already resolved. + pub success: bool, +} + /// OAuth authorization URL the caller should open, or empty when cached tokens already authenticated the server. /// ///
@@ -16934,6 +17043,37 @@ pub enum McpAppsSetHostContextDetailsTheme { Unknown, } +/// Supplies a host-acquired OAuth access token. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpOauthPendingRequestResponseTokenKind { + #[serde(rename = "token")] + #[default] + Token, +} + +/// Declines or cancels the pending OAuth request. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpOauthPendingRequestResponseCancelledKind { + #[serde(rename = "cancelled")] + #[default] + Cancelled, +} + +/// Host response to the pending OAuth request. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum McpOauthPendingRequestResponse { + Token(McpOauthPendingRequestResponseToken), + Cancelled(McpOauthPendingRequestResponseCancelled), +} + /// Outcome of the sampling inference. 'success' produced a response; 'failure' encountered an error (including agent-side rejection by content filter or criteria); 'cancelled' the caller cancelled this execution via cancelSamplingExecution. /// ///
diff --git a/rust/src/generated/rpc.rs b/rust/src/generated/rpc.rs index cd5132f48..928d48262 100644 --- a/rust/src/generated/rpc.rs +++ b/rust/src/generated/rpc.rs @@ -4325,7 +4325,7 @@ pub struct SessionRpcMcpOauth<'a> { } impl<'a> SessionRpcMcpOauth<'a> { - /// Responds to a pending MCP OAuth provider request. Marked internal because the `provider` argument is an in-process OAuthClientProvider instance that cannot be carried over the wire; the public OAuth surface will route the response through a wire-clean handshake once the CLI moves on top of the SDK. + /// Responds to a pending MCP OAuth request with an in-process provider. Conceptually similar to handlePendingRequest, but marked internal because this legacy CLI-only path takes a live OAuthClientProvider instance that cannot be carried over the wire. Once the CLI is replatformed on the SDK and can use handlePendingRequest, this API should be removed. /// /// Wire method: `session.mcp.oauth.respond`. /// @@ -4358,6 +4358,42 @@ impl<'a> SessionRpcMcpOauth<'a> { Ok(serde_json::from_value(_value)?) } + /// Resolves a pending MCP OAuth request with a host-provided token or cancellation. + /// + /// Wire method: `session.mcp.oauth.handlePendingRequest`. + /// + /// # Parameters + /// + /// * `params` - Pending MCP OAuth request ID and host-provided token or cancellation response. + /// + /// # Returns + /// + /// Indicates whether the pending MCP OAuth response was accepted. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn handle_pending_request( + &self, + params: McpOauthHandlePendingRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_MCP_OAUTH_HANDLEPENDINGREQUEST, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + /// Starts OAuth authentication for a remote MCP server. /// /// Wire method: `session.mcp.oauth.login`. diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs index e20d7d6ef..71e1613d8 100644 --- a/rust/src/generated/session_events.rs +++ b/rust/src/generated/session_events.rs @@ -2871,11 +2871,25 @@ pub struct McpOauthRequiredStaticClientConfig { pub public_client: Option, } +/// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthRequiredWwwAuthenticateParams { + /// Parsed OAuth error from the WWW-Authenticate header, if present + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Parsed resource_metadata URL from the WWW-Authenticate header + pub resource_metadata_url: String, + /// Parsed OAuth scope from the WWW-Authenticate header, if present + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option, +} + /// Session event "mcp.oauth_required". OAuth authentication request for an MCP server #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpOauthRequiredData { - /// Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() + /// Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest pub request_id: RequestId, /// Display name of the MCP server that requires OAuth pub server_name: String, @@ -2884,12 +2898,16 @@ pub struct McpOauthRequiredData { /// Static OAuth client configuration, if the server specifies one #[serde(skip_serializing_if = "Option::is_none")] pub static_client_config: Option, + /// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + pub www_authenticate_params: McpOauthRequiredWwwAuthenticateParams, } /// Session event "mcp.oauth_completed". MCP OAuth request completion notification #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct McpOauthCompletedData { + /// How the pending OAuth request was completed + pub outcome: McpOauthCompletedOutcome, /// Request ID of the resolved OAuth request pub request_id: RequestId, } @@ -4232,6 +4250,24 @@ pub enum McpOauthRequiredStaticClientConfigGrantType { ClientCredentials, } +/// How the pending OAuth request was completed +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpOauthCompletedOutcome { + /// The pending OAuth request was resolved with a host-provided token/provider. + #[serde(rename = "token")] + Token, + /// The pending OAuth request was cancelled or declined without a token/provider. + #[serde(rename = "cancelled")] + Cancelled, + /// The pending OAuth request timed out before any client responded. + #[serde(rename = "timeout")] + Timeout, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// The user's auto-mode-switch choice #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum AutoModeSwitchResponse { diff --git a/rust/src/handler.rs b/rust/src/handler.rs index dadd1706f..ff3781ffe 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -19,8 +19,13 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::generated::api_types::{ - PermissionDecision, PermissionDecisionApproveOnce, PermissionDecisionReject, - PermissionDecisionUserNotAvailable, + McpOauthPendingRequestResponse, McpOauthPendingRequestResponseCancelled, + McpOauthPendingRequestResponseCancelledKind, McpOauthPendingRequestResponseToken, + McpOauthPendingRequestResponseTokenKind, PermissionDecision, PermissionDecisionApproveOnce, + PermissionDecisionReject, PermissionDecisionUserNotAvailable, +}; +use crate::session_events::{ + McpOauthRequiredStaticClientConfig, McpOauthRequiredWwwAuthenticateParams, }; use crate::types::{ ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, @@ -159,6 +164,73 @@ pub trait ElicitationHandler: Send + Sync + 'static { ) -> ElicitationResult; } +/// MCP OAuth request that the SDK host can satisfy with a host-acquired token. +#[derive(Debug, Clone)] +pub struct McpAuthRequest { + /// Display name of the MCP server that requires OAuth. + pub server_name: String, + /// URL of the MCP server that requires OAuth. + pub server_url: String, + /// Parsed WWW-Authenticate parameters from the MCP server. + pub www_authenticate_params: McpOauthRequiredWwwAuthenticateParams, + /// Static OAuth client configuration, if the server specifies one. + pub static_client_config: Option, +} + +/// Result returned by an MCP auth request handler. +#[derive(Debug, Clone)] +pub enum McpAuthResult { + /// Supplies host-acquired OAuth token data. + Token { + /// Access token acquired by the SDK host. + access_token: String, + /// OAuth token type. Defaults to Bearer when omitted. + token_type: Option, + /// Refresh token supplied by the host, if available. + refresh_token: Option, + /// Token lifetime in seconds, if known. + expires_in: Option, + }, + /// Declines or cancels the pending OAuth request. + Cancelled, +} + +impl McpAuthResult { + pub(crate) fn into_wire(self) -> McpOauthPendingRequestResponse { + match self { + Self::Token { + access_token, + token_type, + refresh_token, + expires_in, + } => McpOauthPendingRequestResponse::Token(McpOauthPendingRequestResponseToken { + access_token, + token_type, + refresh_token, + expires_in, + kind: McpOauthPendingRequestResponseTokenKind::Token, + }), + Self::Cancelled => { + McpOauthPendingRequestResponse::Cancelled(McpOauthPendingRequestResponseCancelled { + kind: McpOauthPendingRequestResponseCancelledKind::Cancelled, + }) + } + } + } +} + +/// Handler for MCP server OAuth requests. +#[async_trait] +pub trait McpAuthHandler: Send + Sync + 'static { + /// Resolve an MCP OAuth request with host token data or cancellation. + async fn handle( + &self, + session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult; +} + /// Handler for `user_input.requested` events from the `ask_user` tool. /// /// When unset, `requestUserInput: false` goes on the wire and the @@ -266,4 +338,24 @@ mod tests { PermissionResult::Decision(PermissionDecision::Reject(_)) )); } + + #[test] + fn mcp_auth_result_token_converts_to_wire_response() { + let wire = McpAuthResult::Token { + access_token: "host-token".to_string(), + token_type: Some("Bearer".to_string()), + refresh_token: None, + expires_in: Some(3600), + } + .into_wire(); + + match wire { + McpOauthPendingRequestResponse::Token(token) => { + assert_eq!(token.access_token, "host-token"); + assert_eq!(token.token_type.as_deref(), Some("Bearer")); + assert_eq!(token.expires_in, Some(3600)); + } + McpOauthPendingRequestResponse::Cancelled(_) => panic!("expected token response"), + } + } } diff --git a/rust/src/session.rs b/rust/src/session.rs index f387b8627..5103f8413 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -11,14 +11,17 @@ use tokio_util::sync::CancellationToken; use tracing::{Instrument, warn}; use crate::canvas::CanvasHandler; -use crate::generated::api_types::{LogRequest, ModelSwitchToRequest, OpenCanvasInstance}; +use crate::generated::api_types::{ + LogRequest, ModelSwitchToRequest, OpenCanvasInstance, RegisterEventInterestParams, rpc_methods, +}; use crate::generated::session_events::{ - CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, McpOauthRequiredData, SessionCanvasClosedData, SessionErrorData, SessionEventType, }; use crate::handler::{ AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, - PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, + McpAuthHandler, McpAuthRequest, McpAuthResult, PermissionHandler, PermissionResult, + UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; use crate::session_fs::SessionFsProvider; @@ -48,6 +51,7 @@ use crate::{ pub(crate) struct SessionHandlers { pub permission: Option>, pub elicitation: Option>, + pub mcp_auth: Option>, pub user_input: Option>, pub exit_plan_mode: Option>, pub auto_mode_switch: Option>, @@ -879,6 +883,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -892,6 +897,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -931,6 +937,9 @@ impl Client { { let channels = self.register_session(sid); *inline_stash.lock() = Some((sid.clone(), channels)); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, sid).await?; + } None } else { let client = self.clone(); @@ -1026,6 +1035,10 @@ impl Client { "Client::create_session local setup complete" ); *capabilities.write() = create_result.capabilities.unwrap_or_default(); + // Local IDs registered before create; server-assigned IDs can only register now. + if has_mcp_auth_handler && local_session_id.is_none() { + register_mcp_auth_interest(self, &session_id).await?; + } tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), @@ -1134,6 +1147,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -1147,6 +1161,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1164,6 +1179,9 @@ impl Client { let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, &session_id).await?; + } let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); let setup_start = Instant::now(); @@ -1468,6 +1486,17 @@ fn notification_permission_payload(result: &PermissionResult) -> Option { } } +async fn register_mcp_auth_interest(client: &Client, session_id: &SessionId) -> Result<(), Error> { + let mut params = serde_json::to_value(RegisterEventInterestParams { + event_type: "mcp.oauth_required".to_string(), + })?; + params["sessionId"] = Value::String(session_id.to_string()); + client + .call(rpc_methods::SESSION_EVENTLOG_REGISTERINTEREST, Some(params)) + .await?; + Ok(()) +} + fn tool_failure_result(message: impl Into) -> ToolResult { let message = message.into(); ToolResult::Expanded(ToolResultExpanded { @@ -1935,6 +1964,88 @@ async fn handle_notification( .instrument(span), ); } + SessionEventType::McpOauthRequired => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else { + warn!( + session_id = %sid, + request_id = %request_id, + "received MCP OAuth request without a registered MCP auth handler" + ); + return; + }; + let data: McpOauthRequiredData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize MCP OAuth request"); + return; + } + }; + let request = McpAuthRequest { + server_name: data.server_name, + server_url: data.server_url, + www_authenticate_params: data.www_authenticate_params, + static_client_config: data.static_client_config, + }; + let client = client.clone(); + let sid = session_id.clone(); + let span = tracing::error_span!( + "mcp_auth_request_handler", + session_id = %sid, + request_id = %request_id + ); + tokio::spawn( + async move { + let cancel = McpAuthResult::Cancelled; + let handler_task = tokio::spawn({ + let sid = sid.clone(); + let request_id = request_id.clone(); + let span = tracing::error_span!( + "mcp_auth_callback", + session_id = %sid, + request_id = %request_id + ); + async move { + let handler_start = Instant::now(); + let response = mcp_auth_handler + .handle(sid.clone(), request_id.clone(), request) + .await; + tracing::debug!( + elapsed_ms = handler_start.elapsed().as_millis(), + session_id = %sid, + request_id = %request_id, + "McpAuthHandler::handle dispatch" + ); + response + } + .instrument(span) + }); + let result = match handler_task.await { + Ok(result) => result, + Err(_) => cancel, + }; + let rpc_start = Instant::now(); + let _ = client + .call( + "session.mcp.oauth.handlePendingRequest", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result.into_wire(), + })), + ) + .await; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Session::handle_notification MCP auth response sent" + ); + } + .instrument(span), + ); + } SessionEventType::CommandExecute => { let data: CommandExecuteData = match serde_json::from_value(notification.event.data.clone()) { diff --git a/rust/src/types.rs b/rust/src/types.rs index 8b9b5960a..b7c7e4be1 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -18,8 +18,8 @@ use crate::generated::api_types::OpenCanvasInstance; pub use crate::generated::session_events::ContextTier; use crate::generated::session_events::ReasoningSummary; use crate::handler::{ - AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler, - UserInputHandler, + AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, McpAuthHandler, + PermissionHandler, UserInputHandler, }; use crate::hooks::SessionHooks; pub use crate::session_fs::{ @@ -1330,6 +1330,9 @@ pub struct SessionConfig { /// Optional elicitation-request handler. When `None`, /// `requestElicitation: false` goes on the wire. pub elicitation_handler: Option>, + /// Optional MCP OAuth request handler. When set, the SDK can satisfy MCP + /// server OAuth requests with host-acquired token data or cancellation. + pub mcp_auth_handler: Option>, /// Optional user-input handler. When `None`, /// `requestUserInput: false` goes on the wire and the `ask_user` /// tool is disabled. @@ -1456,6 +1459,14 @@ impl std::fmt::Debug for SessionConfig { "elicitation_handler", &self.elicitation_handler.as_ref().map(|_| ""), ) + .field( + "mcp_auth_handler", + &self.mcp_auth_handler.as_ref().map(|_| ""), + ) + .field( + "mcp_auth_handler", + &self.mcp_auth_handler.as_ref().map(|_| ""), + ) .field( "user_input_handler", &self.user_input_handler.as_ref().map(|_| ""), @@ -1540,6 +1551,7 @@ impl Default for SessionConfig { session_fs_provider: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -1563,6 +1575,7 @@ pub(crate) struct SessionConfigRuntime { pub permission_handler: Option>, pub permission_policy: Option, pub elicitation_handler: Option>, + pub mcp_auth_handler: Option>, pub user_input_handler: Option>, pub exit_plan_mode_handler: Option>, pub auto_mode_switch_handler: Option>, @@ -1685,6 +1698,7 @@ impl SessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -1714,6 +1728,12 @@ impl SessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`]. Required for the `ask_user` tool /// to be enabled. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { @@ -2323,6 +2343,8 @@ pub struct ResumeSessionConfig { /// Optional elicitation handler. See /// [`SessionConfig::elicitation_handler`]. pub elicitation_handler: Option>, + /// Optional MCP OAuth handler. See [`SessionConfig::mcp_auth_handler`]. + pub mcp_auth_handler: Option>, /// Optional user-input handler. See /// [`SessionConfig::user_input_handler`]. pub user_input_handler: Option>, @@ -2565,6 +2587,7 @@ impl ResumeSessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -2638,6 +2661,7 @@ impl ResumeSessionConfig { continue_pending_work: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -2663,6 +2687,12 @@ impl ResumeSessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`] for the resumed session. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { self.user_input_handler = Some(handler); diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 244885697..3a3c6a136 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -9,7 +9,8 @@ use async_trait::async_trait; use github_copilot_sdk::canvas::{CanvasDeclaration, CanvasHandler, CanvasResult}; use github_copilot_sdk::handler::{ ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, - ExitPlanModeHandler, ExitPlanModeResult, UserInputHandler, UserInputResponse, + ExitPlanModeHandler, ExitPlanModeResult, McpAuthHandler, McpAuthRequest, McpAuthResult, + UserInputHandler, UserInputResponse, }; use github_copilot_sdk::rpc::{ CanvasInstanceAvailability, CanvasProviderInvokeActionRequest, CanvasProviderOpenRequest, @@ -17,9 +18,10 @@ use github_copilot_sdk::rpc::{ }; use github_copilot_sdk::session_events::ReasoningSummary; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, - ElicitationResult, ExitPlanModeData, ExtensionInfo, MessageOptions, RequestId, SessionConfig, - SessionId, SetModelOptions, Tool, ToolInvocation, ToolResult, + CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, + DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, + MessageOptions, RequestId, SessionConfig, SessionId, SetModelOptions, Tool, ToolInvocation, + ToolResult, }; use github_copilot_sdk::{Client, ContextTier, tool}; use serde_json::Value; @@ -30,6 +32,20 @@ const TIMEOUT: Duration = Duration::from_secs(2); struct TestCanvasHandler; +struct CancelMcpAuthHandler; + +#[async_trait] +impl McpAuthHandler for CancelMcpAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: McpAuthRequest, + ) -> McpAuthResult { + McpAuthResult::Cancelled + } +} + #[async_trait] impl CanvasHandler for TestCanvasHandler { async fn on_open( @@ -226,6 +242,245 @@ fn requested_session_id(request: &Value) -> &str { .expect("session request should include sessionId") } +#[tokio::test] +async fn create_session_registers_mcp_auth_interest_only_with_handler() { + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn cloud_create_session_registers_mcp_auth_interest_after_create_only_with_handler() { + let cloud = || { + CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), + ) + }; + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-1").await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-2").await; + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!( + interest_req["params"]["sessionId"], + "server-assigned-session-2" + ); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn resume_session_registers_mcp_auth_interest_only_with_handler() { + use github_copilot_sdk::types::ResumeSessionConfig; + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-without-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-without-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-with-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-with-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let _session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); +} + +async fn server_respond_create( + writer: &mut (impl AsyncWrite + Unpin), + request: &Value, + session_id: &str, +) { + let id = request["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": session_id, "workspacePath": "/tmp/workspace" }, + })) + .unwrap(), + ) + .await; +} + +async fn respond_to_reload( + reader: &mut (impl tokio::io::AsyncRead + Unpin), + writer: &mut (impl AsyncWrite + Unpin), +) { + let reload = read_framed(reader).await; + assert_eq!(reload["method"], "session.skills.reload"); + let id = reload["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} })) + .unwrap(), + ) + .await; +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair().await; diff --git a/test/harness/test-mcp-oauth-server.mjs b/test/harness/test-mcp-oauth-server.mjs new file mode 100644 index 000000000..3a642b55a --- /dev/null +++ b/test/harness/test-mcp-oauth-server.mjs @@ -0,0 +1,216 @@ +#!/usr/bin/env node +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +/** + * Minimal OAuth-protected Streamable HTTP MCP server for SDK E2E tests. + * + * The `/mcp` endpoint returns a WWW-Authenticate challenge until requests include + * `Authorization: Bearer `, then serves enough JSON-RPC MCP + * methods for the runtime to initialize and list/call one tool. + */ + +import http from "node:http"; + +const DEFAULT_EXPECTED_TOKEN = "sdk-host-token"; +const PROTOCOL_VERSION = "2025-03-26"; + +export async function startOAuthMcpServer({ + expectedToken = DEFAULT_EXPECTED_TOKEN, + host = "127.0.0.1", + port = 0, +} = {}) { + const requests = []; + + const server = http.createServer(async (req, res) => { + const url = new URL( + req.url ?? "/", + `http://${req.headers.host ?? `${host}:${port}`}`, + ); + const baseUrl = `http://${req.headers.host}`; + + if (req.method === "GET" && url.pathname === "/__requests") { + respondJson(res, 200, requests); + return; + } + + if ( + req.method === "GET" && + url.pathname === "/.well-known/oauth-protected-resource" + ) { + respondJson(res, 200, { + resource: `${baseUrl}/mcp`, + authorization_servers: [baseUrl], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }); + return; + } + + if ( + req.method === "GET" && + url.pathname === "/.well-known/oauth-authorization-server" + ) { + respondJson(res, 200, { + issuer: baseUrl, + authorization_endpoint: `${baseUrl}/authorize`, + token_endpoint: `${baseUrl}/token`, + response_types_supported: ["code"], + grant_types_supported: ["authorization_code"], + }); + return; + } + + if (url.pathname !== "/mcp") { + respondJson(res, 404, { error: "not_found" }); + return; + } + + const body = await readBody(req); + requests.push({ + method: req.method, + path: url.pathname, + authorization: req.headers.authorization ?? null, + body: body ? JSON.parse(body) : null, + }); + + if (req.headers.authorization !== `Bearer ${expectedToken}`) { + const resourceMetadataUrl = `${baseUrl}/.well-known/oauth-protected-resource`; + res.writeHead(401, { + "www-authenticate": `Bearer resource_metadata="${resourceMetadataUrl}", scope="mcp.read", error="invalid_token"`, + "content-type": "application/json", + }); + res.end(JSON.stringify({ error: "missing_or_invalid_token" })); + return; + } + + if (req.method !== "POST") { + respondJson(res, 405, { error: "method_not_allowed" }); + return; + } + + const message = body ? JSON.parse(body) : undefined; + const response = Array.isArray(message) + ? message.map(handleJsonRpcMessage).filter((item) => item !== undefined) + : handleJsonRpcMessage(message); + + if ( + response === undefined || + (Array.isArray(response) && response.length === 0) + ) { + res.writeHead(202, { "mcp-session-id": "oauth-test-session" }); + res.end(); + return; + } + + res.writeHead(200, { + "content-type": "application/json", + "mcp-session-id": "oauth-test-session", + }); + res.end(JSON.stringify(response)); + }); + + await new Promise((resolve, reject) => { + server.once("error", reject); + server.listen(port, host, () => { + server.off("error", reject); + resolve(); + }); + }); + + const address = server.address(); + if (!address || typeof address === "string") { + throw new Error("Expected TCP server address"); + } + + return { + url: `http://${host}:${address.port}`, + requests, + close: () => + new Promise((resolve, reject) => + server.close((err) => (err ? reject(err) : resolve())), + ), + }; +} + +function handleJsonRpcMessage(message) { + if (!message || typeof message !== "object" || !("id" in message)) { + return undefined; + } + + switch (message.method) { + case "initialize": + return { + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: message.params?.protocolVersion ?? PROTOCOL_VERSION, + capabilities: { tools: {} }, + serverInfo: { name: "oauth-test-server", version: "1.0.0" }, + }, + }; + case "tools/list": + return { + jsonrpc: "2.0", + id: message.id, + result: { + tools: [ + { + name: "whoami", + description: "Returns the authenticated test principal.", + inputSchema: { + type: "object", + properties: {}, + additionalProperties: false, + }, + }, + ], + }, + }; + case "tools/call": + return { + jsonrpc: "2.0", + id: message.id, + result: { + content: [{ type: "text", text: "oauth-test-user" }], + isError: false, + }, + }; + default: + return { + jsonrpc: "2.0", + id: message.id, + error: { code: -32601, message: `Method not found: ${message.method}` }, + }; + } +} + +function readBody(req) { + return new Promise((resolve, reject) => { + const chunks = []; + req.on("data", (chunk) => chunks.push(chunk)); + req.on("error", reject); + req.on("end", () => resolve(Buffer.concat(chunks).toString("utf8"))); + }); +} + +function respondJson(res, statusCode, body) { + const data = JSON.stringify(body); + res.writeHead(statusCode, { + "content-type": "application/json", + "content-length": Buffer.byteLength(data), + }); + res.end(data); +} + +if (import.meta.url === `file://${process.argv[1]}`) { + const server = await startOAuthMcpServer({ + expectedToken: process.env.EXPECTED_TOKEN ?? DEFAULT_EXPECTED_TOKEN, + }); + console.log(`Listening: ${server.url}`); + process.on("SIGTERM", async () => { + await server.close(); + process.exit(0); + }); +} From 16109983e7bbf3491eb5f9a19db33218a13868a2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 14 Jun 2026 12:04:51 +0000 Subject: [PATCH 2/5] Regenerate Java codegen output Auto-committed by java-codegen-check workflow. --- .../generated/McpOauthCompletedEvent.java | 4 +- .../generated/McpOauthCompletedOutcome.java | 37 ------------------- .../generated/McpOauthRequiredEvent.java | 6 +-- ...McpOauthRequiredWwwAuthenticateParams.java | 31 ---------------- ...SessionEventLogRegisterInterestParams.java | 2 +- .../generated/rpc/SessionMcpOauthApi.java | 16 -------- ...ionMcpOauthHandlePendingRequestParams.java | 34 ----------------- ...ionMcpOauthHandlePendingRequestResult.java | 30 --------------- .../rpc/SessionMcpOauthRespondParams.java | 2 +- 9 files changed, 5 insertions(+), 157 deletions(-) delete mode 100644 java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java delete mode 100644 java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java delete mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java delete mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java index 8bfa56849..635751b43 100644 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedEvent.java @@ -35,9 +35,7 @@ public final class McpOauthCompletedEvent extends SessionEvent { @JsonInclude(JsonInclude.Include.NON_NULL) public record McpOauthCompletedEventData( /** Request ID of the resolved OAuth request */ - @JsonProperty("requestId") String requestId, - /** How the pending OAuth request was completed */ - @JsonProperty("outcome") McpOauthCompletedOutcome outcome + @JsonProperty("requestId") String requestId ) { } } diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java b/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java deleted file mode 100644 index e362f43bd..000000000 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthCompletedOutcome.java +++ /dev/null @@ -1,37 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - *--------------------------------------------------------------------------------------------*/ - -// AUTO-GENERATED FILE - DO NOT EDIT -// Generated from: session-events.schema.json - -package com.github.copilot.generated; - -import javax.annotation.processing.Generated; - -/** - * How the pending OAuth request was completed - * - * @since 1.0.0 - */ -@javax.annotation.processing.Generated("copilot-sdk-codegen") -public enum McpOauthCompletedOutcome { - /** The {@code token} variant. */ - TOKEN("token"), - /** The {@code cancelled} variant. */ - CANCELLED("cancelled"), - /** The {@code timeout} variant. */ - TIMEOUT("timeout"); - - private final String value; - McpOauthCompletedOutcome(String value) { this.value = value; } - @com.fasterxml.jackson.annotation.JsonValue - public String getValue() { return value; } - @com.fasterxml.jackson.annotation.JsonCreator - public static McpOauthCompletedOutcome fromValue(String value) { - for (McpOauthCompletedOutcome v : values()) { - if (v.value.equals(value)) return v; - } - throw new IllegalArgumentException("Unknown McpOauthCompletedOutcome value: " + value); - } -} diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java index 4ebaf351a..02e67a35f 100644 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java @@ -34,16 +34,14 @@ public final class McpOauthRequiredEvent extends SessionEvent { @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) public record McpOauthRequiredEventData( - /** Unique identifier for this OAuth request; used to respond via session.mcp.oauth.handlePendingRequest */ + /** Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() */ @JsonProperty("requestId") String requestId, /** Display name of the MCP server that requires OAuth */ @JsonProperty("serverName") String serverName, /** URL of the MCP server that requires OAuth */ @JsonProperty("serverUrl") String serverUrl, /** Static OAuth client configuration, if the server specifies one */ - @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig, - /** Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. */ - @JsonProperty("wwwAuthenticateParams") McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams + @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig ) { } } diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java deleted file mode 100644 index 072b09ab2..000000000 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java +++ /dev/null @@ -1,31 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - *--------------------------------------------------------------------------------------------*/ - -// AUTO-GENERATED FILE - DO NOT EDIT -// Generated from: session-events.schema.json - -package com.github.copilot.generated; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import javax.annotation.processing.Generated; - -/** - * Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. - * - * @since 1.0.0 - */ -@javax.annotation.processing.Generated("copilot-sdk-codegen") -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonIgnoreProperties(ignoreUnknown = true) -public record McpOauthRequiredWwwAuthenticateParams( - /** Parsed resource_metadata URL from the WWW-Authenticate header */ - @JsonProperty("resourceMetadataUrl") String resourceMetadataUrl, - /** Parsed OAuth scope from the WWW-Authenticate header, if present */ - @JsonProperty("scope") String scope, - /** Parsed OAuth error from the WWW-Authenticate header, if present */ - @JsonProperty("error") String error -) { -} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java index 74250b75e..af0bca43e 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionEventLogRegisterInterestParams.java @@ -26,7 +26,7 @@ public record SessionEventLogRegisterInterestParams( /** Target session identifier */ @JsonProperty("sessionId") String sessionId, - /** The event type the consumer wants the runtime to treat as observed for behavior-switching gating. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. */ + /** The event type the consumer wants the runtime to treat as 'observed' for behavior-switching gating. Some runtime code paths inspect whether any consumer is interested in a specific event type and choose a different implementation accordingly (e.g. `mcp.oauth_required`: when interest is registered the runtime delegates the full interactive OAuth flow to the consumer; when no interest is registered the runtime installs a browserless fallback that silently reuses cached tokens). SDK clients that long-poll events do NOT automatically appear as listeners to these gating checks — they must explicitly call `registerInterest` for each event type they want the runtime to count as having a consumer. Multiple registrations for the same event type from the same or different consumers are tracked independently and must each be released. See: `mcp.oauth_required`, `sampling.requested`, `auto_mode_switch.requested`, `user_input.requested`, `elicitation.requested`, `command.queued`, `exit_plan_mode.requested`. */ @JsonProperty("eventType") String eventType ) { } diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java index 95a081206..59c4e45a1 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java @@ -46,22 +46,6 @@ public CompletableFuture respond(SessionMcpOauthRespondParams params) { return caller.invoke("session.mcp.oauth.respond", _p, Void.class); } - /** - * Pending MCP OAuth request ID and host-provided token or cancellation response. - *

- * Note: the {@code sessionId} field in the params record is overridden - * by the session-scoped wrapper; any value provided is ignored. - * - * @apiNote This method is experimental and may change in a future version. - * @since 1.0.0 - */ - @CopilotExperimental - public CompletableFuture handlePendingRequest(SessionMcpOauthHandlePendingRequestParams params) { - com.fasterxml.jackson.databind.node.ObjectNode _p = MAPPER.valueToTree(params); - _p.put("sessionId", this.sessionId); - return caller.invoke("session.mcp.oauth.handlePendingRequest", _p, SessionMcpOauthHandlePendingRequestResult.class); - } - /** * Remote MCP server name and optional overrides controlling reauthentication, OAuth client display name, and the callback success-page copy. *

diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java deleted file mode 100644 index 5aab57ef9..000000000 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java +++ /dev/null @@ -1,34 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - *--------------------------------------------------------------------------------------------*/ - -// AUTO-GENERATED FILE - DO NOT EDIT -// Generated from: api.schema.json - -package com.github.copilot.generated.rpc; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.github.copilot.CopilotExperimental; -import javax.annotation.processing.Generated; - -/** - * Pending MCP OAuth request ID and host-provided token or cancellation response. - * - * @apiNote This method is experimental and may change in a future version. - * @since 1.0.0 - */ -@CopilotExperimental -@javax.annotation.processing.Generated("copilot-sdk-codegen") -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonIgnoreProperties(ignoreUnknown = true) -public record SessionMcpOauthHandlePendingRequestParams( - /** Target session identifier */ - @JsonProperty("sessionId") String sessionId, - /** OAuth request identifier for the pending request. */ - @JsonProperty("requestId") String requestId, - /** Host response to the pending OAuth request. */ - @JsonProperty("result") Object result -) { -} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java deleted file mode 100644 index a7bca646e..000000000 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java +++ /dev/null @@ -1,30 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - *--------------------------------------------------------------------------------------------*/ - -// AUTO-GENERATED FILE - DO NOT EDIT -// Generated from: api.schema.json - -package com.github.copilot.generated.rpc; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.github.copilot.CopilotExperimental; -import javax.annotation.processing.Generated; - -/** - * Indicates whether the pending MCP OAuth response was accepted. - * - * @apiNote This method is experimental and may change in a future version. - * @since 1.0.0 - */ -@CopilotExperimental -@javax.annotation.processing.Generated("copilot-sdk-codegen") -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonIgnoreProperties(ignoreUnknown = true) -public record SessionMcpOauthHandlePendingRequestResult( - /** Whether the response was accepted. False if the request was unknown, timed out, or already resolved. */ - @JsonProperty("success") Boolean success -) { -} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java index d46890ca9..9757a9538 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthRespondParams.java @@ -26,7 +26,7 @@ public record SessionMcpOauthRespondParams( /** Target session identifier */ @JsonProperty("sessionId") String sessionId, - /** OAuth request identifier for the pending request. */ + /** OAuth request identifier from mcp.oauth_required */ @JsonProperty("requestId") String requestId, /** In-process OAuthClientProvider instance, or omitted to deny. Marked internal: cannot be serialized across the JSON-RPC boundary. */ @JsonProperty("provider") Object provider From 79425dcda7ce20ca486ec987952568d3711bdb67 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 17 Jun 2026 18:38:57 +0200 Subject: [PATCH 3/5] Expose MCP OAuth resource metadata in SDKs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Generated/SessionEvents.cs | 10 ++- dotnet/src/Session.cs | 1 + dotnet/src/Types.cs | 8 ++- .../Unit/SessionEventSerializationTests.cs | 29 +++++++++ go/rpc/zsession_events.go | 7 ++- go/session.go | 30 +++++---- go/session_test.go | 62 ++++++++++++++++++- go/types.go | 3 +- .../generated/McpOauthRequiredEvent.java | 6 +- .../com/github/copilot/CopilotSession.java | 2 +- .../github/copilot/rpc/McpAuthRequest.java | 1 + .../McpAuthInterestRegistrationTest.java | 42 +++++++++++++ nodejs/src/generated/session-events.ts | 6 +- nodejs/src/types.ts | 9 ++- nodejs/test/client.test.ts | 41 ++++++++++-- nodejs/test/e2e/mcp_oauth.e2e.test.ts | 6 ++ python/copilot/generated/session_events.py | 12 +++- python/copilot/session.py | 20 +++--- python/test_client.py | 46 ++++++++++++-- rust/src/generated/session_events.rs | 8 ++- rust/src/handler.rs | 6 +- rust/src/session.rs | 3 +- rust/tests/session_test.rs | 30 ++++++++- 23 files changed, 333 insertions(+), 55 deletions(-) diff --git a/dotnet/src/Generated/SessionEvents.cs b/dotnet/src/Generated/SessionEvents.cs index d4969a5b1..7cab9e269 100644 --- a/dotnet/src/Generated/SessionEvents.cs +++ b/dotnet/src/Generated/SessionEvents.cs @@ -3069,9 +3069,15 @@ public sealed partial class McpOauthRequiredData [JsonPropertyName("staticClientConfig")] public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } - ///

Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + /// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. Omitted for older event producers. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("wwwAuthenticateParams")] - public required McpOauthRequiredWwwAuthenticateParams WwwAuthenticateParams { get; set; } + public McpOauthRequiredWwwAuthenticateParams? WwwAuthenticateParams { get; set; } + + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime from the MCP server origin, if discovery succeeded. Omitted for older event producers and when metadata discovery fails. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("resourceMetadata")] + public string? ResourceMetadata { get; set; } } /// MCP OAuth request completion notification. diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 1d0d9572a..f84f0c7de 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -665,6 +665,7 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) ServerName = data.ServerName, ServerUrl = data.ServerUrl, WwwAuthenticateParams = data.WwwAuthenticateParams, + ResourceMetadata = data.ResourceMetadata, StaticClientConfig = data.StaticClientConfig }, handler); break; diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 78244e353..119fa51e3 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1121,9 +1121,11 @@ public sealed class McpAuthContext /// URL of the MCP server that requires OAuth. public string ServerUrl { get; set; } = string.Empty; - /// Parsed WWW-Authenticate parameters from the MCP server. - public McpOauthRequiredWwwAuthenticateParams WwwAuthenticateParams { get; set; } = - new() { ResourceMetadataUrl = string.Empty }; + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + public McpOauthRequiredWwwAuthenticateParams? WwwAuthenticateParams { get; set; } + + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + public string? ResourceMetadata { get; set; } /// Static OAuth client configuration, if the server specifies one. public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } diff --git a/dotnet/test/Unit/SessionEventSerializationTests.cs b/dotnet/test/Unit/SessionEventSerializationTests.cs index 2db690a6c..db19558b6 100644 --- a/dotnet/test/Unit/SessionEventSerializationTests.cs +++ b/dotnet/test/Unit/SessionEventSerializationTests.cs @@ -162,6 +162,7 @@ public class SessionEventSerializationTests { ResourceMetadataUrl = "https://example.com/.well-known/oauth-protected-resource", }, + ResourceMetadata = """{"resource":"https://example.com/mcp"}""", }, }, "mcp.oauth_required" @@ -285,6 +286,11 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven .GetProperty("staticClientConfig") .GetProperty("grantType") .GetString()); + Assert.Equal( + """{"resource":"https://example.com/mcp"}""", + root.GetProperty("data") + .GetProperty("resourceMetadata") + .GetString()); break; case "assistant.message_start": @@ -301,4 +307,27 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven break; } } + + [Fact] + public void McpOauthRequiredData_Allows_Missing_Optional_Metadata() + { + const string json = """ + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + } + """; + + var authEvent = JsonSerializer.Deserialize(json); + Assert.NotNull(authEvent); + Assert.Null(authEvent.Data.WwwAuthenticateParams); + Assert.Null(authEvent.Data.ResourceMetadata); + } } diff --git a/go/rpc/zsession_events.go b/go/rpc/zsession_events.go index 535d6b3e3..c57b7e784 100644 --- a/go/rpc/zsession_events.go +++ b/go/rpc/zsession_events.go @@ -698,8 +698,10 @@ type MCPOauthRequiredData struct { ServerURL string `json:"serverUrl"` // Static OAuth client configuration, if the server specifies one StaticClientConfig *MCPOauthRequiredStaticClientConfig `json:"staticClientConfig,omitempty"` - // Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. - WwwAuthenticateParams MCPOauthRequiredWwwAuthenticateParams `json:"wwwAuthenticateParams"` + // Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. Omitted for older event producers. + WwwAuthenticateParams *MCPOauthRequiredWwwAuthenticateParams `json:"wwwAuthenticateParams,omitempty"` + // Raw RFC 9728 protected-resource metadata JSON fetched by the runtime from the MCP server origin, if discovery succeeded. Omitted for older event producers and when metadata discovery fails. + ResourceMetadata *string `json:"resourceMetadata,omitempty"` } func (*MCPOauthRequiredData) sessionEventData() {} @@ -1919,6 +1921,7 @@ type MCPOauthRequiredWwwAuthenticateParams struct { Scope *string `json:"scope,omitempty"` } + // Schema for the `McpServersLoadedServer` type. type MCPServersLoadedServer struct { // Error message if the server failed to connect diff --git a/go/session.go b/go/session.go index 265d18343..665879e28 100644 --- a/go/session.go +++ b/go/session.go @@ -1377,24 +1377,30 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { PublicClient: d.StaticClientConfig.PublicClient, } } - var scope, oauthError string - if d.WwwAuthenticateParams.Scope != nil { - scope = *d.WwwAuthenticateParams.Scope - } - if d.WwwAuthenticateParams.Error != nil { - oauthError = *d.WwwAuthenticateParams.Error - } - s.handleMCPAuthRequest(MCPAuthRequest{ + request := MCPAuthRequest{ RequestID: d.RequestID, ServerName: d.ServerName, ServerURL: d.ServerURL, - WwwAuthenticateParams: MCPAuthWwwAuthenticateParams{ + StaticClientConfig: staticClientConfig, + } + if d.ResourceMetadata != nil { + request.ResourceMetadata = *d.ResourceMetadata + } + if d.WwwAuthenticateParams != nil { + var scope, oauthError string + if d.WwwAuthenticateParams.Scope != nil { + scope = *d.WwwAuthenticateParams.Scope + } + if d.WwwAuthenticateParams.Error != nil { + oauthError = *d.WwwAuthenticateParams.Error + } + request.WwwAuthenticateParams = &MCPAuthWwwAuthenticateParams{ ResourceMetadataURL: d.WwwAuthenticateParams.ResourceMetadataURL, Scope: scope, Error: oauthError, - }, - StaticClientConfig: staticClientConfig, - }) + } + } + s.handleMCPAuthRequest(request) case *CommandExecuteData: s.executeCommandAndRespond(d.RequestID, d.CommandName, d.Command, d.Args) diff --git a/go/session_test.go b/go/session_test.go index 7cc9caece..b20918868 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -118,7 +118,9 @@ func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { client: client, RPC: rpc.NewSessionRPC(client, "session-1"), } + var observedRequest MCPAuthRequest session.registerMCPAuthHandler(func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) { + observedRequest = request if invocation.SessionID != "session-1" { t.Fatalf("expected invocation session-1, got %s", invocation.SessionID) } @@ -134,7 +136,19 @@ func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { }, }, nil }) - session.handleMCPAuthRequest(MCPAuthRequest{RequestID: "oauth-request"}) + session.handleMCPAuthRequest(MCPAuthRequest{ + RequestID: "oauth-request", + ResourceMetadata: `{"resource":"https://example.com/mcp"}`, + WwwAuthenticateParams: &MCPAuthWwwAuthenticateParams{ + ResourceMetadataURL: "https://example.com/.well-known/oauth-protected-resource", + }, + }) + if observedRequest.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata to be propagated, got %q", observedRequest.ResourceMetadata) + } + if observedRequest.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params to be propagated") + } select { case params := <-paramsCh: @@ -164,6 +178,52 @@ func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { } } +func TestMCPAuthRequestAllowsMissingOptionalMetadata(t *testing.T) { + request := MCPAuthRequest{RequestID: "oauth-request"} + if request.ResourceMetadata != "" { + t.Fatalf("expected no resource metadata, got %q", request.ResourceMetadata) + } + if request.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", request.WwwAuthenticateParams) + } +} + +func TestMCPOauthRequiredDataAllowsOptionalMetadata(t *testing.T) { + var withMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}" + }`), &withMetadata); err != nil { + t.Fatal(err) + } + if withMetadata.ResourceMetadata == nil || *withMetadata.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata, got %#v", withMetadata.ResourceMetadata) + } + if withMetadata.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params") + } + + var withoutMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + }`), &withoutMetadata); err != nil { + t.Fatal(err) + } + if withoutMetadata.ResourceMetadata != nil { + t.Fatalf("expected no resource metadata, got %#v", withoutMetadata.ResourceMetadata) + } + if withoutMetadata.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", withoutMetadata.WwwAuthenticateParams) + } +} + func captureSetModelRequest(t *testing.T, opts *SetModelOptions) map[string]any { t.Helper() diff --git a/go/types.go b/go/types.go index d9594fd49..d528a930a 100644 --- a/go/types.go +++ b/go/types.go @@ -328,7 +328,8 @@ type MCPAuthRequest struct { RequestID string `json:"requestId"` ServerName string `json:"serverName"` ServerURL string `json:"serverUrl"` - WwwAuthenticateParams MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams"` + WwwAuthenticateParams *MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams,omitempty"` + ResourceMetadata string `json:"resourceMetadata,omitempty"` StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` } diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java index 02e67a35f..0440e4c0d 100644 --- a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredEvent.java @@ -41,7 +41,11 @@ public record McpOauthRequiredEventData( /** URL of the MCP server that requires OAuth */ @JsonProperty("serverUrl") String serverUrl, /** Static OAuth client configuration, if the server specifies one */ - @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig + @JsonProperty("staticClientConfig") McpOauthRequiredStaticClientConfig staticClientConfig, + /** Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. Omitted for older event producers. */ + @JsonProperty("wwwAuthenticateParams") McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams, + /** Raw RFC 9728 protected-resource metadata JSON fetched by the runtime from the MCP server origin, if discovery succeeded. Omitted for older event producers and when metadata discovery fails. */ + @JsonProperty("resourceMetadata") String resourceMetadata ) { } } diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 2c7f07f4a..9a3807df9 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -856,7 +856,7 @@ private void handleBroadcastEventAsync(SessionEvent event) { return; } executeMcpAuthAndRespondAsync(new McpAuthRequest(sessionId, data.requestId(), data.serverName(), - data.serverUrl(), data.wwwAuthenticateParams(), data.staticClientConfig()), handler); + data.serverUrl(), data.wwwAuthenticateParams(), data.resourceMetadata(), data.staticClientConfig()), handler); } else if (event instanceof CommandExecuteEvent cmdEvent) { var data = cmdEvent.getData(); if (data == null || data.requestId() == null || data.commandName() == null) { diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java index ae5222f4c..a836eaa40 100644 --- a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java @@ -18,5 +18,6 @@ public record McpAuthRequest( String serverName, String serverUrl, McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams, + String resourceMetadata, McpOauthRequiredStaticClientConfig staticClientConfig) { } diff --git a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java index 328da7dcd..30ae0b970 100644 --- a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java +++ b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.generated.McpOauthRequiredEvent; import com.github.copilot.rpc.CloudSessionOptions; import com.github.copilot.rpc.CloudSessionRepository; import com.github.copilot.rpc.CopilotClientOptions; @@ -30,6 +31,47 @@ class McpAuthInterestRegistrationTest { private static final ObjectMapper MAPPER = new ObjectMapper(); + @Test + void mcpOauthRequiredEventExposesOptionalResourceMetadata() throws Exception { + var event = MAPPER.readValue(""" + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\\"resource\\":\\"https://example.com/mcp\\"}" + } + } + """, McpOauthRequiredEvent.class); + + assertEquals("{\"resource\":\"https://example.com/mcp\"}", event.getData().resourceMetadata()); + assertNotNull(event.getData().wwwAuthenticateParams()); + + var withoutMetadata = MAPPER.readValue(""" + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + } + """, McpOauthRequiredEvent.class); + + assertNull(withoutMetadata.getData().resourceMetadata()); + assertNull(withoutMetadata.getData().wwwAuthenticateParams()); + } + @Test void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { try (var server = new RecordingRuntime(); diff --git a/nodejs/src/generated/session-events.ts b/nodejs/src/generated/session-events.ts index 9215b178d..2410ed250 100644 --- a/nodejs/src/generated/session-events.ts +++ b/nodejs/src/generated/session-events.ts @@ -5800,7 +5800,11 @@ export interface McpOauthRequiredData { */ serverUrl: string; staticClientConfig?: McpOauthRequiredStaticClientConfig; - wwwAuthenticateParams: McpOauthRequiredWwwAuthenticateParams; + wwwAuthenticateParams?: McpOauthRequiredWwwAuthenticateParams; + /** + * Raw RFC 9728 protected-resource metadata JSON fetched by the runtime from the MCP server origin, if discovery succeeded. Omitted for older event producers and when metadata discovery fails. + */ + resourceMetadata?: string; } /** * Static OAuth client configuration, if the server specifies one diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 166f6c861..2375c1308 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -14,10 +14,7 @@ import type { SessionEvent as GeneratedSessionEvent, } from "./generated/session-events.js"; import type { CopilotSession } from "./session.js"; -import type { - OpenCanvasInstance, - RemoteSessionMode, -} from "./generated/rpc.js"; +import type { OpenCanvasInstance, RemoteSessionMode } from "./generated/rpc.js"; import type { ToolSet } from "./toolSet.js"; export type { RemoteSessionMode } from "./generated/rpc.js"; export type SessionEvent = GeneratedSessionEvent; @@ -1577,7 +1574,9 @@ export interface McpAuthRequest { /** URL of the MCP server that requires OAuth. */ serverUrl: string; /** Parsed WWW-Authenticate parameters from the MCP server. */ - wwwAuthenticateParams: McpAuthWwwAuthenticateParams; + wwwAuthenticateParams?: McpAuthWwwAuthenticateParams; + /** Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. */ + resourceMetadata?: string; /** Static OAuth client configuration, if the server specifies one. */ staticClientConfig?: McpAuthStaticClientConfig; } diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index e48d9805e..8c9f771ae 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -25,17 +25,21 @@ describe("CopilotClient", () => { it("responds to MCP OAuth requests with host token data", async () => { const sendRequest = vi.fn(async () => ({ success: true })); + let observedRequest: any; const session = new CopilotSession( "session-1", { sendRequest } as any, undefined, undefined, { - mcpAuthHandler: async () => ({ - accessToken: "host-token", - tokenType: "Bearer", - expiresIn: 3600, - }), + mcpAuthHandler: async (request) => { + observedRequest = request; + return { + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }; + }, } ); @@ -46,8 +50,10 @@ describe("CopilotClient", () => { wwwAuthenticateParams: { resourceMetadataUrl: "https://example.com/.well-known/oauth-protected-resource", }, + resourceMetadata: '{"resource":"https://example.com/mcp"}', }); + expect(observedRequest.resourceMetadata).toBe('{"resource":"https://example.com/mcp"}'); expect(sendRequest).toHaveBeenCalledWith("session.mcp.oauth.handlePendingRequest", { sessionId: "session-1", requestId: "oauth-request", @@ -60,6 +66,31 @@ describe("CopilotClient", () => { }); }); + it("passes MCP OAuth requests through when optional metadata is absent", async () => { + let observedRequest: any; + const session = new CopilotSession( + "session-1", + { sendRequest: vi.fn(async () => ({ success: true })) } as any, + undefined, + undefined, + { + mcpAuthHandler: async (request) => { + observedRequest = request; + return { kind: "cancelled" }; + }, + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + }); + + expect(observedRequest.resourceMetadata).toBeUndefined(); + expect(observedRequest.wwwAuthenticateParams).toBeUndefined(); + }); + it("registers interest in MCP OAuth required events after create when an auth handler is configured", async () => { const client = new CopilotClient(); await client.start(); diff --git a/nodejs/test/e2e/mcp_oauth.e2e.test.ts b/nodejs/test/e2e/mcp_oauth.e2e.test.ts index 2817ac5c3..52dc902ed 100644 --- a/nodejs/test/e2e/mcp_oauth.e2e.test.ts +++ b/nodejs/test/e2e/mcp_oauth.e2e.test.ts @@ -62,6 +62,12 @@ describe("MCP OAuth host auth", async () => { scope: "mcp.read", error: "invalid_token", }, + resourceMetadata: JSON.stringify({ + resource: `${oauthServer.url}/mcp`, + authorization_servers: [oauthServer.url], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }), }); const requests = await oauthServer.requests(); diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index 52fd43512..3f8ee194f 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -2230,8 +2230,9 @@ class McpOauthRequiredData: request_id: str server_name: str server_url: str - www_authenticate_params: McpOauthRequiredWwwAuthenticateParams + www_authenticate_params: McpOauthRequiredWwwAuthenticateParams | None = None static_client_config: McpOauthRequiredStaticClientConfig | None = None + resource_metadata: str | None = None @staticmethod def from_dict(obj: Any) -> "McpOauthRequiredData": @@ -2239,14 +2240,16 @@ def from_dict(obj: Any) -> "McpOauthRequiredData": request_id = from_str(obj.get("requestId")) server_name = from_str(obj.get("serverName")) server_url = from_str(obj.get("serverUrl")) - www_authenticate_params = McpOauthRequiredWwwAuthenticateParams.from_dict(obj.get("wwwAuthenticateParams")) + www_authenticate_params = from_union([from_none, McpOauthRequiredWwwAuthenticateParams.from_dict], obj.get("wwwAuthenticateParams")) static_client_config = from_union([from_none, McpOauthRequiredStaticClientConfig.from_dict], obj.get("staticClientConfig")) + resource_metadata = from_union([from_none, from_str], obj.get("resourceMetadata")) return McpOauthRequiredData( request_id=request_id, server_name=server_name, server_url=server_url, www_authenticate_params=www_authenticate_params, static_client_config=static_client_config, + resource_metadata=resource_metadata, ) def to_dict(self) -> dict: @@ -2254,9 +2257,12 @@ def to_dict(self) -> dict: result["requestId"] = from_str(self.request_id) result["serverName"] = from_str(self.server_name) result["serverUrl"] = from_str(self.server_url) - result["wwwAuthenticateParams"] = to_class(McpOauthRequiredWwwAuthenticateParams, self.www_authenticate_params) + if self.www_authenticate_params is not None: + result["wwwAuthenticateParams"] = from_union([from_none, lambda x: to_class(McpOauthRequiredWwwAuthenticateParams, x)], self.www_authenticate_params) if self.static_client_config is not None: result["staticClientConfig"] = from_union([from_none, lambda x: to_class(McpOauthRequiredStaticClientConfig, x)], self.static_client_config) + if self.resource_metadata is not None: + result["resourceMetadata"] = from_union([from_none, from_str], self.resource_metadata) return result diff --git a/python/copilot/session.py b/python/copilot/session.py index 672dfa20c..629f9999b 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -325,7 +325,8 @@ class McpAuthRequest(TypedDict, total=False): requestId: Required[str] serverName: Required[str] serverUrl: Required[str] - wwwAuthenticateParams: Required[McpAuthWwwAuthenticateParams] + wwwAuthenticateParams: McpAuthWwwAuthenticateParams + resourceMetadata: str staticClientConfig: McpAuthStaticClientConfig @@ -1587,14 +1588,17 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: "requestId": data.request_id, "serverName": data.server_name, "serverUrl": data.server_url, - "wwwAuthenticateParams": { - "resourceMetadataUrl": data.www_authenticate_params.resource_metadata_url, - }, } - if data.www_authenticate_params.scope is not None: - request["wwwAuthenticateParams"]["scope"] = data.www_authenticate_params.scope - if data.www_authenticate_params.error is not None: - request["wwwAuthenticateParams"]["error"] = data.www_authenticate_params.error + if data.www_authenticate_params is not None: + request["wwwAuthenticateParams"] = { + "resourceMetadataUrl": data.www_authenticate_params.resource_metadata_url, + } + if data.www_authenticate_params.scope is not None: + request["wwwAuthenticateParams"]["scope"] = data.www_authenticate_params.scope + if data.www_authenticate_params.error is not None: + request["wwwAuthenticateParams"]["error"] = data.www_authenticate_params.error + if data.resource_metadata is not None: + request["resourceMetadata"] = data.resource_metadata if data.static_client_config is not None: static_client_config: McpAuthStaticClientConfig = { "clientId": data.static_client_config.client_id, diff --git a/python/test_client.py b/python/test_client.py index dd4f90017..2f851e0b4 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -270,12 +270,19 @@ async def mock_request(method, params, **kwargs): return {} client._client.request = mock_request - session = await client.create_session( - on_permission_request=PermissionHandler.approve_all, - on_mcp_auth_request=lambda request: { + observed_request = None + + def handle_mcp_auth_request(request): + nonlocal observed_request + observed_request = request + return { "accessToken": "host-token", "tokenType": "Bearer", - }, + } + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=handle_mcp_auth_request, ) session._dispatch_event( @@ -287,6 +294,7 @@ async def mock_request(method, params, **kwargs): www_authenticate_params=McpOauthRequiredWwwAuthenticateParams( resource_metadata_url="https://example.com/.well-known/oauth-protected-resource" ), + resource_metadata='{"resource":"https://example.com/mcp"}', ), id="evt-1", timestamp="2026-01-01T00:00:00Z", @@ -301,6 +309,11 @@ async def mock_request(method, params, **kwargs): break await asyncio.sleep(0.005) + assert observed_request is not None + assert observed_request["resourceMetadata"] == '{"resource":"https://example.com/mcp"}' + assert observed_request["wwwAuthenticateParams"]["resourceMetadataUrl"] == ( + "https://example.com/.well-known/oauth-protected-resource" + ) assert captured == [ ( "session.mcp.oauth.handlePendingRequest", @@ -315,6 +328,31 @@ async def mock_request(method, params, **kwargs): }, ) ] + + observed_request = None + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request-without-metadata", + server_name="oauth-server", + server_url="https://example.com/mcp", + ), + id="evt-2", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if observed_request is not None: + break + await asyncio.sleep(0.005) + + assert observed_request is not None + assert "resourceMetadata" not in observed_request + assert "wwwAuthenticateParams" not in observed_request finally: await client.force_stop() diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs index 71e1613d8..04fa4e006 100644 --- a/rust/src/generated/session_events.rs +++ b/rust/src/generated/session_events.rs @@ -2898,8 +2898,12 @@ pub struct McpOauthRequiredData { /// Static OAuth client configuration, if the server specifies one #[serde(skip_serializing_if = "Option::is_none")] pub static_client_config: Option, - /// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. - pub www_authenticate_params: McpOauthRequiredWwwAuthenticateParams, + /// Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. Omitted for older event producers. + #[serde(skip_serializing_if = "Option::is_none")] + pub www_authenticate_params: Option, + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime from the MCP server origin, if discovery succeeded. Omitted for older event producers and when metadata discovery fails. + #[serde(skip_serializing_if = "Option::is_none")] + pub resource_metadata: Option, } /// Session event "mcp.oauth_completed". MCP OAuth request completion notification diff --git a/rust/src/handler.rs b/rust/src/handler.rs index ff3781ffe..568c494d4 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -171,8 +171,10 @@ pub struct McpAuthRequest { pub server_name: String, /// URL of the MCP server that requires OAuth. pub server_url: String, - /// Parsed WWW-Authenticate parameters from the MCP server. - pub www_authenticate_params: McpOauthRequiredWwwAuthenticateParams, + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + pub www_authenticate_params: Option, + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + pub resource_metadata: Option, /// Static OAuth client configuration, if the server specifies one. pub static_client_config: Option, } diff --git a/rust/src/session.rs b/rust/src/session.rs index 5103f8413..44f6d9042 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -1970,7 +1970,7 @@ async fn handle_notification( }; let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else { warn!( - session_id = %sid, + session_id = %session_id, request_id = %request_id, "received MCP OAuth request without a registered MCP auth handler" ); @@ -1988,6 +1988,7 @@ async fn handle_notification( server_name: data.server_name, server_url: data.server_url, www_authenticate_params: data.www_authenticate_params, + resource_metadata: data.resource_metadata, static_client_config: data.static_client_config, }; let client = client.clone(); diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 3a3c6a136..d5b3ce50c 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -16,7 +16,7 @@ use github_copilot_sdk::rpc::{ CanvasInstanceAvailability, CanvasProviderInvokeActionRequest, CanvasProviderOpenRequest, CanvasProviderOpenResult, OpenCanvasInstance, }; -use github_copilot_sdk::session_events::ReasoningSummary; +use github_copilot_sdk::session_events::{McpOauthRequiredData, ReasoningSummary}; use github_copilot_sdk::types::{ CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, @@ -236,6 +236,34 @@ fn rand_id() -> u64 { COUNTER.fetch_add(1, Ordering::Relaxed) as u64 } +#[test] +fn mcp_oauth_required_data_allows_optional_metadata() { + let with_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}" + })) + .unwrap(); + assert_eq!( + with_metadata.resource_metadata.as_deref(), + Some("{\"resource\":\"https://example.com/mcp\"}") + ); + assert!(with_metadata.www_authenticate_params.is_some()); + + let without_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + })) + .unwrap(); + assert!(without_metadata.resource_metadata.is_none()); + assert!(without_metadata.www_authenticate_params.is_none()); +} + fn requested_session_id(request: &Value) -> &str { request["params"]["sessionId"] .as_str() From fe5eb3bc61bc4ce830e965b25751f6cb38fa4bf4 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 17 Jun 2026 18:40:39 +0200 Subject: [PATCH 4/5] Temporarily validate MCP OAuth E2E against local runtime Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- nodejs/test/e2e/mcp_oauth.e2e.test.ts | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/nodejs/test/e2e/mcp_oauth.e2e.test.ts b/nodejs/test/e2e/mcp_oauth.e2e.test.ts index 52dc902ed..78a4bf2a2 100644 --- a/nodejs/test/e2e/mcp_oauth.e2e.test.ts +++ b/nodejs/test/e2e/mcp_oauth.e2e.test.ts @@ -8,17 +8,30 @@ import { createInterface } from "node:readline"; import { fileURLToPath } from "node:url"; import { describe, expect, it, onTestFinished } from "vitest"; import type { CopilotSession, MCPServerConfig, McpAuthRequest } from "../../src/index.js"; -import { approveAll } from "../../src/index.js"; +import { approveAll, RuntimeConnection } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext.js"; import { waitForCondition } from "./harness/sdkTestHelper.js"; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); const TEST_MCP_OAUTH_SERVER = resolve(__dirname, "../../../test/harness/test-mcp-oauth-server.mjs"); +const LOCAL_RUNTIME_DIR = + "/Users/roji/.copilot/repos/copilot-worktrees/copilot-agent-runtime/roji-ubiquitous-parakeet"; const EXPECTED_TOKEN = "sdk-host-token"; describe("MCP OAuth host auth", async () => { - const { copilotClient: client } = await createSdkTestContext(); + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + connection: RuntimeConnection.forStdio({ + path: "/bin/sh", + args: [ + "-c", + `cd ${JSON.stringify(LOCAL_RUNTIME_DIR)} && exec node --enable-source-maps --report-on-fatalerror dist-cli/index.js "$@"`, + "copilot-local-runtime", + ], + }), + }, + }); it("should satisfy MCP OAuth using host-provided token", { timeout: 120_000 }, async () => { const oauthServer = await startOAuthMcpServer(); From 59694fd4baaa6b5b67b7139e8403d532efcdfe4b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 17 Jun 2026 19:06:29 +0200 Subject: [PATCH 5/5] Fix MCP OAuth validation issues Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Client.cs | 5 +- dotnet/src/JsonRpc.cs | 5 +- dotnet/src/Session.cs | 12 ++- .../Unit/SessionEventSerializationTests.cs | 3 +- go/client_test.go | 16 ++-- go/rpc/zsession_events.go | 1 - go/session.go | 6 +- go/types.go | 8 +- ...McpOauthRequiredWwwAuthenticateParams.java | 31 +++++++ .../generated/rpc/SessionMcpOauthApi.java | 18 ++++ ...ionMcpOauthHandlePendingRequestParams.java | 34 +++++++ ...ionMcpOauthHandlePendingRequestResult.java | 30 ++++++ .../com/github/copilot/CopilotClient.java | 21 +++-- .../com/github/copilot/CopilotSession.java | 9 +- .../github/copilot/rpc/McpAuthRequest.java | 9 +- .../com/github/copilot/rpc/McpAuthResult.java | 2 +- .../com/github/copilot/rpc/SessionConfig.java | 4 +- .../McpAuthInterestRegistrationTest.java | 93 ++++++++----------- python/copilot/generated/rpc.py | 4 +- python/copilot/generated/session_events.py | 1 + python/test_client.py | 2 +- 21 files changed, 211 insertions(+), 103 deletions(-) create mode 100644 java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java create mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java create mode 100644 java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 16f628f99..e4a87dd7d 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -2260,7 +2260,10 @@ private async Task PumpAsync(Process process, ILogger logger, CancellationToken Buffer.AppendLine(line); } - logger.LogWarning("[CLI] {Line}", line); + if (logger.IsEnabled(LogLevel.Warning)) + { + logger.LogWarning("[CLI] {Line}", line); + } } } catch (Exception e) when (cancellationToken.IsCancellationRequested diff --git a/dotnet/src/JsonRpc.cs b/dotnet/src/JsonRpc.cs index 912d5a529..eeca2bc8e 100644 --- a/dotnet/src/JsonRpc.cs +++ b/dotnet/src/JsonRpc.cs @@ -470,7 +470,10 @@ private void HandleResponse(JsonElement message, JsonElement idProp) } catch (Exception ex) { - _logger.LogWarning(ex, "Inline response callback for request {RequestId} threw", id); + if (_logger.IsEnabled(LogLevel.Warning)) + { + _logger.LogWarning(ex, "Inline response callback for request {RequestId} threw", id); + } pending.TrySetException(ex); return; } diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index f84f0c7de..fb501db9a 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -651,11 +651,13 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) var handler = _mcpAuthHandler; if (handler is null) { - _logger.LogWarning( - "Received MCP OAuth request without a registered MCP auth handler. " + - "SessionId={SessionId}, RequestId={RequestId}", - SessionId, - data.RequestId); + if (_logger.IsEnabled(LogLevel.Warning)) + { + _logger.LogWarning( + "Received MCP OAuth request without a registered MCP auth handler. SessionId={SessionId}, RequestId={RequestId}", + SessionId, + data.RequestId); + } return; } diff --git a/dotnet/test/Unit/SessionEventSerializationTests.cs b/dotnet/test/Unit/SessionEventSerializationTests.cs index db19558b6..f871f5618 100644 --- a/dotnet/test/Unit/SessionEventSerializationTests.cs +++ b/dotnet/test/Unit/SessionEventSerializationTests.cs @@ -325,8 +325,7 @@ public void McpOauthRequiredData_Allows_Missing_Optional_Metadata() } """; - var authEvent = JsonSerializer.Deserialize(json); - Assert.NotNull(authEvent); + var authEvent = Assert.IsType(SessionEvent.FromJson(json)); Assert.Null(authEvent.Data.WwwAuthenticateParams); Assert.Null(authEvent.Data.ResourceMetadata); } diff --git a/go/client_test.go b/go/client_test.go index 03a62cf78..1b2fc2517 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -996,12 +996,12 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { session, err := client.CreateSession(t.Context(), &SessionConfig{ OnPermissionRequest: PermissionHandler.ApproveAll, - OnEvent: func(SessionEvent) {}, + OnEvent: func(SessionEvent) {}, }) if err != nil { t.Fatalf("CreateSession failed: %v", err) } - defer session.Close() + defer session.Disconnect() assertNoMCPAuthInterest(t, requests.snapshot()) assertRequestMethod(t, requests.snapshot(), "session.create") @@ -1021,7 +1021,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { if err != nil { t.Fatalf("CreateSession failed: %v", err) } - defer session.Close() + defer session.Disconnect() snapshot := requests.snapshot() assertRequestMethod(t, snapshot, "session.eventLog.registerInterest") @@ -1048,7 +1048,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { if err != nil { t.Fatalf("CreateSession without auth failed: %v", err) } - defer withoutAuth.Close() + defer withoutAuth.Disconnect() assertNoMCPAuthInterest(t, requests.snapshot()) requests.clear() @@ -1065,7 +1065,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { if err != nil { t.Fatalf("CreateSession with auth failed: %v", err) } - defer withAuth.Close() + defer withAuth.Disconnect() snapshot := requests.snapshot() if snapshot[0].Method != "session.create" { @@ -1083,12 +1083,12 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { withoutAuth, err := client.ResumeSession(t.Context(), "session-without-auth", &ResumeSessionConfig{ OnPermissionRequest: PermissionHandler.ApproveAll, - OnEvent: func(SessionEvent) {}, + OnEvent: func(SessionEvent) {}, }) if err != nil { t.Fatalf("ResumeSession without auth failed: %v", err) } - defer withoutAuth.Close() + defer withoutAuth.Disconnect() assertNoMCPAuthInterest(t, requests.snapshot()) assertRequestMethod(t, requests.snapshot(), "session.resume") @@ -1103,7 +1103,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { if err != nil { t.Fatalf("ResumeSession with auth failed: %v", err) } - defer withAuth.Close() + defer withAuth.Disconnect() snapshot := requests.snapshot() if snapshot[0].Method != "session.eventLog.registerInterest" { diff --git a/go/rpc/zsession_events.go b/go/rpc/zsession_events.go index c57b7e784..285098676 100644 --- a/go/rpc/zsession_events.go +++ b/go/rpc/zsession_events.go @@ -1921,7 +1921,6 @@ type MCPOauthRequiredWwwAuthenticateParams struct { Scope *string `json:"scope,omitempty"` } - // Schema for the `McpServersLoadedServer` type. type MCPServersLoadedServer struct { // Error message if the server failed to connect diff --git a/go/session.go b/go/session.go index 665879e28..e5eb22a70 100644 --- a/go/session.go +++ b/go/session.go @@ -1378,9 +1378,9 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { } } request := MCPAuthRequest{ - RequestID: d.RequestID, - ServerName: d.ServerName, - ServerURL: d.ServerURL, + RequestID: d.RequestID, + ServerName: d.ServerName, + ServerURL: d.ServerURL, StaticClientConfig: staticClientConfig, } if d.ResourceMetadata != nil { diff --git a/go/types.go b/go/types.go index d528a930a..11b2475af 100644 --- a/go/types.go +++ b/go/types.go @@ -325,12 +325,12 @@ type MCPAuthStaticClientConfig struct { // MCPAuthRequest describes an MCP OAuth request that the SDK host can satisfy with a token. type MCPAuthRequest struct { - RequestID string `json:"requestId"` - ServerName string `json:"serverName"` - ServerURL string `json:"serverUrl"` + RequestID string `json:"requestId"` + ServerName string `json:"serverName"` + ServerURL string `json:"serverUrl"` WwwAuthenticateParams *MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams,omitempty"` ResourceMetadata string `json:"resourceMetadata,omitempty"` - StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` + StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` } // MCPAuthToken is host-provided OAuth token data for a pending MCP OAuth request. diff --git a/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java new file mode 100644 index 000000000..8b7b65af4 --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/McpOauthRequiredWwwAuthenticateParams.java @@ -0,0 +1,31 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: session-events.schema.json + +package com.github.copilot.generated; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.processing.Generated; + +/** + * Parsed parameters from the WWW-Authenticate header that the SDK host uses for RFC 9728 protected-resource metadata discovery. + * + * @since 1.0.0 + */ +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record McpOauthRequiredWwwAuthenticateParams( + /** Parsed OAuth error from the WWW-Authenticate header, if present */ + @JsonProperty("error") String error, + /** Parsed resource_metadata URL from the WWW-Authenticate header */ + @JsonProperty("resourceMetadataUrl") String resourceMetadataUrl, + /** Parsed OAuth scope from the WWW-Authenticate header, if present */ + @JsonProperty("scope") String scope +) { +} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java index 59c4e45a1..291b104d9 100644 --- a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthApi.java @@ -62,4 +62,22 @@ public CompletableFuture login(SessionMcpOauthLoginP return caller.invoke("session.mcp.oauth.login", _p, SessionMcpOauthLoginResult.class); } + /** + * SDK-safe MCP OAuth pending-request response. + *

+ * Note: the {@code sessionId} field in the params record is overridden + * by the session-scoped wrapper; any value provided is ignored. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ + @CopilotExperimental + public CompletableFuture handlePendingRequest( + SessionMcpOauthHandlePendingRequestParams params) { + com.fasterxml.jackson.databind.node.ObjectNode _p = MAPPER.valueToTree(params); + _p.put("sessionId", this.sessionId); + return caller.invoke("session.mcp.oauth.handlePendingRequest", _p, + SessionMcpOauthHandlePendingRequestResult.class); + } + } diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java new file mode 100644 index 000000000..c0418e3bb --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestParams.java @@ -0,0 +1,34 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: api.schema.json + +package com.github.copilot.generated.rpc; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; +import javax.annotation.processing.Generated; + +/** + * SDK-safe MCP OAuth pending-request response. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ +@CopilotExperimental +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record SessionMcpOauthHandlePendingRequestParams( + /** Target session identifier */ + @JsonProperty("sessionId") String sessionId, + /** OAuth request identifier from mcp.oauth_required */ + @JsonProperty("requestId") String requestId, + /** Token or cancellation result for the pending OAuth request */ + @JsonProperty("result") Object result +) { +} diff --git a/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java new file mode 100644 index 000000000..def0ad205 --- /dev/null +++ b/java/src/generated/java/com/github/copilot/generated/rpc/SessionMcpOauthHandlePendingRequestResult.java @@ -0,0 +1,30 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: api.schema.json + +package com.github.copilot.generated.rpc; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; +import javax.annotation.processing.Generated; + +/** + * Result for SDK-safe MCP OAuth pending-request responses. + * + * @apiNote This method is experimental and may change in a future version. + * @since 1.0.0 + */ +@CopilotExperimental +@javax.annotation.processing.Generated("copilot-sdk-codegen") +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public record SessionMcpOauthHandlePendingRequestResult( + /** Whether the pending request was found and handled */ + @JsonProperty("success") boolean success +) { +} diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index c6b156926..54c55042d 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -514,9 +514,8 @@ public CompletableFuture createSession(SessionConfig config) { preRegisteredSessionHolder[0] = initializeSession.apply(localSessionId); registeredIdHolder[0] = localSessionId; if (config.getOnMcpAuthRequest() != null) { - preCreateInterest = preRegisteredSessionHolder[0].getRpc().eventLog - .registerInterest(new SessionEventLogRegisterInterestParams(localSessionId, - "mcp.oauth_required")); + preCreateInterest = preRegisteredSessionHolder[0].getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(localSessionId, "mcp.oauth_required")); } } @@ -564,8 +563,9 @@ public CompletableFuture createSession(SessionConfig config) { } long rpcNanos = System.nanoTime(); - return preCreateInterest.thenCompose(ignored -> connection.rpc.invoke("session.create", request, - CreateSessionResponse.class)) + return preCreateInterest + .thenCompose( + ignored -> connection.rpc.invoke("session.create", request, CreateSessionResponse.class)) .thenCompose(response -> { String returnedId = response.sessionId(); LoggingHelpers.logTiming(LOG, Level.FINE, @@ -583,11 +583,12 @@ public CompletableFuture createSession(SessionConfig config) { ? preRegisteredSessionHolder[0] : initializeSession.apply(returnedId); registeredIdHolder[0] = returnedId; - // Local IDs registered before create; server-assigned IDs can only register now. + // Local IDs registered before create; server-assigned IDs can only register + // now. CompletableFuture interest = config.getOnMcpAuthRequest() != null && preRegisteredSessionHolder[0] == null - ? session.getRpc().eventLog.registerInterest( - new SessionEventLogRegisterInterestParams(returnedId, + ? session.getRpc().eventLog + .registerInterest(new SessionEventLogRegisterInterestParams(returnedId, "mcp.oauth_required")) : CompletableFuture.completedFuture(null); session.setWorkspacePath(response.workspacePath()); @@ -714,7 +715,9 @@ public CompletableFuture resumeSession(String sessionId, ResumeS } long rpcNanos = System.nanoTime(); - return interest.thenCompose(ignored -> connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class)) + return interest + .thenCompose( + ignored -> connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class)) .thenCompose(response -> { LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.resumeSession session resume request completed. Elapsed={Elapsed}, SessionId=" diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 9a3807df9..f1e25d200 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -856,7 +856,8 @@ private void handleBroadcastEventAsync(SessionEvent event) { return; } executeMcpAuthAndRespondAsync(new McpAuthRequest(sessionId, data.requestId(), data.serverName(), - data.serverUrl(), data.wwwAuthenticateParams(), data.resourceMetadata(), data.staticClientConfig()), handler); + data.serverUrl(), data.wwwAuthenticateParams(), data.resourceMetadata(), data.staticClientConfig()), + handler); } else if (event instanceof CommandExecuteEvent cmdEvent) { var data = cmdEvent.getData(); if (data == null || data.requestId() == null || data.commandName() == null) { @@ -1053,7 +1054,7 @@ private void executeMcpAuthAndRespondAsync(McpAuthRequest request, McpAuthHandle private void sendMcpAuthResponse(String requestId, McpAuthResult result) { try { Object response; - if (result == null || result.cancelled() || result.token() == null) { + if (result == null || result.isCancelled() || result.token() == null) { response = Map.of("kind", "cancelled"); } else { var token = result.token(); @@ -1071,8 +1072,8 @@ private void sendMcpAuthResponse(String requestId, McpAuthResult result) { } response = tokenResponse; } - getRpc().mcp.oauth - .handlePendingRequest(new SessionMcpOauthHandlePendingRequestParams(sessionId, requestId, response)); + getRpc().mcp.oauth.handlePendingRequest( + new SessionMcpOauthHandlePendingRequestParams(sessionId, requestId, response)); } catch (Exception e) { LOG.log(Level.WARNING, "Error sending MCP auth response for requestId=" + requestId, e); } diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java index a836eaa40..b1d28e2d4 100644 --- a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java @@ -12,12 +12,7 @@ * * @since 1.0.0 */ -public record McpAuthRequest( - String sessionId, - String requestId, - String serverName, - String serverUrl, - McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams, - String resourceMetadata, +public record McpAuthRequest(String sessionId, String requestId, String serverName, String serverUrl, + McpOauthRequiredWwwAuthenticateParams wwwAuthenticateParams, String resourceMetadata, McpOauthRequiredStaticClientConfig staticClientConfig) { } diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java index b8a0acfc4..6b7fda34f 100644 --- a/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java @@ -9,7 +9,7 @@ * * @since 1.0.0 */ -public record McpAuthResult(boolean cancelled, McpAuthToken token) { +public record McpAuthResult(boolean isCancelled, McpAuthToken token) { /** * Creates a token result. * diff --git a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java index 66bd2cbec..8d4f74dfd 100644 --- a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java @@ -606,8 +606,8 @@ public McpAuthHandler getOnMcpAuthRequest() { /** * Sets the MCP OAuth request handler. *

- * When provided, the SDK can satisfy MCP server OAuth requests with host-provided - * token data or cancellation. + * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. * * @param onMcpAuthRequest * the handler diff --git a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java index 30ae0b970..79f7ffa85 100644 --- a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java +++ b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java @@ -33,52 +33,39 @@ class McpAuthInterestRegistrationTest { @Test void mcpOauthRequiredEventExposesOptionalResourceMetadata() throws Exception { - var event = MAPPER.readValue(""" + var data = MAPPER.readValue(""" { - "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", - "timestamp": "2026-03-15T21:26:54.987Z", - "parentId": null, - "type": "mcp.oauth_required", - "data": { - "requestId": "oauth-request", - "serverName": "oauth-server", - "serverUrl": "https://example.com/mcp", - "wwwAuthenticateParams": { - "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" - }, - "resourceMetadata": "{\\"resource\\":\\"https://example.com/mcp\\"}" - } + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\\"resource\\":\\"https://example.com/mcp\\"}" } - """, McpOauthRequiredEvent.class); + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); - assertEquals("{\"resource\":\"https://example.com/mcp\"}", event.getData().resourceMetadata()); - assertNotNull(event.getData().wwwAuthenticateParams()); + assertEquals("{\"resource\":\"https://example.com/mcp\"}", data.resourceMetadata()); + assertNotNull(data.wwwAuthenticateParams()); var withoutMetadata = MAPPER.readValue(""" { - "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", - "timestamp": "2026-03-15T21:26:54.987Z", - "parentId": null, - "type": "mcp.oauth_required", - "data": { - "requestId": "oauth-request", - "serverName": "oauth-server", - "serverUrl": "https://example.com/mcp" - } + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" } - """, McpOauthRequiredEvent.class); + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); - assertNull(withoutMetadata.getData().resourceMetadata()); - assertNull(withoutMetadata.getData().wwwAuthenticateParams()); + assertNull(withoutMetadata.resourceMetadata()); + assertNull(withoutMetadata.wwwAuthenticateParams()); } @Test void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { try (var server = new RecordingRuntime(); var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { - try (var session = client.createSession(new SessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setOnEvent(event -> { + try (var session = client.createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { })).get()) { assertNotNull(session); } @@ -89,10 +76,11 @@ void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exc server.clearRequests(); - try (var session = client.createSession(new SessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture - .completedFuture(McpAuthResult.cancelled()))).get()) { + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { assertNotNull(session); } @@ -107,23 +95,24 @@ void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exc void cloudCreateSessionRegistersMcpAuthInterestAfterCreateOnlyWhenHandlerConfigured() throws Exception { try (var server = new RecordingRuntime(); var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { - var cloud = new CloudSessionOptions() - .setRepository(new CloudSessionRepository().setOwner("github").setName("copilot-sdk").setBranch("main")); + var cloud = new CloudSessionOptions().setRepository( + new CloudSessionRepository().setOwner("github").setName("copilot-sdk").setBranch("main")); - try (var session = client.createSession(new SessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setCloud(cloud)).get()) { + try (var session = client + .createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setCloud(cloud)) + .get()) { assertNotNull(session); } assertNoMcpAuthInterest(server.requests()); server.clearRequests(); - try (var session = client.createSession(new SessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setCloud(cloud) - .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture - .completedFuture(McpAuthResult.cancelled()))).get()) { + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setCloud(cloud).setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { assertNotNull(session); } @@ -139,8 +128,7 @@ void resumeSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exc try (var server = new RecordingRuntime(); var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { try (var session = client.resumeSession("session-without-auth", new ResumeSessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setOnEvent(event -> { + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { })).get()) { assertNotNull(session); } @@ -151,10 +139,11 @@ void resumeSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exc server.clearRequests(); - try (var session = client.resumeSession("session-with-auth", new ResumeSessionConfig() - .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) - .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture - .completedFuture(McpAuthResult.cancelled()))).get()) { + try (var session = client.resumeSession("session-with-auth", + new ResumeSessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { assertNotNull(session); } diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 58b0ef564..ed03a14ca 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Protocol, TypeVar, cast +from typing import Any, Protocol, TypeVar, Union, cast from uuid import UUID import dateutil.parser @@ -2766,7 +2766,7 @@ def from_dict(obj: Any) -> 'MCPOauthHandlePendingRequest': def to_dict(self) -> dict: result: dict = {} result["requestId"] = from_str(self.request_id) - result["result"] = to_class(cast(Any, self.result), self.result) + result["result"] = self.result.to_dict() return result diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index 3f8ee194f..e971f6add 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -7373,6 +7373,7 @@ def session_event_from_dict(s: Any) -> SessionEvent: def session_event_to_dict(x: SessionEvent) -> Any: return x.to_dict() +McpServersLoadedServer = MCPServersLoadedServer __all__ = [ "AbortData", diff --git a/python/test_client.py b/python/test_client.py index 2f851e0b4..edd2db188 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -272,7 +272,7 @@ async def mock_request(method, params, **kwargs): client._client.request = mock_request observed_request = None - def handle_mcp_auth_request(request): + def handle_mcp_auth_request(request, invocation): nonlocal observed_request observed_request = request return {