diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 5a5d34dcd..e19f2a9a1 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -85,6 +85,13 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private List? _modelsCache; private ServerRpc? _serverRpc; + /// + /// Client-global RPC handlers (e.g. the LLM inference provider adapter), + /// built once at construction when the corresponding option is configured and + /// registered on every connection. Null when no client-global API is enabled. + /// + private readonly ClientGlobalApiHandlers? _clientGlobalApis; + private sealed record LifecycleSubscription(Type EventType, Action Handler); /// @@ -165,6 +172,8 @@ public CopilotClient(CopilotClientOptions? options = null) _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; + _clientGlobalApis = BuildClientGlobalApis(); + // Empty mode: validate at construction time that the app supplied a // per-session persistence location. The runtime is mode-agnostic, so // without this check it would silently fall back to ~/.copilot, which @@ -276,6 +285,8 @@ async Task StartCoreAsync(CancellationToken ct) sessionFsTimestamp); } + await ConfigureLlmInferenceAsync(ct); + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.StartAsync complete. Elapsed={Elapsed}", startTimestamp); @@ -1678,6 +1689,39 @@ await Rpc.SessionFs.SetProviderAsync( cancellationToken: cancellationToken); } + /// + /// Builds the client-global RPC handler bag at construction time. Currently + /// only the LLM inference provider adapter is registered; returns null when no + /// client-global API is configured so the registration is skipped entirely. + /// + private ClientGlobalApiHandlers? BuildClientGlobalApis() + { + var handler = _options.LlmInference?.Handler; + if (handler is null) + { + return null; + } + + return new ClientGlobalApiHandlers + { + LlmInference = new LlmInferenceAdapter(handler, () => _serverRpc), + }; + } + + /// + /// Tells the runtime to route its outbound model-layer requests through this + /// client's LLM inference provider. No-op when interception is not configured. + /// + private async Task ConfigureLlmInferenceAsync(CancellationToken cancellationToken) + { + if (_clientGlobalApis?.LlmInference is null) + { + return; + } + + await Rpc.LlmInference.SetProviderAsync(cancellationToken); + } + private void ConfigureSessionFsHandlers(CopilotSession session, Func? createSessionFsHandler) { if (_options.SessionFs is null) @@ -2072,6 +2116,10 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); return session.ClientSessionApis; }); + if (_clientGlobalApis is not null) + { + ClientGlobalApiRegistration.RegisterClientGlobalApiHandlers(rpc, _clientGlobalApis); + } rpc.StartListening(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index d11b8efc7..ff34be16f 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -69,10 +69,14 @@ internal sealed class ConnectRequest /// Long context tier pricing (available for models with extended context windows). public sealed class ModelBillingTokenPricesLongContext { - /// AI Credits cost per billing batch of cached tokens. + /// AI Credits cost per billing batch of cache-read tokens. [JsonPropertyName("cachePrice")] public double? CachePrice { get; set; } + /// AI Credits cost per billing batch of cache-write (cache creation) tokens. + [JsonPropertyName("cacheWritePrice")] + public double? CacheWritePrice { get; set; } + /// Prompt token budget (max_prompt_tokens) for the long context tier. The total context window is this value plus the model's max_output_tokens. [JsonPropertyName("contextMax")] public long? ContextMax { get; set; } @@ -93,10 +97,14 @@ public sealed class ModelBillingTokenPrices [JsonPropertyName("batchSize")] public long? BatchSize { get; set; } - /// AI Credits cost per billing batch of cached tokens. + /// AI Credits cost per billing batch of cache-read tokens. [JsonPropertyName("cachePrice")] public double? CachePrice { get; set; } + /// AI Credits cost per billing batch of cache-write (cache creation) tokens. + [JsonPropertyName("cacheWritePrice")] + public double? CacheWritePrice { get; set; } + /// Prompt token budget (max_prompt_tokens) for the default tier. The total context window is this value plus the model's max_output_tokens. [JsonPropertyName("contextMax")] public long? ContextMax { get; set; } @@ -1141,6 +1149,92 @@ internal sealed class SessionFsSetProviderRequest public string SessionStatePath { get; set; } = string.Empty; } +/// Indicates whether the calling client was registered as the LLM inference provider. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceSetProviderResult +{ + /// Whether the provider was set successfully. + [JsonPropertyName("success")] + public bool Success { get; set; } +} + +/// Whether the start frame was accepted. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseStartResult +{ + /// True when the response start was matched to a pending request; false when unknown. + [JsonPropertyName("accepted")] + public bool Accepted { get; set; } +} + +/// Response head. +[Experimental(Diagnostics.Experimental)] +internal sealed class LlmInferenceHttpResponseStartRequest +{ + /// Gets or sets the headers value. + [JsonPropertyName("headers")] + public IDictionary> Headers { get => field ??= new Dictionary>(); set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// HTTP status code. + [JsonPropertyName("status")] + public long Status { get; set; } + + /// Optional HTTP status reason phrase. + [JsonPropertyName("statusText")] + public string? StatusText { get; set; } +} + +/// Whether the chunk was accepted. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseChunkResult +{ + /// True when the chunk was matched to a pending request; false when unknown. + [JsonPropertyName("accepted")] + public bool Accepted { get; set; } +} + +/// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseChunkError +{ + /// Optional machine-readable error code. + [JsonPropertyName("code")] + public string? Code { get; set; } + + /// Human-readable failure description. + [JsonPropertyName("message")] + public string Message { get; set; } = string.Empty; +} + +/// A response body chunk or terminal error. +[Experimental(Diagnostics.Experimental)] +internal sealed class LlmInferenceHttpResponseChunkRequest +{ + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + [JsonPropertyName("binary")] + public bool? Binary { get; set; } + + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; + + /// When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + [JsonPropertyName("end")] + public bool? End { get; set; } + + /// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + [JsonPropertyName("error")] + public LlmInferenceHttpResponseChunkError? Error { get; set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; +} + /// Pre-resolved working-directory context for session startup. [Experimental(Diagnostics.Experimental)] public sealed class SessionContext @@ -10384,6 +10478,76 @@ public sealed class CanvasProviderInvokeActionRequest public string SessionId { get; set; } = string.Empty; } +/// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartResult +{ +} + +/// The head of an outbound model-layer HTTP request. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartRequest +{ + /// Gets or sets the headers value. + [JsonPropertyName("headers")] + public IDictionary> Headers { get => field ??= new Dictionary>(); set; } + + /// HTTP method, e.g. GET, POST. + [JsonPropertyName("method")] + public string Method { get; set; } = string.Empty; + + /// Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + [JsonPropertyName("sessionId")] + public string? SessionId { get; set; } + + /// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + [JsonPropertyName("transport")] + public LlmInferenceHttpRequestStartTransport? Transport { get; set; } + + /// Absolute request URL. + [JsonPropertyName("url")] + public string Url { get; set; } = string.Empty; +} + +/// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkResult +{ +} + +/// A request body chunk or cancellation signal. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkRequest +{ + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + [JsonPropertyName("binary")] + public bool? Binary { get; set; } + + /// When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + [JsonPropertyName("cancel")] + public bool? Cancel { get; set; } + + /// Optional human-readable reason for the cancellation, propagated for logging. + [JsonPropertyName("cancelReason")] + public string? CancelReason { get; set; } + + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; + + /// When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + [JsonPropertyName("end")] + public bool? End { get; set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; +} + /// Model capability category for grouping in the model picker. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -15947,6 +16111,69 @@ public override void Write(Utf8JsonWriter writer, SessionFsSqliteQueryType value } +/// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. +[Experimental(Diagnostics.Experimental)] +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct LlmInferenceHttpRequestStartTransport : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public LlmInferenceHttpRequestStartTransport(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. + public static LlmInferenceHttpRequestStartTransport Http { get; } = new("http"); + + /// Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. + public static LlmInferenceHttpRequestStartTransport Websocket { get; } = new("websocket"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is LlmInferenceHttpRequestStartTransport other && Equals(other); + + /// + public bool Equals(LlmInferenceHttpRequestStartTransport other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override LlmInferenceHttpRequestStartTransport Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, LlmInferenceHttpRequestStartTransport value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(LlmInferenceHttpRequestStartTransport)); + } + } +} + + /// Provides server-scoped RPC methods (no session required). public sealed class ServerRpc { @@ -16049,6 +16276,12 @@ internal async Task ConnectAsync(string? token = null, Cancellati Interlocked.CompareExchange(ref field, new(_rpc), null) ?? field; + /// LlmInference APIs. + public ServerLlmInferenceApi LlmInference => + field ?? + Interlocked.CompareExchange(ref field, new(_rpc), null) ?? + field; + /// Sessions APIs. public ServerSessionsApi Sessions => field ?? @@ -16630,6 +16863,59 @@ public async Task SetProviderAsync(string initialCwd } } +/// Provides server-scoped LlmInference APIs. +[Experimental(Diagnostics.Experimental)] +public sealed class ServerLlmInferenceApi +{ + private readonly JsonRpc _rpc; + + internal ServerLlmInferenceApi(JsonRpc rpc) + { + _rpc = rpc; + } + + /// Registers an SDK client as the LLM inference callback provider. + /// The to monitor for cancellation requests. The default is . + /// Indicates whether the calling client was registered as the LLM inference provider. + public async Task SetProviderAsync(CancellationToken cancellationToken = default) + { + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.setProvider", [], cancellationToken); + } + + /// Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames. + /// Matches the requestId from the originating httpRequestStart frame. + /// HTTP status code. + /// The headers parameter. + /// Optional HTTP status reason phrase. + /// The to monitor for cancellation requests. The default is . + /// Whether the start frame was accepted. + public async Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(requestId); + ArgumentNullException.ThrowIfNull(headers); + + var request = new LlmInferenceHttpResponseStartRequest { RequestId = requestId, Status = status, Headers = headers, StatusText = statusText }; + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.httpResponseStart", [request], cancellationToken); + } + + /// Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError. + /// Matches the requestId from the originating httpRequestStart frame. + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + /// When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + /// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + /// The to monitor for cancellation requests. The default is . + /// Whether the chunk was accepted. + public async Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(requestId); + ArgumentNullException.ThrowIfNull(data); + + var request = new LlmInferenceHttpResponseChunkRequest { RequestId = requestId, Data = data, Binary = binary, End = end, Error = error }; + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.httpResponseChunk", [request], cancellationToken); + } +} + /// Provides server-scoped Sessions APIs. [Experimental(Diagnostics.Experimental)] public sealed class ServerSessionsApi @@ -20043,6 +20329,53 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncHandles `llmInference` client global API methods. +[Experimental(Diagnostics.Experimental)] +public interface ILlmInferenceHandler +{ + /// Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + /// The head of an outbound model-layer HTTP request. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default); + /// Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + /// A request body chunk or cancellation signal. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default); +} + +/// Provides all client global API handler groups for a connection. +public sealed class ClientGlobalApiHandlers +{ + /// Optional handler for LlmInference client global API methods. + public ILlmInferenceHandler? LlmInference { get; set; } +} + +/// Registers client global API handlers on a JSON-RPC connection. +internal static class ClientGlobalApiRegistration +{ + /// + /// Registers handlers for server-to-client global API calls. + /// Unlike client session APIs, these methods carry no implicit + /// sessionId dispatch key — a single set of handlers serves the + /// entire connection. + /// + public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers) + { + rpc.SetLocalRpcMethod("llmInference.httpRequestStart", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestStartAsync(request, cancellationToken); + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("llmInference.httpRequestChunk", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestChunkAsync(request, cancellationToken); + }), singleObjectParam: true); + } +} + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -20166,6 +20499,7 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncFor HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("badRequestKind")] + public ModelCallFailureBadRequestKind? BadRequestKind { get; set; } + /// Duration of the failed API call in milliseconds. [JsonConverter(typeof(MillisecondsTimeSpanConverter))] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("durationMs")] public TimeSpan? Duration { get; set; } + /// For HTTP 400 failures only: the `code` from the CAPI error envelope (e.g. 'model_max_prompt_tokens_exceeded') identifying which deterministic validation failure occurred. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("errorCode")] + public string? ErrorCode { get; set; } + /// Raw provider/runtime error message for restricted telemetry. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("errorMessage")] public string? ErrorMessage { get; set; } + /// For HTTP 400 failures only: the `type` from the CAPI error envelope (e.g. 'websocket_error'), a coarser companion to errorCode for envelopes that carry no code. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("errorType")] + public string? ErrorType { get; set; } + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("initiator")] @@ -3827,6 +3842,16 @@ public sealed partial class AttachmentFile : Attachment [JsonIgnore] public override string Type => "file"; + /// Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("assetId")] + public string? AssetId { get; set; } + + /// Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("byteLength")] + public long? ByteLength { get; set; } + /// User-facing display name for the attachment. [JsonPropertyName("displayName")] public required string DisplayName { get; set; } @@ -3836,6 +3861,16 @@ public sealed partial class AttachmentFile : Attachment [JsonPropertyName("lineRange")] public AttachmentFileLineRange? LineRange { get; set; } + /// Internal: MIME type of the file's model-facing bytes (post-resize for images). Set when the file's bytes are interned to an asset. Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("mimeType")] + public string? MimeType { get; set; } + + /// Internal: why model-facing bytes are absent from persistence. Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("omittedReason")] + public OmittedBinaryOmittedReason? OmittedReason { get; set; } + /// Absolute file path. [JsonPropertyName("path")] public required string Path { get; set; } @@ -3959,10 +3994,21 @@ public sealed partial class AttachmentBlob : Attachment [JsonIgnore] public override string Type => "blob"; - /// Base64-encoded content. + /// Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("assetId")] + public string? AssetId { get; set; } + + /// Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("byteLength")] + public long? ByteLength { get; set; } + + /// Base64-encoded content. Present on input and for external consumers; replaced by an internal `assetId` reference in persisted events when interned to a content-addressed asset. [Base64String] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("data")] - public required string Data { get; set; } + public string? Data { get; set; } /// User-facing display name for the attachment. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] @@ -3972,6 +4018,11 @@ public sealed partial class AttachmentBlob : Attachment /// MIME type of the inline data. [JsonPropertyName("mimeType")] public required string MimeType { get; set; } + + /// Internal: why model-facing bytes are absent from persistence. Absent externally. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("omittedReason")] + public OmittedBinaryOmittedReason? OmittedReason { get; set; } } /// Structured context contributed by an extension. Composer pills displayed in the host are forwarded back through session.send.attachments, then rendered into the model prompt as an <extension_context> XML block. @@ -7086,6 +7137,67 @@ public override void Write(Utf8JsonWriter writer, UserMessageAgentMode value, Js } } +/// Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable. +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct OmittedBinaryOmittedReason : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public OmittedBinaryOmittedReason(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// Bytes exceeded the session's inline size limit. + public static OmittedBinaryOmittedReason TooLarge { get; } = new("too_large"); + + /// The referenced binary asset could not be found (e.g. a truncated log). + public static OmittedBinaryOmittedReason AssetUnavailable { get; } = new("asset_unavailable"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(OmittedBinaryOmittedReason left, OmittedBinaryOmittedReason right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(OmittedBinaryOmittedReason left, OmittedBinaryOmittedReason right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is OmittedBinaryOmittedReason other && Equals(other); + + /// + public bool Equals(OmittedBinaryOmittedReason other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override OmittedBinaryOmittedReason Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, OmittedBinaryOmittedReason value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(OmittedBinaryOmittedReason)); + } + } +} + /// Type of GitHub reference. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -7278,6 +7390,67 @@ public override void Write(Utf8JsonWriter writer, AssistantUsageApiEndpoint valu } } +/// For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct ModelCallFailureBadRequestKind : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public ModelCallFailureBadRequestKind(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// The 400 response carried no error body (transient gateway/proxy signature). + public static ModelCallFailureBadRequestKind Bodyless { get; } = new("bodyless"); + + /// The 400 response carried a structured CAPI error envelope (deterministic validation failure). + public static ModelCallFailureBadRequestKind StructuredError { get; } = new("structured_error"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(ModelCallFailureBadRequestKind left, ModelCallFailureBadRequestKind right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(ModelCallFailureBadRequestKind left, ModelCallFailureBadRequestKind right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is ModelCallFailureBadRequestKind other && Equals(other); + + /// + public bool Equals(ModelCallFailureBadRequestKind other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override ModelCallFailureBadRequestKind Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, ModelCallFailureBadRequestKind value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(ModelCallFailureBadRequestKind)); + } + } +} + /// Where the failed model call originated. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -7528,67 +7701,6 @@ public override void Write(Utf8JsonWriter writer, PersistedBinaryImageType value } } -/// Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable. -[JsonConverter(typeof(Converter))] -[DebuggerDisplay("{Value,nq}")] -public readonly struct OmittedBinaryOmittedReason : IEquatable -{ - private readonly string? _value; - - /// Initializes a new instance of the struct. - /// The value to associate with this . - [JsonConstructor] - public OmittedBinaryOmittedReason(string value) - { - ArgumentException.ThrowIfNullOrWhiteSpace(value); - _value = value; - } - - /// Gets the value associated with this . - public string Value => _value ?? string.Empty; - - /// Bytes exceeded the session's inline size limit. - public static OmittedBinaryOmittedReason TooLarge { get; } = new("too_large"); - - /// The referenced binary asset could not be found (e.g. a truncated log). - public static OmittedBinaryOmittedReason AssetUnavailable { get; } = new("asset_unavailable"); - - /// Returns a value indicating whether two instances are equivalent. - public static bool operator ==(OmittedBinaryOmittedReason left, OmittedBinaryOmittedReason right) => left.Equals(right); - - /// Returns a value indicating whether two instances are not equivalent. - public static bool operator !=(OmittedBinaryOmittedReason left, OmittedBinaryOmittedReason right) => !(left == right); - - /// - public override bool Equals(object? obj) => obj is OmittedBinaryOmittedReason other && Equals(other); - - /// - public bool Equals(OmittedBinaryOmittedReason other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); - - /// - public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); - - /// - public override string ToString() => Value; - - /// Provides a for serializing instances. - [EditorBrowsable(EditorBrowsableState.Never)] - public sealed class Converter : JsonConverter - { - /// - public override OmittedBinaryOmittedReason Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); - } - - /// - public override void Write(Utf8JsonWriter writer, OmittedBinaryOmittedReason value, JsonSerializerOptions options) - { - GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(OmittedBinaryOmittedReason)); - } - } -} - /// Binary result type discriminator. Use "image" for images and "resource" for other binary data. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] diff --git a/dotnet/src/GitHub.Copilot.SDK.csproj b/dotnet/src/GitHub.Copilot.SDK.csproj index 7a9fa2bdc..f37982155 100644 --- a/dotnet/src/GitHub.Copilot.SDK.csproj +++ b/dotnet/src/GitHub.Copilot.SDK.csproj @@ -27,6 +27,10 @@ $(NoWarn);GHCP001 + + + + true diff --git a/dotnet/src/LlmInferenceProvider.cs b/dotnet/src/LlmInferenceProvider.cs new file mode 100644 index 000000000..73b121f17 --- /dev/null +++ b/dotnet/src/LlmInferenceProvider.cs @@ -0,0 +1,628 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Channels; + +namespace GitHub.Copilot; + +/// +/// Transport the runtime would otherwise use to issue an intercepted +/// model-layer request. +/// +[Experimental(Diagnostics.Experimental)] +public enum LlmInferenceTransport +{ + /// + /// Plain HTTP or a streamed SSE response. Each body chunk is an opaque + /// byte range. + /// + Http, + + /// + /// Full-duplex WebSocket channel. Each request-body chunk is one inbound + /// WebSocket message and each response-body write is one outbound message. + /// + WebSocket, +} + +/// +/// An outbound model-layer HTTP (or WebSocket) request the runtime is asking +/// the SDK consumer to service on its behalf. +/// +/// +/// This is a low-level shape: URL / method / headers verbatim, body bytes +/// delivered as an async sequence, and the response delivered through the +/// sink. The runtime does not classify the request +/// (no provider type, endpoint kind, or wire API); consumers that need that +/// information derive it from the URL / headers themselves. +/// +internal sealed class LlmInferenceRequest +{ + /// Opaque runtime-minted id, stable across the request lifecycle. + public required string RequestId { get; init; } + + /// + /// Id of the runtime session that triggered this request, when one is in + /// scope. for out-of-session requests (e.g. startup + /// model catalog). + /// + public string? SessionId { get; init; } + + /// HTTP method (GET, POST, ...). + public required string Method { get; init; } + + /// Absolute request URL. + public required string Url { get; init; } + + /// HTTP request headers, lowercased names mapped to multi-valued lists. + public required IReadOnlyDictionary> Headers { get; init; } + + /// + /// Transport the runtime would otherwise use. + /// covers plain HTTP and SSE responses; + /// indicates a full-duplex message channel. Consumers branch on this to + /// decide whether to service the request with an HTTP client or a WebSocket + /// client. + /// + public LlmInferenceTransport Transport { get; init; } + + /// + /// Request body bytes, yielded as they arrive from the runtime. Always + /// enumerable; an empty body yields zero chunks before completing. For + /// WebSocket transport each element is one inbound message. + /// + public required IAsyncEnumerable> RequestBody { get; init; } + + /// + /// Cancelled when the runtime aborts this in-flight request (e.g. the agent + /// turn was aborted upstream). Pass it straight to HttpClient.SendAsync + /// / your transport so the upstream call is torn down too. After it fires, + /// writes to are ignored. + /// + public CancellationToken CancellationToken { get; init; } + + /// + /// Sink the consumer writes the upstream response into. Call + /// exactly once before + /// writing body chunks, then zero or more + /// + /// calls, and finish with or + /// . + /// + public required LlmInferenceResponseSink ResponseBody { get; init; } +} + +/// Response head passed to . +internal sealed class LlmInferenceResponseInit +{ + /// HTTP status code (101 acknowledges a WebSocket upgrade). + public int Status { get; init; } + + /// Optional HTTP status reason phrase. + public string? StatusText { get; init; } + + /// Response headers, lowercased names mapped to multi-valued lists. + public IReadOnlyDictionary>? Headers { get; init; } +} + +/// +/// Sink the consumer writes the upstream response into. The state machine is +/// strict: once → zero or more WriteAsync → +/// exactly one of or . Calling +/// out of order throws. +/// +internal abstract class LlmInferenceResponseSink +{ + /// Sends the response head (status + headers) back to the runtime. + public abstract Task StartAsync(LlmInferenceResponseInit init); + + /// Sends a binary body chunk (base64-encoded on the wire). + public abstract Task WriteAsync(ReadOnlyMemory data); + + /// Sends a UTF-8 text body chunk. + public abstract Task WriteAsync(string text); + + /// Marks end-of-stream cleanly. + public abstract Task EndAsync(); + + /// Marks end-of-stream with a transport-level failure. + public abstract Task ErrorAsync(string message, string? code = null); +} + +/// +/// Internal seam implemented by and consumed by +/// . The single callback handles both buffered +/// and streaming responses — the implementer calls +/// zero +/// or more times before . +/// +/// +/// Not part of the public API: consumers subclass +/// rather than implementing this directly. It exists so the adapter can drive any +/// handler through one uniform entry point. +/// +internal interface ILlmInferenceProvider +{ + /// + /// Invoked by the adapter once per outbound LLM request. The implementer is + /// responsible for eventually calling either + /// or + /// ; failing to do so leaks + /// runtime state. Throwing surfaces a transport-level failure to the runtime + /// (equivalent to ResponseBody.ErrorAsync(...) when + /// has not yet been called). + /// + Task OnLlmRequestAsync(LlmInferenceRequest request); +} + +/// +/// Adapts an into the generated +/// shape consumed by the SDK's RPC +/// dispatcher. +/// +/// +/// Maintains a per-requestId state table: each httpRequestStart +/// allocates a body channel + response sink and fires +/// in the background. +/// Subsequent httpRequestChunk frames are routed into the channel. The +/// sink translates Start / Write / End / Error calls +/// into outbound llmInference.httpResponseStart / +/// llmInference.httpResponseChunk calls. +/// +internal sealed class LlmInferenceAdapter : ILlmInferenceHandler +{ + private readonly ILlmInferenceProvider _provider; + private readonly Func _getChannel; + private readonly ConcurrentDictionary _pending = new(StringComparer.Ordinal); + + // Defense-in-depth backstop: chunks that arrive before their start frame + // (a reordering the runtime's single ordered dispatch should make + // impossible) are staged here and drained the moment httpRequestStart + // registers the matching state, so a body byte is never silently dropped. + private readonly ConcurrentDictionary> _staged = new(StringComparer.Ordinal); + + internal LlmInferenceAdapter(ILlmInferenceProvider provider, Func getServerRpc) + : this(provider, WrapServerRpc(getServerRpc ?? throw new ArgumentNullException(nameof(getServerRpc)))) + { + } + + internal LlmInferenceAdapter(ILlmInferenceProvider provider, Func getChannel) + { + _provider = provider ?? throw new ArgumentNullException(nameof(provider)); + _getChannel = getChannel ?? throw new ArgumentNullException(nameof(getChannel)); + } + + /// + /// Adapts a getter into a response-channel getter, + /// caching the wrapper so a new one is allocated only when the underlying + /// connection changes (e.g. reconnect). + /// + private static Func WrapServerRpc(Func getServerRpc) + { + ServerRpc? cachedRpc = null; + ILlmInferenceResponseChannel? cachedChannel = null; + return () => + { + var rpc = getServerRpc(); + if (rpc is null) + { + return null; + } + + if (!ReferenceEquals(rpc, cachedRpc)) + { + cachedRpc = rpc; + cachedChannel = new ServerRpcResponseChannel(rpc); + } + + return cachedChannel; + }; + } + + public Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + var state = new PendingState(); + _pending[request.RequestId] = state; + + if (_staged.TryRemove(request.RequestId, out var stagedChunks)) + { + foreach (var chunk in stagedChunks) + { + RouteChunk(state, chunk); + } + } + + var sink = new AdapterResponseSink(request.RequestId, state, _getChannel, _pending); + state.Sink = sink; + + var transport = request.Transport == LlmInferenceHttpRequestStartTransport.Websocket + ? LlmInferenceTransport.WebSocket + : LlmInferenceTransport.Http; + + var llmRequest = new LlmInferenceRequest + { + RequestId = request.RequestId, + SessionId = request.SessionId, + Method = request.Method, + Url = request.Url, + Headers = ToReadOnlyHeaders(request.Headers), + Transport = transport, + RequestBody = state.Body.ReadAllAsync(state.Abort.Token), + CancellationToken = state.Abort.Token, + ResponseBody = sink, + }; + + // Return from httpRequestStart immediately (after registering state) so + // the runtime's RPC reply is not gated on the consumer's I/O. The actual + // provider work runs asynchronously. + _ = RunProviderAsync(llmRequest, state, sink); + + return Task.FromResult(new LlmInferenceHttpRequestStartResult()); + } + + public Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + if (_pending.TryGetValue(request.RequestId, out var state)) + { + RouteChunk(state, request); + } + else + { + _staged.AddOrUpdate( + request.RequestId, + _ => [request], + (_, list) => + { + list.Add(request); + return list; + }); + } + + return Task.FromResult(new LlmInferenceHttpRequestChunkResult()); + } + + private async Task RunProviderAsync(LlmInferenceRequest request, PendingState state, AdapterResponseSink sink) + { + try + { + await _provider.OnLlmRequestAsync(request).ConfigureAwait(false); + if (!state.Finished) + { + await FailViaSink( + sink, + state, + "LLM inference provider returned without finalising the response (call ResponseBody.EndAsync() or .ErrorAsync()).").ConfigureAwait(false); + } + } + catch (Exception ex) + { + if (state.Cancelled || state.Abort.IsCancellationRequested) + { + // The runtime already cancelled this request; the provider's + // throw is just the abort propagating out of its upstream call. + await FinishCancelled(sink, state).ConfigureAwait(false); + return; + } + + await FailViaSink(sink, state, ex.Message).ConfigureAwait(false); + } + } + + private static async Task FailViaSink(AdapterResponseSink sink, PendingState state, string message) + { + if (state.Finished) + { + return; + } + + try + { + if (!state.Started) + { + await sink.StartAsync(new LlmInferenceResponseInit { Status = 502 }).ConfigureAwait(false); + } + + await sink.ErrorAsync(message).ConfigureAwait(false); + } + catch + { + // Best-effort — the connection may already be dead. + } + } + + private static async Task FinishCancelled(AdapterResponseSink sink, PendingState state) + { + if (state.Finished) + { + return; + } + + try + { + if (!state.Started) + { + await sink.StartAsync(new LlmInferenceResponseInit { Status = 499 }).ConfigureAwait(false); + } + + await sink.ErrorAsync("Request cancelled by runtime", "cancelled").ConfigureAwait(false); + } + catch + { + // Best-effort — the runtime already dropped the request on cancel. + } + } + + private static void RouteChunk(PendingState state, LlmInferenceHttpRequestChunkRequest chunk) + { + if (chunk.Cancel == true) + { + state.Cancelled = true; + state.Abort.Cancel(); + state.Body.PushCancel(chunk.CancelReason); + return; + } + + if (!string.IsNullOrEmpty(chunk.Data)) + { + state.Body.PushChunk(DecodeChunkData(chunk.Data, chunk.Binary == true)); + } + + if (chunk.End == true) + { + state.Body.PushEnd(); + } + } + + private static byte[] DecodeChunkData(string data, bool binary) => + binary ? Convert.FromBase64String(data) : Encoding.UTF8.GetBytes(data); + + private static Dictionary> ToReadOnlyHeaders(IDictionary> headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var (name, values) in headers) + { + result[name] = values as IReadOnlyList ?? [.. values]; + } + + return result; + } + + private sealed class PendingState + { + public BodyChannel Body { get; } = new(); + + public CancellationTokenSource Abort { get; } = new(); + + public bool Started { get; set; } + + public bool Finished { get; set; } + + public bool Cancelled { get; set; } + + public AdapterResponseSink? Sink { get; set; } + } + + /// + /// An unbounded channel of request-body items exposed as an + /// of byte chunks. A cancel item surfaces + /// as an out of the enumerator so + /// the consumer's upstream call is torn down. + /// + private sealed class BodyChannel + { + private readonly Channel _channel = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); + + public void PushChunk(byte[] data) => _channel.Writer.TryWrite(new Item { Chunk = data }); + + public void PushEnd() => _channel.Writer.TryWrite(new Item { End = true }); + + public void PushCancel(string? reason) => _channel.Writer.TryWrite(new Item { Cancel = true, CancelReason = reason }); + + public async IAsyncEnumerable> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (await _channel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (_channel.Reader.TryRead(out var item)) + { + if (item.Cancel) + { + _channel.Writer.TryComplete(); + throw new OperationCanceledException( + item.CancelReason is null + ? "Request cancelled by runtime" + : $"Request cancelled by runtime: {item.CancelReason}"); + } + + if (item.End) + { + _channel.Writer.TryComplete(); + yield break; + } + + if (item.Chunk is { Length: > 0 }) + { + yield return item.Chunk; + } + } + } + } + + private struct Item + { + public byte[]? Chunk; + public bool End; + public bool Cancel; + public string? CancelReason; + } + } + + private sealed class AdapterResponseSink( + string requestId, + PendingState state, + Func getChannel, + ConcurrentDictionary pending) : LlmInferenceResponseSink + { + public override async Task StartAsync(LlmInferenceResponseInit init) + { + ArgumentNullException.ThrowIfNull(init); + + if (state.Started) + { + throw new InvalidOperationException("LLM inference response sink StartAsync() called twice."); + } + + if (state.Finished) + { + throw new InvalidOperationException("LLM inference response sink already finished."); + } + + state.Started = true; + var result = await Channel() + .HttpResponseStartAsync(requestId, init.Status, ToWireHeaders(init.Headers), init.StatusText) + .ConfigureAwait(false); + if (!result.Accepted) + { + RejectedByRuntime(); + } + } + + public override Task WriteAsync(ReadOnlyMemory data) => + WriteChunk(Convert.ToBase64String(data.ToArray()), binary: true); + + public override Task WriteAsync(string text) + { + ArgumentNullException.ThrowIfNull(text); + return WriteChunk(text, binary: false); + } + + public override async Task EndAsync() + { + if (state.Finished) + { + return; + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + await Channel().HttpResponseChunkAsync(requestId, string.Empty, end: true).ConfigureAwait(false); + } + + public override async Task ErrorAsync(string message, string? code = null) + { + ArgumentNullException.ThrowIfNull(message); + + if (state.Finished) + { + return; + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + await Channel() + .HttpResponseChunkAsync( + requestId, + string.Empty, + end: true, + error: new LlmInferenceHttpResponseChunkError { Message = message, Code = code }) + .ConfigureAwait(false); + } + + private async Task WriteChunk(string data, bool binary) + { + if (state.Cancelled) + { + throw new InvalidOperationException("LLM inference request was cancelled by the runtime."); + } + + if (!state.Started) + { + throw new InvalidOperationException("LLM inference response sink WriteAsync() called before StartAsync()."); + } + + if (state.Finished) + { + throw new InvalidOperationException("LLM inference response sink WriteAsync() called after EndAsync()/ErrorAsync()."); + } + + var result = await Channel() + .HttpResponseChunkAsync(requestId, data, binary: binary, end: false) + .ConfigureAwait(false); + if (!result.Accepted) + { + RejectedByRuntime(); + } + } + + private ILlmInferenceResponseChannel Channel() => + getChannel() ?? throw new InvalidOperationException("LLM inference response sink used after RPC connection closed."); + + // The runtime acknowledges every response frame with accepted; accepted: + // false means it has dropped the request (e.g. it cancelled), so we abort + // the provider's upstream work and stop emitting. + private void RejectedByRuntime() + { + if (!state.Cancelled) + { + state.Cancelled = true; + state.Abort.Cancel(); + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + throw new InvalidOperationException("LLM inference response was rejected by the runtime (request no longer active)."); + } + + private static Dictionary> ToWireHeaders(IReadOnlyDictionary>? headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + if (headers is null) + { + return result; + } + + foreach (var (name, values) in headers) + { + result[name] = values as IList ?? [.. values]; + } + + return result; + } + } +} + +/// +/// Minimal seam over the runtime-bound llmInference server API the +/// adapter uses to push response frames back to the runtime. Extracted as an +/// interface so the adapter can be unit-tested without a live JSON-RPC +/// connection. +/// +internal interface ILlmInferenceResponseChannel +{ + Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null); + + Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null); +} + +/// +/// Production backed by the generated +/// client. +/// +internal sealed class ServerRpcResponseChannel(ServerRpc serverRpc) : ILlmInferenceResponseChannel +{ + public Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null) => + serverRpc.LlmInference.HttpResponseStartAsync(requestId, status, headers, statusText); + + public Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null) => + serverRpc.LlmInference.HttpResponseChunkAsync(requestId, data, binary, end, error); +} diff --git a/dotnet/src/LlmRequestHandler.cs b/dotnet/src/LlmRequestHandler.cs new file mode 100644 index 000000000..b44cb9130 --- /dev/null +++ b/dotnet/src/LlmRequestHandler.cs @@ -0,0 +1,747 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Text; + +namespace GitHub.Copilot; + +/// +/// Per-request context handed to every hook. +/// Mirrors the subset of fields that are +/// stable across the request lifetime, letting overrides observe routing / +/// cancellation without re-plumbing the underlying request. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmRequestContext +{ + /// Opaque runtime-minted id, stable across the request lifecycle. + public required string RequestId { get; init; } + + /// Runtime session id that triggered the request, if any. + public string? SessionId { get; init; } + + /// Transport the runtime would otherwise use. + public LlmInferenceTransport Transport { get; init; } + + /// Original request URL. + public required string Url { get; init; } + + /// Original request headers. + public required IReadOnlyDictionary> Headers { get; init; } + + /// + /// Cancelled when the runtime aborts this in-flight request. Subclasses that + /// issue their own I/O should pass this through so the upstream call is torn + /// down too. + /// + public CancellationToken CancellationToken { get; init; } + + internal LlmWebSocketResponseBridge? WebSocketResponse { get; set; } +} + +/// A single WebSocket message exchanged through a hook. +[Experimental(Diagnostics.Experimental)] +public readonly struct LlmWebSocketMessage(ReadOnlyMemory data, bool isBinary) +{ + /// The message payload bytes. + public ReadOnlyMemory Data { get; } = data; + + /// True for a binary frame; false for a UTF-8 text frame. + public bool IsBinary { get; } = isBinary; + + /// Decodes the payload as UTF-8 text. + public string GetText() => Encoding.UTF8.GetString(Data.ToArray()); + + /// Creates a text message from a UTF-8 string. + public static LlmWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false); + + /// Creates a binary message from raw bytes. + public static LlmWebSocketMessage Binary(ReadOnlyMemory data) => new(data, isBinary: true); +} + +/// +/// Terminal status for a callback-owned WebSocket connection. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmWebSocketCloseStatus +{ + /// The close description, if any. + public string? Description { get; init; } + + /// + /// Optional error code surfaced to the runtime when the close is a failure + /// rather than a clean end-of-stream. + /// + public string? ErrorCode { get; init; } + + /// The error that terminated the connection, if any. + public Exception? Error { get; init; } + + /// Shared normal-closure instance. + public static LlmWebSocketCloseStatus NormalClosure { get; } = new(); +} + +/// +/// Per-connection WebSocket handler returned by +/// . +/// +[Experimental(Diagnostics.Experimental)] +public abstract class CopilotWebSocketHandler : IAsyncDisposable +{ + private readonly TaskCompletionSource _completion = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _closed; + private bool _suppressCloseOnDispose; + + /// Request context for this WebSocket connection. + protected LlmRequestContext Context { get; } + + internal Task Completion => _completion.Task; + + /// + /// Initializes a per-connection handler for the supplied request context. + /// + protected CopilotWebSocketHandler(LlmRequestContext context) + { + Context = context; + _ = context.WebSocketResponse ?? throw new InvalidOperationException("WebSocket response bridge is not attached."); + } + + /// + /// Send a message from the runtime to the upstream connection. + /// + public abstract Task SendRequestMessageAsync(LlmWebSocketMessage message); + + /// + /// Send a message from the upstream connection back to the runtime. + /// Override to mutate or duplicate messages; call base to emit. + /// + public virtual Task SendResponseMessageAsync(LlmWebSocketMessage message) => + Context.WebSocketResponse!.WriteAsync(message); + + /// + /// Close the connection and finalise the runtime-facing response. + /// + public virtual async Task CloseAsync(LlmWebSocketCloseStatus status) + { + if (Interlocked.Exchange(ref _closed, 1) != 0) + { + return; + } + + if (status.Error is not null) + { + await Context.WebSocketResponse! + .ErrorAsync(status.Description ?? status.Error.Message, status.ErrorCode) + .ConfigureAwait(false); + } + else + { + await Context.WebSocketResponse!.EndAsync().ConfigureAwait(false); + } + + _completion.TrySetResult(status); + } + + internal void SuppressCloseOnDispose() => _suppressCloseOnDispose = true; + + internal virtual Task OpenAsync() => Task.CompletedTask; + + /// + public virtual async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + if (!_suppressCloseOnDispose && Volatile.Read(ref _closed) == 0) + { + await CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + } +} + +/// +/// Default pass-through WebSocket handler. Opens the real upstream socket and +/// relays messages unchanged unless a subclass overrides the send methods. +/// +[Experimental(Diagnostics.Experimental)] +public class ForwardingWebSocketHandler : CopilotWebSocketHandler +{ + private readonly string _url; + private readonly IReadOnlyDictionary> _headers; + private WebSocket? _upstream; + private CancellationTokenSource? _pumpCts; + private Task? _responsePump; + + /// + /// Initializes a forwarding handler that will open the upstream socket on + /// demand using the supplied URL/headers (or the values from + /// when omitted). + /// + public ForwardingWebSocketHandler( + LlmRequestContext context, + string? url = null, + IReadOnlyDictionary>? headers = null) + : base(context) + { + _url = url ?? context.Url; + _headers = headers ?? context.Headers; + } + + /// + /// Opens the upstream socket and starts the built-in response pump. + /// + internal override async Task OpenAsync() + { + if (_upstream is not null) + { + return; + } + + var socket = new ClientWebSocket(); + foreach (var (name, values) in _headers) + { + if (s_forbiddenRequestHeaders.Contains(name)) + { + continue; + } + + try + { + socket.Options.SetRequestHeader(name, string.Join(", ", values)); + } + catch + { + // Some headers are managed by the handshake; ignore rejections. + } + } + + await socket.ConnectAsync(LlmWebSocketHelpers.ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false); + _upstream = socket; + _pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken); + _responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token); + } + + /// + /// Sends a message from the runtime to the upstream connection. Subclasses may override to mutate messages. + /// + /// The message to send. + /// A representing the asynchronous operation. + public override Task SendRequestMessageAsync(LlmWebSocketMessage message) + { + if (_upstream?.State != WebSocketState.Open) + { + return Task.CompletedTask; + } + + var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; + return _upstream.SendAsync( + new ArraySegment(message.Data.ToArray()), + type, + endOfMessage: true, + Context.CancellationToken); + } + + /// + public override async Task CloseAsync(LlmWebSocketCloseStatus status) + { + _pumpCts?.Cancel(); + if (_upstream is not null) + { + await LlmWebSocketHelpers.CloseWebSocketQuietlyAsync(_upstream).ConfigureAwait(false); + } + await base.CloseAsync(status).ConfigureAwait(false); + } + + /// + public override async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + try + { + await base.DisposeAsync().ConfigureAwait(false); + } + finally + { + _pumpCts?.Cancel(); + _pumpCts?.Dispose(); + _upstream?.Dispose(); + if (_responsePump is not null) + { + await LlmWebSocketHelpers.ObserveQuietlyAsync(_responsePump).ConfigureAwait(false); + } + } + } + + private async Task PumpResponsesAsync(CancellationToken cancellationToken) + { + if (_upstream is null) + { + return; + } + + try + { + while (_upstream.State == WebSocketState.Open) + { + var message = await LlmWebSocketHelpers.ReceiveMessageAsync(_upstream, cancellationToken).ConfigureAwait(false); + if (message is null) + { + break; + } + + await SendResponseMessageAsync(message.Value).ConfigureAwait(false); + } + + await CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + catch (OperationCanceledException) when (Context.CancellationToken.IsCancellationRequested) + { + // Runtime-side cancellation aborts the request pump; the outer + // handler rethrows that cancellation rather than finalising here. + } + catch (Exception ex) + { + await CloseAsync(new LlmWebSocketCloseStatus + { + Description = ex.Message, + Error = ex, + }).ConfigureAwait(false); + } + } + + // Computed/managed by the HTTP/WS stack; forwarding them verbatim either + // throws or corrupts the request. + private static readonly HashSet s_forbiddenRequestHeaders = new(StringComparer.OrdinalIgnoreCase) + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + }; +} + +/// +/// Base class for SDK consumers who want to observe or mutate the LLM inference +/// requests the runtime issues. +/// +[Experimental(Diagnostics.Experimental)] +public class LlmRequestHandler : ILlmInferenceProvider +{ + private static readonly HttpClient s_sharedHttpClient = new(); + + // Computed/managed by the HTTP stack; forwarding them verbatim either throws + // or corrupts the request. + private static readonly HashSet s_forbiddenRequestHeaders = new(StringComparer.OrdinalIgnoreCase) + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + }; + + /// + async Task ILlmInferenceProvider.OnLlmRequestAsync(LlmInferenceRequest request) + { + ArgumentNullException.ThrowIfNull(request); + + var wsResponse = new LlmWebSocketResponseBridge(request.ResponseBody); + var ctx = new LlmRequestContext + { + RequestId = request.RequestId, + SessionId = request.SessionId, + Transport = request.Transport, + Url = request.Url, + Headers = request.Headers, + CancellationToken = request.CancellationToken, + }; + ctx.WebSocketResponse = wsResponse; + + if (request.Transport == LlmInferenceTransport.WebSocket) + { + await HandleWebSocketAsync(request, ctx).ConfigureAwait(false); + } + else + { + await HandleHttpAsync(request, ctx).ConfigureAwait(false); + } + } + + /// + /// Issue the upstream HTTP request. Override to mutate the request before + /// calling base, mutate the returned response after, or replace the + /// call entirely. + /// + protected virtual Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) => + s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken); + + /// + /// Open the upstream WebSocket connection. Override to return a custom + /// or to construct a + /// against a rewritten URL. + /// + protected virtual Task OpenWebSocketAsync(LlmRequestContext ctx) => + Task.FromResult(new ForwardingWebSocketHandler(ctx)); + + private async Task HandleHttpAsync(LlmInferenceRequest req, LlmRequestContext ctx) + { + using var request = await BuildHttpRequestAsync(req).ConfigureAwait(false); + using var response = await SendRequestAsync(request, ctx).ConfigureAwait(false); + await StreamResponseToSinkAsync(response, req, ctx).ConfigureAwait(false); + } + + private static async Task BuildHttpRequestAsync(LlmInferenceRequest req) + { + var method = new HttpMethod(req.Method.ToUpperInvariant()); + var message = new HttpRequestMessage(method, req.Url); + + var hasBody = method != HttpMethod.Get && method != HttpMethod.Head; + var body = await DrainAsync(req.RequestBody).ConfigureAwait(false); + if (hasBody && body.Length > 0) + { + message.Content = new ByteArrayContent(body); + } + + foreach (var (name, values) in req.Headers) + { + if (s_forbiddenRequestHeaders.Contains(name)) + { + continue; + } + + if (!message.Headers.TryAddWithoutValidation(name, values)) + { + message.Content ??= new ByteArrayContent([]); + message.Content.Headers.TryAddWithoutValidation(name, values); + } + } + + return message; + } + + private static async Task StreamResponseToSinkAsync(HttpResponseMessage response, LlmInferenceRequest req, LlmRequestContext ctx) + { + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit + { + Status = (int)response.StatusCode, + StatusText = response.ReasonPhrase, + Headers = HeadersToMultiMap(response), + }).ConfigureAwait(false); + +#if NETSTANDARD2_0 + using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); +#else + using var stream = await response.Content.ReadAsStreamAsync(ctx.CancellationToken).ConfigureAwait(false); +#endif + var buffer = new byte[16 * 1024]; + int read; +#if NETSTANDARD2_0 + while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ctx.CancellationToken).ConfigureAwait(false)) > 0) + { + await req.ResponseBody.WriteAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); + } +#else + while ((read = await stream.ReadAsync(buffer.AsMemory(), ctx.CancellationToken).ConfigureAwait(false)) > 0) + { + await req.ResponseBody.WriteAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); + } +#endif + + await req.ResponseBody.EndAsync().ConfigureAwait(false); + } + + private async Task HandleWebSocketAsync(LlmInferenceRequest req, LlmRequestContext ctx) + { + var handler = await OpenWebSocketAsync(ctx).ConfigureAwait(false); + try + { + await handler.OpenAsync().ConfigureAwait(false); + await ctx.WebSocketResponse!.StartAsync().ConfigureAwait(false); + + var clientPump = Task.Run(async () => + { + await foreach (var chunk in req.RequestBody.WithCancellation(ctx.CancellationToken).ConfigureAwait(false)) + { + await handler.SendRequestMessageAsync(new LlmWebSocketMessage(chunk, isBinary: false)).ConfigureAwait(false); + } + }, ctx.CancellationToken); + + var first = await Task.WhenAny(clientPump, handler.Completion).ConfigureAwait(false); + if (first == clientPump) + { + if (clientPump.IsFaulted || clientPump.IsCanceled) + { + handler.SuppressCloseOnDispose(); + await clientPump.ConfigureAwait(false); + } + + await handler.CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + await handler.Completion.ConfigureAwait(false); + return; + } + + var closeStatus = await handler.Completion.ConfigureAwait(false); + if (closeStatus.Error is not null) + { + throw closeStatus.Error; + } + } + finally + { + await handler.DisposeAsync().ConfigureAwait(false); + } + } + + private static async Task DrainAsync(IAsyncEnumerable> stream) + { + using var buffer = new MemoryStream(); + await foreach (var chunk in stream.ConfigureAwait(false)) + { + if (chunk.Length > 0) + { + buffer.Write(chunk.ToArray(), 0, chunk.Length); + } + } + + return buffer.ToArray(); + } + + private static Dictionary> HeadersToMultiMap(HttpResponseMessage response) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var header in response.Headers) + { + result[header.Key] = [.. header.Value]; + } + + if (response.Content is not null) + { + foreach (var header in response.Content.Headers) + { + result[header.Key] = [.. header.Value]; + } + } + + return result; + } + +} + +internal static class LlmWebSocketHelpers +{ + internal static async Task ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken) + { + var buffer = new byte[16 * 1024]; + using var assembled = new MemoryStream(); + WebSocketReceiveResult result; + do + { + try + { + result = await socket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return null; + } + catch (WebSocketException) + { + return null; + } + + if (result.MessageType == WebSocketMessageType.Close) + { + return null; + } + + assembled.Write(buffer, 0, result.Count); + } + while (!result.EndOfMessage); + + return new LlmWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary); + } + + internal static async Task CloseWebSocketQuietlyAsync(WebSocket socket) + { + try + { + if (socket.State is WebSocketState.Open or WebSocketState.CloseReceived) + { + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, statusDescription: null, CancellationToken.None).ConfigureAwait(false); + } + } + catch + { + // Best-effort; the socket may already be closed. + } + } + + [SuppressMessage("Usage", "CA1031:Do not catch general exception types", Justification = "Best-effort teardown of the losing pump.")] + internal static async Task ObserveQuietlyAsync(Task task) + { + try + { + await task.ConfigureAwait(false); + } + catch + { + // Best-effort teardown only. + } + } + + internal static Uri ToWebSocketUri(string url) + { + var builder = new UriBuilder(url); + if (builder.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "wss"; + } + else if (builder.Scheme.Equals("http", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "ws"; + } + + return builder.Uri; + } +} + +internal sealed class LlmWebSocketResponseBridge +{ + private readonly LlmInferenceResponseSink _sink; + private readonly SemaphoreSlim _gate = new(1, 1); + private readonly Queue _pending = new(); + private bool _started; + private bool _completed; + + internal LlmWebSocketResponseBridge(LlmInferenceResponseSink sink) + { + _sink = sink; + } + + internal async Task StartAsync() + { + await _gate.WaitAsync().ConfigureAwait(false); + try + { + if (_started) + { + return; + } + + _started = true; + await _sink.StartAsync(new LlmInferenceResponseInit { Status = 101 }).ConfigureAwait(false); + while (_pending.Count > 0) + { + await ApplyAsync(_pending.Dequeue()).ConfigureAwait(false); + } + } + finally + { + _gate.Release(); + } + } + + internal Task WriteAsync(LlmWebSocketMessage message) => EnqueueOrApplyAsync(PendingAction.Write(message)); + + internal Task EndAsync() => EnqueueOrApplyAsync(PendingAction.End()); + + internal Task ErrorAsync(string message, string? code) => EnqueueOrApplyAsync(PendingAction.Error(message, code)); + + private async Task EnqueueOrApplyAsync(PendingAction action) + { + await _gate.WaitAsync().ConfigureAwait(false); + try + { + if (_completed && action.Kind == PendingActionKind.Write) + { + return; + } + + if (!_started) + { + _pending.Enqueue(action); + if (action.Kind is PendingActionKind.End or PendingActionKind.Error) + { + _completed = true; + } + + return; + } + + await ApplyAsync(action).ConfigureAwait(false); + } + finally + { + _gate.Release(); + } + } + + private async Task ApplyAsync(PendingAction action) + { + if (_completed && action.Kind == PendingActionKind.Write) + { + return; + } + + switch (action.Kind) + { + case PendingActionKind.Write: + if (action.Message!.Value.IsBinary) + { + await _sink.WriteAsync(action.Message.Value.Data).ConfigureAwait(false); + } + else + { + await _sink.WriteAsync(action.Message.Value.GetText()).ConfigureAwait(false); + } + break; + case PendingActionKind.End: + if (_completed) + { + return; + } + + _completed = true; + await _sink.EndAsync().ConfigureAwait(false); + break; + case PendingActionKind.Error: + if (_completed) + { + return; + } + + _completed = true; + await _sink.ErrorAsync(action.ErrorMessage!, action.ErrorCode).ConfigureAwait(false); + break; + } + } + + private readonly record struct PendingAction( + PendingActionKind Kind, + LlmWebSocketMessage? Message = null, + string? ErrorMessage = null, + string? ErrorCode = null) + { + internal static PendingAction Write(LlmWebSocketMessage message) => new(PendingActionKind.Write, message); + internal static PendingAction End() => new(PendingActionKind.End); + internal static PendingAction Error(string message, string? code) => new(PendingActionKind.Error, null, message, code); + } + + private enum PendingActionKind + { + Write, + End, + Error, + } +} diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index d7b326afb..9167c2cf7 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -278,6 +278,7 @@ private CopilotClientOptions(CopilotClientOptions? other) UseLoggedInUser = other.UseLoggedInUser; OnListModels = other.OnListModels; SessionFs = other.SessionFs; + LlmInference = other.LlmInference; SessionIdleTimeoutSeconds = other.SessionIdleTimeoutSeconds; EnableRemoteSessions = other.EnableRemoteSessions; Mode = other.Mode; @@ -364,6 +365,17 @@ private CopilotClientOptions(CopilotClientOptions? other) /// public SessionFsConfig? SessionFs { get; set; } + /// + /// Configures interception of the LLM inference requests the runtime would + /// otherwise issue itself (for both CAPI and BYOK providers). When set, the + /// client registers a client-global LLM inference provider on connect, so + /// every model-layer HTTP / WebSocket request is routed to the consumer's + /// (or + /// subclass) instead of the runtime's own outbound call. + /// + [Experimental(Diagnostics.Experimental)] + public LlmInferenceConfig? LlmInference { get; set; } + /// /// OpenTelemetry configuration for the runtime. /// When set to a non- instance, the runtime is started with OpenTelemetry instrumentation enabled. @@ -484,6 +496,21 @@ public sealed class SessionFsConfig public SessionFsSetProviderCapabilities? Capabilities { get; init; } } +/// +/// Configuration for intercepting the LLM inference requests the runtime issues. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceConfig +{ + /// + /// Handler that services every intercepted model-layer request for the + /// lifetime of the client connection. Subclass + /// and override its hooks to observe, mutate, or fully replace each + /// request/response. + /// + public LlmRequestHandler? Handler { get; set; } +} + /// /// Represents a binary result returned by a tool invocation. /// diff --git a/dotnet/test/E2E/LlmInferenceE2EProvider.cs b/dotnet/test/E2E/LlmInferenceE2EProvider.cs new file mode 100644 index 000000000..25fdadd76 --- /dev/null +++ b/dotnet/test/E2E/LlmInferenceE2EProvider.cs @@ -0,0 +1,167 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.RegularExpressions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// A subclass for e2e tests that records every +/// intercepted request (url + threaded session id) and fully replaces the +/// upstream call with a fabricated, well-formed response for every model-layer +/// endpoint, so an agent turn completes entirely off-network — no upstream +/// server and no CAPI proxy acting as the inference endpoint. +/// +/// +/// +/// This exercises the public extension surface end to end: a consumer subclasses +/// and overrides to +/// short-circuit the upstream HTTP call with any +/// it likes. The base class streams that response back to the runtime. +/// +/// +/// All response bodies are emitted as raw JSON string literals rather than via +/// JsonSerializer: the test project disables reflection-based STJ on +/// net8.0 (JsonSerializerIsReflectionEnabledByDefault=false), so +/// serializing anonymous types would throw at runtime. +/// +/// +internal sealed class RecordingInferenceProvider : LlmRequestHandler +{ + internal const string SyntheticText = "OK from the synthetic stream."; + + private static readonly Regex WantsStreamRegex = new("\"stream\"\\s*:\\s*true", RegexOptions.Compiled); + + private readonly ConcurrentQueue _records = new(); + + public IReadOnlyCollection Records => _records; + + public IReadOnlyList InferenceRequests => + [.. _records.Where(r => IsInferenceUrl(r.Url))]; + + protected override async Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) + { + var url = request.RequestUri!.ToString(); + _records.Enqueue(new InterceptedRequest(url, ctx.SessionId)); + + var bodyText = request.Content is null + ? string.Empty +#if NET8_0_OR_GREATER + : await request.Content.ReadAsStringAsync(ctx.CancellationToken).ConfigureAwait(false); +#else + : await request.Content.ReadAsStringAsync().ConfigureAwait(false); +#endif + + return IsInferenceUrl(url) + ? BuildInferenceResponse(url, bodyText) + : BuildNonInferenceResponse(url); + } + + internal static bool IsInferenceUrl(string url) + { + var u = url.ToLowerInvariant(); + return u.EndsWith("/chat/completions", StringComparison.Ordinal) + || u.EndsWith("/responses", StringComparison.Ordinal) + || u.EndsWith("/v1/messages", StringComparison.Ordinal) + || u.EndsWith("/messages", StringComparison.Ordinal); + } + + /// + /// Synthesizes a well-formed inference response so the agent turn completes. + /// The runtime selects /responses for both the CAPI and BYOK sessions + /// here; /chat/completions is handled too for robustness. + /// + private static HttpResponseMessage BuildInferenceResponse(string url, string bodyText) + { + var wantsStream = WantsStreamRegex.IsMatch(bodyText); + var u = url.ToLowerInvariant(); + + if (u.Contains("/responses", StringComparison.Ordinal)) + { + return wantsStream + ? Sse(string.Concat(ResponsesStreamEvents)) + : Json(BufferedResponseJson); + } + + if (u.Contains("/chat/completions", StringComparison.Ordinal) && wantsStream) + { + return Sse(string.Concat(ChatCompletionStreamEvents)); + } + + // /chat/completions non-streaming (and any other inference url) — buffered JSON. + return Json(BufferedChatCompletionJson); + } + + /// + /// Serves the non-inference model-layer GETs/POSTs the runtime issues + /// (catalog, model session, policy). These flow through the same callback + /// but carry no session id (they happen outside an agent turn). + /// + private static HttpResponseMessage BuildNonInferenceResponse(string url) + { + var u = url.ToLowerInvariant(); + if (u.EndsWith("/models", StringComparison.Ordinal)) + { + return Json(ModelCatalogJson); + } + + if (u.Contains("/models/session", StringComparison.Ordinal)) + { + return Json("{}"); + } + + if (u.Contains("/policy", StringComparison.Ordinal)) + { + return Json("{\"state\":\"enabled\"}"); + } + + return Json("{}"); + } + + private static HttpResponseMessage Json(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "application/json"), + }; + + private static HttpResponseMessage Sse(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "text/event-stream"), + }; + + private static readonly string[] ResponsesStreamEvents = + [ + "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"in_progress\",\"output\":[]}}\n\n", + "event: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[]}}\n\n", + "event: response.content_part.added\ndata: {\"type\":\"response.content_part.added\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"\"}}\n\n", + "event: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"output_index\":0,\"content_index\":0,\"delta\":\"" + SyntheticText + "\"}\n\n", + "event: response.output_text.done\ndata: {\"type\":\"response.output_text.done\",\"output_index\":0,\"content_index\":0,\"text\":\"" + SyntheticText + "\"}\n\n", + "event: response.completed\ndata: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}}\n\n", + ]; + + private static readonly string[] ChatCompletionStreamEvents = + [ + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"" + SyntheticText + "\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}\n\n", + "data: [DONE]\n\n", + ]; + + private static readonly string BufferedResponseJson = + "{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}"; + + private static readonly string BufferedChatCompletionJson = + "{\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"" + SyntheticText + "\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}"; + + private const string ModelCatalogJson = + "{\"data\":[{\"id\":\"claude-sonnet-4.5\",\"name\":\"Claude Sonnet 4.5\",\"object\":\"model\",\"vendor\":\"Anthropic\",\"version\":\"1\",\"preview\":false,\"model_picker_enabled\":true,\"capabilities\":{\"type\":\"chat\",\"family\":\"claude-sonnet-4.5\",\"tokenizer\":\"o200k_base\",\"limits\":{\"max_context_window_tokens\":200000,\"max_output_tokens\":8192},\"supports\":{\"streaming\":true,\"tool_calls\":true,\"parallel_tool_calls\":true,\"vision\":true}}}]}"; +} + +/// A single request the callback intercepted. +internal sealed record InterceptedRequest(string Url, string? SessionId); diff --git a/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs b/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs new file mode 100644 index 000000000..be1db1de9 --- /dev/null +++ b/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs @@ -0,0 +1,107 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// Asserts the runtime threads its session id into the LLM inference callback +/// for BOTH a CAPI session and a BYOK session. The callback alone services +/// every model-layer request — no upstream server, no CAPI proxy acting as the +/// inference endpoint — so the only source of req.SessionId is the +/// runtime's own per-client threading. +/// +public class LlmInferenceSessionIdE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "llm_inference_session_id", output) +{ + private CopilotClient CreateClientWith(RecordingInferenceProvider provider) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + LlmInference = new LlmInferenceConfig + { + Handler = provider, + }, + }); + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Capi_Session_Inference_Request() + { + var provider = new RecordingInferenceProvider(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + var capiSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(capiSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Byok_Session_Inference_Request() + { + var provider = new RecordingInferenceProvider(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + // BYOK providers require an explicit model id. + Model = "claude-sonnet-4.5", + Provider = new ProviderConfig + { + Type = "openai", + WireApi = "responses", + BaseUrl = "https://byok.invalid/v1", + ApiKey = "byok-secret", + ModelId = "claude-sonnet-4.5", + WireModel = "claude-sonnet-4.5", + }, + }); + var byokSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(byokSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } +} diff --git a/dotnet/test/GitHub.Copilot.SDK.Test.csproj b/dotnet/test/GitHub.Copilot.SDK.Test.csproj index 4b27df57c..49e117d83 100644 --- a/dotnet/test/GitHub.Copilot.SDK.Test.csproj +++ b/dotnet/test/GitHub.Copilot.SDK.Test.csproj @@ -7,6 +7,13 @@ false true $(NoWarn);GHCP001 + + $(NoWarn);CS0436 @@ -35,7 +42,11 @@ - + diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs b/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs new file mode 100644 index 000000000..94d50f378 --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs @@ -0,0 +1,197 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Text; +using Xunit; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +public class LlmInferenceAdapterTests +{ + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(10); + + private static LlmInferenceAdapter CreateAdapter(ILlmInferenceProvider provider, RecordingResponseChannel channel) + { + ILlmInferenceResponseChannel current = channel; + return new LlmInferenceAdapter(provider, () => current); + } + + [Fact] + public async Task Stages_request_chunks_that_arrive_before_their_start_frame_and_replays_them_in_order() + { + var received = new List(); + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + await foreach (var chunk in req.RequestBody) + { + received.Add(Encoding.UTF8.GetString(chunk.ToArray())); + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + // Chunks arrive BEFORE the start frame (a reordering the runtime should + // never produce). They must be staged and replayed once start registers. + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "hello ", end: false)); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "world", end: false)); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "", end: true)); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r1")); + + await done.Task.WaitAsync(Timeout); + Assert.Equal("hello world", string.Concat(received)); + } + + [Fact] + public async Task Emits_a_buffered_response_as_start_then_body_then_terminal_end() + { + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit + { + Status = 200, + Headers = new Dictionary> { ["content-type"] = ["application/json"] }, + }); + await req.ResponseBody.WriteAsync("OK"); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r2")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r2", "", end: true)); + + await done.Task.WaitAsync(Timeout); + + var start = Assert.Single(channel.Starts); + Assert.Equal(200, start.Status); + Assert.Equal("OK", channel.DecodeTextBody()); + + var terminal = Assert.Single(channel.Chunks, c => c.End == true); + Assert.Null(terminal.Error); + } + + [Fact] + public async Task Aborts_the_provider_and_throws_from_write_when_the_runtime_rejects_a_response_frame() + { + var aborted = false; + var writeThrew = false; + var settled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + req.CancellationToken.Register(() => aborted = true); + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + try + { + await req.ResponseBody.WriteAsync("rejected-chunk"); + } + catch (InvalidOperationException) + { + writeThrew = true; + } + + settled.SetResult(); + }); + + // The runtime accepts the start frame but rejects the body chunk. + var channel = new RecordingResponseChannel(acceptStart: true, acceptChunk: false); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r3")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r3", "", end: true)); + + await settled.Task.WaitAsync(Timeout); + Assert.True(writeThrew, "write should throw after the runtime rejects the chunk"); + Assert.True(aborted, "the provider's cancellation token should fire on rejection"); + } + + [Fact] + public async Task Surfaces_a_runtime_cancel_chunk_as_a_cancelled_terminal_error() + { + var observedCancellation = false; + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + try + { + await foreach (var _ in req.RequestBody) + { + // The cancel frame surfaces as an OperationCanceledException here. + } + } + catch (OperationCanceledException) + { + observedCancellation = true; + throw; + } + finally + { + done.TrySetResult(); + } + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r4")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r4", cancel: true, cancelReason: "turn aborted")); + + await done.Task.WaitAsync(Timeout); + await channel.Terminal.WaitAsync(Timeout); + Assert.True(observedCancellation, "the request body iterator should throw on a cancel frame"); + + // The adapter finalises a cancelled request as a 499 + error{code:cancelled}. + var terminal = Assert.Single(channel.Chunks, c => c.Error is not null); + Assert.Equal("cancelled", terminal.Error!.Code); + } + + [Fact] + public async Task Threads_the_runtime_session_id_into_the_request() + { + string? observedSessionId = null; + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + observedSessionId = req.SessionId; + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r5", sessionId: "session-123")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r5", "", end: true)); + + await done.Task.WaitAsync(Timeout); + Assert.Equal("session-123", observedSessionId); + } +} diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs b/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs new file mode 100644 index 000000000..663884781 --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs @@ -0,0 +1,159 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Net; +using System.Net.Http; +using System.Text; +using Xunit; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +public class LlmInferenceHandlerTests +{ + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(10); + + private static Task Dispatch(LlmRequestHandler handler, LlmInferenceRequest request) => + ((ILlmInferenceProvider)handler).OnLlmRequestAsync(request); + + private static async IAsyncEnumerable> AsyncBytes(params string[] chunks) + { + foreach (var chunk in chunks) + { + await Task.Yield(); + yield return Encoding.UTF8.GetBytes(chunk); + } + } + + private static LlmInferenceRequest HttpRequest( + RecordingSink sink, + IAsyncEnumerable> body, + string method = "POST", + string url = "https://upstream.test/v1/chat/completions", + IReadOnlyDictionary>? headers = null) => + new() + { + RequestId = "req-1", + SessionId = "session-1", + Method = method, + Url = url, + Headers = headers ?? new Dictionary>(), + Transport = LlmInferenceTransport.Http, + RequestBody = body, + ResponseBody = sink, + }; + + /// A handler whose upstream call is a canned delegate (no network). + private sealed class StubHandler(Func send) : LlmRequestHandler + { + protected override Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) => + Task.FromResult(send(request)); + } + + /// A handler that adds a header before calling base.SendRequestAsync. + private sealed class HeaderMutatingHandler(Func send) : LlmRequestHandler + { + protected override Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) + { + request.Headers.TryAddWithoutValidation("authorization", "Bearer swapped-token"); + return Task.FromResult(send(request)); + } + } + + [Fact] + public async Task Forwards_request_body_and_streams_response_back_to_the_sink() + { + string? forwardedBody = null; + var handler = new StubHandler(req => + { + forwardedBody = req.Content!.ReadAsStringAsync().GetAwaiter().GetResult(); + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("RESPONSE-BODY", Encoding.UTF8, "application/json"), + }; + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes("{\"hello\":", "\"world\"}")); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.Equal("{\"hello\":\"world\"}", forwardedBody); + + var start = Assert.Single(sink.Starts); + Assert.Equal(200, start.Status); + Assert.Equal("RESPONSE-BODY", sink.DecodeBinaryBody()); + Assert.True(sink.Ended); + Assert.Null(sink.Errored); + } + + [Fact] + public async Task Strips_forbidden_request_headers_before_forwarding() + { + var forwarded = new Dictionary(StringComparer.OrdinalIgnoreCase); + var handler = new StubHandler(req => + { + foreach (var header in req.Headers) + { + forwarded[header.Key] = string.Join(",", header.Value); + } + + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent("ok") }; + }); + + var sink = new RecordingSink(); + var headers = new Dictionary> + { + ["host"] = ["should-be-stripped.test"], + ["x-tenant"] = ["acme"], + }; + var request = HttpRequest(sink, AsyncBytes("body"), headers: headers); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.False(forwarded.ContainsKey("host"), "the forbidden host header must be stripped"); + Assert.Equal("acme", forwarded["x-tenant"]); + } + + [Fact] + public async Task Lets_a_subclass_mutate_the_outbound_request_headers() + { + string? observedAuth = null; + var handler = new HeaderMutatingHandler(req => + { + observedAuth = req.Headers.TryGetValues("authorization", out var values) + ? string.Join(",", values) + : null; + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent("ok") }; + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes("body")); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.Equal("Bearer swapped-token", observedAuth); + } + + [Fact] + public async Task Propagates_a_non_2xx_status_verbatim_to_the_runtime() + { + var handler = new StubHandler(_ => + new HttpResponseMessage((HttpStatusCode)429) + { + Content = new StringContent("slow down"), + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes()); + + await Dispatch(handler, request).WaitAsync(Timeout); + + var start = Assert.Single(sink.Starts); + Assert.Equal(429, start.Status); + Assert.Equal("slow down", sink.DecodeBinaryBody()); + Assert.True(sink.Ended); + } +} diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs b/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs new file mode 100644 index 000000000..65339732a --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs @@ -0,0 +1,157 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using System.Text; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// In-memory that records every +/// response frame the adapter emits and lets a test choose what +/// accepted value the runtime returns. +/// +internal sealed class RecordingResponseChannel(bool acceptStart = true, bool acceptChunk = true) : ILlmInferenceResponseChannel +{ + public sealed record StartFrame(long Status, string? StatusText, IDictionary> Headers); + + public sealed record ChunkFrame(string Data, bool? Binary, bool? End, LlmInferenceHttpResponseChunkError? Error); + + public List Starts { get; } = []; + + public List Chunks { get; } = []; + + private readonly TaskCompletionSource _terminal = new(TaskCreationOptions.RunContinuationsAsynchronously); + + /// Completes once a terminal response chunk (end or error) is recorded. + public Task Terminal => _terminal.Task; + + public Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null) + { + Starts.Add(new StartFrame(status, statusText, headers)); + return Task.FromResult(new LlmInferenceHttpResponseStartResult { Accepted = acceptStart }); + } + + public Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null) + { + Chunks.Add(new ChunkFrame(data, binary, end, error)); + if (end == true || error is not null) + { + _terminal.TrySetResult(); + } + + return Task.FromResult(new LlmInferenceHttpResponseChunkResult { Accepted = acceptChunk }); + } + + /// Concatenates the UTF-8 text of all non-terminal body chunks. + public string DecodeTextBody() + { + var sb = new StringBuilder(); + foreach (var chunk in Chunks) + { + if (chunk.Error is not null || chunk.Data.Length == 0) + { + continue; + } + + sb.Append(chunk.Binary == true + ? Encoding.UTF8.GetString(Convert.FromBase64String(chunk.Data)) + : chunk.Data); + } + + return sb.ToString(); + } +} + +/// An driven by an inline delegate. +internal sealed class InlineProvider(Func handler) : ILlmInferenceProvider +{ + public Task OnLlmRequestAsync(LlmInferenceRequest request) => handler(request); +} + +/// Records everything written to a . +internal sealed class RecordingSink : LlmInferenceResponseSink +{ + public List Starts { get; } = []; + + public List TextWrites { get; } = []; + + public List BinaryWrites { get; } = []; + + public bool Ended { get; private set; } + + public (string Message, string? Code)? Errored { get; private set; } + + /// Concatenates all binary body writes and decodes them as UTF-8. + public string DecodeBinaryBody() => Encoding.UTF8.GetString(BinaryWrites.SelectMany(b => b).ToArray()); + + public override Task StartAsync(LlmInferenceResponseInit init) + { + Starts.Add(init); + return Task.CompletedTask; + } + + public override Task WriteAsync(ReadOnlyMemory data) + { + BinaryWrites.Add(data.ToArray()); + return Task.CompletedTask; + } + + public override Task WriteAsync(string text) + { + TextWrites.Add(text); + return Task.CompletedTask; + } + + public override Task EndAsync() + { + Ended = true; + return Task.CompletedTask; + } + + public override Task ErrorAsync(string message, string? code = null) + { + Errored = (message, code); + return Task.CompletedTask; + } +} + +/// Convenience builders for the generated request frames. +internal static class LlmFrames +{ + public static LlmInferenceHttpRequestStartRequest Start( + string requestId, + string url = "https://example.test/v1/chat", + string method = "POST", + string? sessionId = null, + LlmInferenceHttpRequestStartTransport? transport = null) => + new() + { + RequestId = requestId, + Url = url, + Method = method, + SessionId = sessionId, + Headers = new Dictionary>(), + Transport = transport, + }; + + public static LlmInferenceHttpRequestChunkRequest Chunk( + string requestId, + string data = "", + bool? end = null, + bool? binary = null, + bool? cancel = null, + string? cancelReason = null) => + new() + { + RequestId = requestId, + Data = data, + End = end, + Binary = binary, + Cancel = cancel, + CancelReason = cancelReason, + }; +} diff --git a/go/client.go b/go/client.go index af9044ad9..f2575a646 100644 --- a/go/client.go +++ b/go/client.go @@ -371,6 +371,15 @@ func (c *Client) Start(ctx context.Context) error { } } + // If an LLM inference callback was configured, register as the provider. + if c.options.LlmInference != nil && c.options.LlmInference.Handler != nil { + if _, err := c.RPC.LlmInference.SetProvider(ctx); err != nil { + killErr := c.killProcess() + c.state = stateError + return errors.Join(err, killErr) + } + } + c.state = stateConnected return nil } @@ -2003,6 +2012,15 @@ func (c *Client) setupNotificationHandler() { } return session.clientSessionAPIs }) + if c.options.LlmInference != nil && c.options.LlmInference.Handler != nil { + adapter := newLlmInferenceAdapter(c.options.LlmInference.Handler, func() *rpc.ServerLlmInferenceAPI { + if c.RPC == nil { + return nil + } + return c.RPC.LlmInference + }) + rpc.RegisterClientGlobalAPIHandlers(c.client, &rpc.ClientGlobalAPIHandlers{LlmInference: adapter}) + } } func (c *Client) handleSessionEvent(req sessionEventRequest) { diff --git a/go/go.mod b/go/go.mod index 16114a0ab..586a5d336 100644 --- a/go/go.mod +++ b/go/go.mod @@ -8,6 +8,7 @@ require ( ) require ( + github.com/coder/websocket v1.8.15 github.com/google/uuid v1.6.0 go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/trace v1.35.0 diff --git a/go/go.sum b/go/go.sum index ec2bbcc1e..e7ac53d5a 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,3 +1,5 @@ +github.com/coder/websocket v1.8.15 h1:6B2JPeOGlpff2Uz6vOEH1Vzpi0iUz20A+lPVhPHtNUA= +github.com/coder/websocket v1.8.15/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= diff --git a/go/internal/e2e/llm_inference_cancel_e2e_test.go b/go/internal/e2e/llm_inference_cancel_e2e_test.go new file mode 100644 index 000000000..cbeb2bc56 --- /dev/null +++ b/go/internal/e2e/llm_inference_cancel_e2e_test.go @@ -0,0 +1,102 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +type llmCancellingHandler struct { + inferenceEntered atomic.Bool + sawAbort atomic.Bool + abortSeen chan struct{} + once sync.Once +} + +func newLlmCancellingHandler() *llmCancellingHandler { + return &llmCancellingHandler{abortSeen: make(chan struct{})} +} + +func (h *llmCancellingHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + served, err := llmServiceNonInference(req) + if err != nil { + return err + } + if served { + return nil + } + if !llmIsInferenceURL(req.URL) { + return llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, "{}") + } + + // Inference: never produce a response. Wait for the runtime to cancel us, + // recording the abort. + llmDrainRequest(req) + h.inferenceEntered.Store(true) + <-req.Context.Done() + h.sawAbort.Store(true) + h.once.Do(func() { close(h.abortSeen) }) + // Runtime already dropped the request on cancel; the sink error is a no-op. + _ = req.ResponseBody.Error("cancelled by upstream", "cancelled") + return nil +} + +func waitFor(t *testing.T, predicate func() bool, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for !predicate() { + if time.Now().After(deadline) { + t.Fatal("waitFor timed out") + } + time.Sleep(50 * time.Millisecond) + } +} + +func TestLlmInferenceCancel(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := newLlmCancellingHandler() + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if _, err := session.Send(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}); err != nil { + t.Fatalf("send failed: %v", err) + } + waitFor(t, handler.inferenceEntered.Load, 60*time.Second) + if err := session.Abort(t.Context()); err != nil { + t.Fatalf("abort failed: %v", err) + } + + select { + case <-handler.abortSeen: + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for the consumer to observe runtime cancellation") + } + _ = session.Disconnect() + + if !handler.inferenceEntered.Load() { + t.Fatal("Expected the inference callback to be entered") + } + if !handler.sawAbort.Load() { + t.Fatal("Expected the consumer to observe the runtime-driven cancellation") + } +} diff --git a/go/internal/e2e/llm_inference_consumer_cancel_e2e_test.go b/go/internal/e2e/llm_inference_consumer_cancel_e2e_test.go new file mode 100644 index 000000000..0cda6b665 --- /dev/null +++ b/go/internal/e2e/llm_inference_consumer_cancel_e2e_test.go @@ -0,0 +1,69 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "net/http" + "sync/atomic" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +type llmConsumerCancelHandler struct { + inferenceAttempts atomic.Int32 +} + +func (h *llmConsumerCancelHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + served, err := llmServiceNonInference(req) + if err != nil { + return err + } + if served { + return nil + } + if !llmIsInferenceURL(req.URL) { + return llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, "{}") + } + + // Consumer-initiated cancellation: the consumer's own upstream call was + // aborted, so it tells the runtime to give up on this request. No response + // head is ever produced; the runtime should see a transport failure rather + // than hanging. + llmDrainRequest(req) + h.inferenceAttempts.Add(1) + return req.ResponseBody.Error("upstream call aborted by consumer", "cancelled") +} + +func TestLlmInferenceConsumerCancel(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmConsumerCancelHandler{} + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, sendErr := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + _ = session.Disconnect() + + // The runtime reached the inference step and the consumer's cancellation + // terminated it (rather than the runtime hanging). + if handler.inferenceAttempts.Load() == 0 { + t.Fatal("Expected the inference callback to be attempted") + } + if sendErr != nil && len(sendErr.Error()) == 0 { + t.Fatal("Expected a non-empty error string when a failure surfaces") + } +} diff --git a/go/internal/e2e/llm_inference_e2e_test.go b/go/internal/e2e/llm_inference_e2e_test.go new file mode 100644 index 000000000..640915891 --- /dev/null +++ b/go/internal/e2e/llm_inference_e2e_test.go @@ -0,0 +1,80 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// llmRecordingHandler answers every model-layer request with the synthetic +// non-inference fallback (catalog / session / policy, and empty JSON for the +// inference call itself). It records what it intercepted. +type llmRecordingHandler struct { + mu sync.Mutex + received []*copilot.LlmInferenceRequest +} + +func (h *llmRecordingHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + h.mu.Lock() + h.received = append(h.received, req) + h.mu.Unlock() + return llmHandleNonInferenceModelTraffic(req, nil) +} + +func (h *llmRecordingHandler) snapshot() []*copilot.LlmInferenceRequest { + h.mu.Lock() + defer h.mu.Unlock() + return append([]*copilot.LlmInferenceRequest(nil), h.received...) +} + +func TestLlmInferenceCallback(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmRecordingHandler{} + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // The buffered fallback returns empty JSON for the inference call, which is + // not a valid model response, so the turn fails; swallow that. What we + // assert is that the runtime attempted the callback. + _, _ = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + _ = session.Disconnect() + + received := handler.snapshot() + if len(received) == 0 { + t.Fatal("Expected the runtime to invoke the inference callback") + } + + var sawCatalog bool + for _, r := range received { + if !strings.HasPrefix(r.URL, "http://") && !strings.HasPrefix(r.URL, "https://") { + t.Fatalf("Expected an absolute URL, got %q", r.URL) + } + if strings.HasSuffix(strings.ToLower(r.URL), "/models") { + sawCatalog = true + } + if r.SessionID != "" && len(r.SessionID) == 0 { + t.Fatal("session id should be non-empty when present") + } + } + if !sawCatalog { + t.Fatal("Expected to intercept the /models catalog request") + } +} diff --git a/go/internal/e2e/llm_inference_errors_e2e_test.go b/go/internal/e2e/llm_inference_errors_e2e_test.go new file mode 100644 index 000000000..7264699ab --- /dev/null +++ b/go/internal/e2e/llm_inference_errors_e2e_test.go @@ -0,0 +1,86 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "errors" + "net/http" + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +type llmThrowingHandler struct { + mu sync.Mutex + totalCalls int + callsBeforeError int +} + +func (h *llmThrowingHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + h.mu.Lock() + h.totalCalls++ + h.mu.Unlock() + + served, err := llmServiceNonInference(req) + if err != nil { + return err + } + if served { + return nil + } + + url := strings.ToLower(req.URL) + if strings.Contains(url, "/chat/completions") || strings.Contains(url, "/responses") { + llmDrainRequest(req) + h.mu.Lock() + h.callsBeforeError++ + h.mu.Unlock() + return errors.New("synthetic-callback-transport-failure") + } + + return llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, "{}") +} + +func TestLlmInferenceErrors(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmThrowingHandler{} + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // The handler raises from the inference callback; the agent layer surfaces + // it as an error or an event rather than hanging. The assertion is loose: + // the inference call was attempted and the runtime did not hang. + _, sendErr := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + _ = session.Disconnect() + + handler.mu.Lock() + total := handler.totalCalls + before := handler.callsBeforeError + handler.mu.Unlock() + + if total == 0 { + t.Fatal("Expected the callback to be invoked") + } + if before == 0 { + t.Fatal("Expected the inference callback to be reached and raise") + } + if sendErr != nil && len(sendErr.Error()) == 0 { + t.Fatal("Expected a non-empty error string when an error surfaces") + } +} diff --git a/go/internal/e2e/llm_inference_handler_e2e_test.go b/go/internal/e2e/llm_inference_handler_e2e_test.go new file mode 100644 index 000000000..4767a0fe3 --- /dev/null +++ b/go/internal/e2e/llm_inference_handler_e2e_test.go @@ -0,0 +1,207 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + + "github.com/coder/websocket" + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +const ( + llmHandlerHTTPText = "OK from synthetic HTTP upstream." + llmHandlerWSText = "OK from synthetic WS upstream." +) + +type llmHandlerCounters struct { + httpRequests atomic.Int32 + httpResponses atomic.Int32 + wsRequestMessages atomic.Int32 + wsResponseMessages atomic.Int32 + upstreamWSRequests atomic.Int32 +} + +func llmSSEBody(text, respID string) string { + var sb strings.Builder + for _, event := range llmResponsesEvents(text, respID) { + sb.WriteString(llmSSE(event["type"].(string), event)) + } + return sb.String() +} + +// startFakeUpstream brings up a real HTTP upstream (catalog / policy / +// responses-SSE) and a real WebSocket upstream that echoes the ordered +// /responses events per inbound message. +func startFakeUpstream(t *testing.T, counters *llmHandlerCounters) (httpURL, wsURL string) { + t.Helper() + + httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.ToLower(strings.SplitN(r.URL.Path, "?", 2)[0]) + _ = r.Body.Close + switch { + case strings.HasSuffix(path, "/models"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte(llmModelCatalog(llmWSSupportedEndpoints))) + case strings.HasSuffix(path, "/models/session"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte("{}")) + case strings.Contains(path, "/policy"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte(`{"state":"enabled"}`)) + case strings.HasSuffix(path, "/responses"): + w.Header().Set("content-type", "text/event-stream") + _, _ = w.Write([]byte(llmSSEBody(llmHandlerHTTPText, "resp_stub_http"))) + default: + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not_found"}`)) + } + })) + t.Cleanup(httpSrv.Close) + + wsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{InsecureSkipVerify: true}) + if err != nil { + return + } + defer c.Close(websocket.StatusNormalClosure, "") + c.SetReadLimit(-1) + bg := context.Background() + for { + _, _, readErr := c.Read(bg) + if readErr != nil { + return + } + counters.upstreamWSRequests.Add(1) + for _, event := range llmResponsesEvents(llmHandlerWSText, "resp_stub_ws") { + raw, _ := json.Marshal(event) + if err := c.Write(bg, websocket.MessageText, raw); err != nil { + return + } + } + } + })) + t.Cleanup(wsSrv.Close) + + return httpSrv.URL, "ws://" + strings.TrimPrefix(wsSrv.URL, "http://") +} + +type llmRewritingRoundTripper struct { + base *url.URL + counters *llmHandlerCounters + inner http.RoundTripper +} + +func (rt *llmRewritingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.counters.httpRequests.Add(1) + req.URL.Scheme = rt.base.Scheme + req.URL.Host = rt.base.Host + req.Host = rt.base.Host + req.Header.Set("x-test-mutated", "1") + resp, err := rt.inner.RoundTrip(req) + if err != nil { + return nil, err + } + rt.counters.httpResponses.Add(1) + resp.Header.Set("x-test-response-mutated", "1") + return resp, nil +} + +func TestLlmInferenceHandler(t *testing.T) { + ctx := testharness.NewTestContext(t) + counters := &llmHandlerCounters{} + httpURL, wsURL := startFakeUpstream(t, counters) + + httpBase, err := url.Parse(httpURL) + if err != nil { + t.Fatalf("Failed to parse upstream URL: %v", err) + } + wsBase, err := url.Parse(wsURL) + if err != nil { + t.Fatalf("Failed to parse upstream ws URL: %v", err) + } + + handler := &copilot.LlmRequestHandler{ + Transport: &llmRewritingRoundTripper{ + base: httpBase, + counters: counters, + inner: http.DefaultTransport.(*http.Transport).Clone(), + }, + OpenWebSocket: func(rctx *copilot.LlmRequestContext) (copilot.CopilotWebSocketHandler, error) { + parsed, perr := url.Parse(rctx.URL) + if perr != nil { + return nil, perr + } + parsed.Scheme = wsBase.Scheme + parsed.Host = wsBase.Host + fwd := copilot.NewForwardingWebSocketHandler(parsed.String(), rctx.Headers) + fwd.OnSendRequestMessage = func(data []byte) []byte { + counters.wsRequestMessages.Add(1) + return data + } + fwd.OnSendResponseMessage = func(data []byte) []byte { + counters.wsResponseMessages.Add(1) + return data + } + return fwd, nil + }, + } + + client := newLlmClient(ctx, handler, "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES=true") + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + // The HTTP seam fired — the runtime issued model-layer GETs (catalog, + // policy) and possibly a single-shot inference through the RoundTripper. + if counters.httpRequests.Load() == 0 { + t.Fatal("Expected the HTTP RoundTripper to fire") + } + if counters.httpResponses.Load() == 0 { + t.Fatal("Expected the HTTP response mutation to fire") + } + + // The WebSocket seam fired — the main agent turn went over the WS path and + // we observed messages in both directions. + if counters.wsRequestMessages.Load() == 0 { + t.Fatal("Expected runtime → upstream ws messages") + } + if counters.wsResponseMessages.Load() == 0 { + t.Fatal("Expected upstream → runtime ws messages") + } + if counters.upstreamWSRequests.Load() == 0 { + t.Fatal("Expected the upstream WS to receive request messages") + } + + // Validate the final assistant response arrived (guards against truncated captures) + text := assistantText(result) + if !strings.Contains(text, "OK from synthetic") || !strings.Contains(text, "upstream") { + t.Fatalf("Expected synthetic upstream content in assistant reply, got %q", text) + } +} diff --git a/go/internal/e2e/llm_inference_helpers_test.go b/go/internal/e2e/llm_inference_helpers_test.go new file mode 100644 index 000000000..e945f2284 --- /dev/null +++ b/go/internal/e2e/llm_inference_helpers_test.go @@ -0,0 +1,275 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "encoding/json" + "net/http" + "regexp" + "strings" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// Shared synthetic-upstream helpers for the LLM inference callback e2e tests. +// +// These tests have no recorded snapshots: the registered callback fabricates +// well-formed model responses and the runtime routes all of its model-layer +// HTTP/WebSocket traffic through that callback instead of the CAPI proxy. The +// helpers centralise the synthetic CAPI shapes (model catalog, policy, +// /responses SSE, /chat/completions) so each test focuses on the behaviour it +// is exercising. + +const llmSyntheticText = "OK from the synthetic stream." + +var llmStreamTrueRe = regexp.MustCompile(`"stream"\s*:\s*true`) + +func llmStreamTrue(body string) bool { + return llmStreamTrueRe.MatchString(body) +} + +func llmIsInferenceURL(url string) bool { + u := strings.ToLower(url) + return strings.HasSuffix(u, "/chat/completions") || + strings.HasSuffix(u, "/responses") || + strings.HasSuffix(u, "/v1/messages") || + strings.HasSuffix(u, "/messages") +} + +func llmSSE(eventType string, data map[string]any) string { + raw, _ := json.Marshal(data) + return "event: " + eventType + "\ndata: " + string(raw) + "\n\n" +} + +func llmModelCatalog(supportedEndpoints []string) string { + model := map[string]any{ + "id": "claude-sonnet-4.5", + "name": "Claude Sonnet 4.5", + "object": "model", + "vendor": "Anthropic", + "version": "1", + "preview": false, + "model_picker_enabled": true, + "capabilities": map[string]any{ + "type": "chat", + "family": "claude-sonnet-4.5", + "tokenizer": "o200k_base", + "limits": map[string]any{ + "max_context_window_tokens": 200000, + "max_output_tokens": 8192, + }, + "supports": map[string]any{ + "streaming": true, + "tool_calls": true, + "parallel_tool_calls": true, + "vision": true, + }, + }, + } + if supportedEndpoints != nil { + model["supported_endpoints"] = supportedEndpoints + } + raw, _ := json.Marshal(map[string]any{"data": []any{model}}) + return string(raw) +} + +// llmResponsesEvents returns the ordered /responses event objects the runtime's +// reducer expects. Used raw (one object == one WebSocket message) for the WS +// path and SSE-framed for the HTTP path. +func llmResponsesEvents(text, respID string) []map[string]any { + return []map[string]any{ + { + "type": "response.created", + "response": map[string]any{"id": respID, "object": "response", "status": "in_progress", "output": []any{}}, + }, + { + "type": "response.output_item.added", + "output_index": 0, + "item": map[string]any{"id": "msg_1", "type": "message", "role": "assistant", "content": []any{}}, + }, + { + "type": "response.content_part.added", + "output_index": 0, + "content_index": 0, + "part": map[string]any{"type": "output_text", "text": ""}, + }, + {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": text}, + {"type": "response.output_text.done", "output_index": 0, "content_index": 0, "text": text}, + { + "type": "response.completed", + "response": map[string]any{ + "id": respID, + "object": "response", + "status": "completed", + "output": []any{ + map[string]any{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": []any{map[string]any{"type": "output_text", "text": text}}, + }, + }, + "usage": map[string]any{"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + }, + }, + } +} + +func llmDrainRequest(req *copilot.LlmInferenceRequest) string { + var sb strings.Builder + for frame := range req.RequestBody { + sb.Write(frame) + } + return sb.String() +} + +func llmRespondBuffered(req *copilot.LlmInferenceRequest, status int, headers http.Header, body string) error { + llmDrainRequest(req) + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: status, Headers: headers}); err != nil { + return err + } + if body != "" { + if err := req.ResponseBody.Write([]byte(body)); err != nil { + return err + } + } + return req.ResponseBody.End() +} + +// llmServiceNonInference serves the model catalog, model session and policy +// endpoints. Returns true when the request was one of those (and answered). +func llmServiceNonInference(req *copilot.LlmInferenceRequest) (bool, error) { + url := strings.ToLower(req.URL) + switch { + case strings.HasSuffix(url, "/models"): + return true, llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, llmModelCatalog(nil)) + case strings.Contains(url, "/models/session"): + return true, llmRespondBuffered(req, 200, http.Header{}, "{}") + case strings.Contains(url, "/policy"): + return true, llmRespondBuffered(req, 200, http.Header{}, `{"state":"enabled"}`) + } + return false, nil +} + +// llmHandleNonInferenceModelTraffic serves every non-inference model-layer +// request, including an empty-JSON fallback for anything unrecognised. +func llmHandleNonInferenceModelTraffic(req *copilot.LlmInferenceRequest, supportedEndpoints []string) error { + url := strings.ToLower(req.URL) + switch { + case strings.HasSuffix(url, "/models"): + return llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, llmModelCatalog(supportedEndpoints)) + case strings.Contains(url, "/models/session"): + return llmRespondBuffered(req, 200, http.Header{}, "{}") + case strings.Contains(url, "/policy"): + return llmRespondBuffered(req, 200, http.Header{}, `{"state":"enabled"}`) + } + return llmRespondBuffered(req, 200, http.Header{"content-type": {"application/json"}}, "{}") +} + +// llmHandleInference synthesizes a well-formed inference response, dispatching +// by URL and the request body's stream flag exactly as a real reverse proxy +// would. +func llmHandleInference(req *copilot.LlmInferenceRequest, text string) error { + body := llmDrainRequest(req) + wantsStream := llmStreamTrue(body) + url := strings.ToLower(req.URL) + + if strings.Contains(url, "/responses") { + events := llmResponsesEvents(text, "resp_stub_1") + if !wantsStream { + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 200, Headers: http.Header{"content-type": {"application/json"}}}); err != nil { + return err + } + last := events[len(events)-1]["response"] + raw, _ := json.Marshal(last) + if err := req.ResponseBody.Write(raw); err != nil { + return err + } + return req.ResponseBody.End() + } + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 200, Headers: http.Header{"content-type": {"text/event-stream"}}}); err != nil { + return err + } + for _, event := range events { + if err := req.ResponseBody.Write([]byte(llmSSE(event["type"].(string), event))); err != nil { + return err + } + } + return req.ResponseBody.End() + } + + if strings.Contains(url, "/chat/completions") && wantsStream { + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 200, Headers: http.Header{"content-type": {"text/event-stream"}}}); err != nil { + return err + } + base := func() map[string]any { + return map[string]any{ + "id": "chatcmpl-stub-1", + "object": "chat.completion.chunk", + "created": 1, + "model": "claude-sonnet-4.5", + } + } + c1 := base() + c1["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{"role": "assistant", "content": ""}, "finish_reason": nil}} + c2 := base() + c2["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{"content": text}, "finish_reason": nil}} + c3 := base() + c3["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{}, "finish_reason": "stop"}} + c3["usage"] = map[string]any{"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12} + for _, chunk := range []map[string]any{c1, c2, c3} { + raw, _ := json.Marshal(chunk) + if err := req.ResponseBody.Write([]byte("data: " + string(raw) + "\n\n")); err != nil { + return err + } + } + if err := req.ResponseBody.Write([]byte("data: [DONE]\n\n")); err != nil { + return err + } + return req.ResponseBody.End() + } + + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 200, Headers: http.Header{"content-type": {"application/json"}}}); err != nil { + return err + } + raw, _ := json.Marshal(map[string]any{ + "id": "chatcmpl-stub-1", + "object": "chat.completion", + "created": 1, + "model": "claude-sonnet-4.5", + "choices": []any{ + map[string]any{"index": 0, "message": map[string]any{"role": "assistant", "content": text}, "finish_reason": "stop"}, + }, + "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + }) + if err := req.ResponseBody.Write(raw); err != nil { + return err + } + return req.ResponseBody.End() +} + +func assistantText(msg *copilot.SessionEvent) string { + if msg == nil { + return "" + } + if d, ok := msg.Data.(*copilot.AssistantMessageData); ok { + return d.Content + } + return "" +} + +// newLlmClient builds a client wired to handler via LlmInferenceConfig. The +// shared ctx harness client has no inference callback, so each inference test +// owns an isolated client carrying its own handler. extraEnv is appended to the +// spawned runtime's environment (e.g. to flip an ExP flag for the WS transport). +func newLlmClient(ctx *testharness.TestContext, handler copilot.LlmInferenceProvider, extraEnv ...string) *copilot.Client { + return ctx.NewClient(func(o *copilot.ClientOptions) { + o.LlmInference = &copilot.LlmInferenceConfig{Handler: handler} + if len(extraEnv) > 0 { + o.Env = append(o.Env, extraEnv...) + } + }) +} diff --git a/go/internal/e2e/llm_inference_session_id_e2e_test.go b/go/internal/e2e/llm_inference_session_id_e2e_test.go new file mode 100644 index 000000000..b89e107ce --- /dev/null +++ b/go/internal/e2e/llm_inference_session_id_e2e_test.go @@ -0,0 +1,135 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +type interceptedRequest struct { + url string + sessionID string +} + +type llmSessionIDHandler struct { + mu sync.Mutex + records []interceptedRequest +} + +func (h *llmSessionIDHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + h.mu.Lock() + h.records = append(h.records, interceptedRequest{url: req.URL, sessionID: req.SessionID}) + h.mu.Unlock() + if llmIsInferenceURL(req.URL) { + return llmHandleInference(req, llmSyntheticText) + } + return llmHandleNonInferenceModelTraffic(req, nil) +} + +func (h *llmSessionIDHandler) inferenceRecords() []interceptedRequest { + h.mu.Lock() + defer h.mu.Unlock() + var out []interceptedRequest + for _, r := range h.records { + if llmIsInferenceURL(r.url) { + out = append(out, r) + } + } + return out +} + +func TestLlmInferenceSessionID(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmSessionIDHandler{} + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + var capiSessionID string + + t.Run("threads session id into a CAPI session", func(t *testing.T) { + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + capiSessionID = session.SessionID + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + inference := handler.inferenceRecords() + if len(inference) == 0 { + t.Fatal("Expected at least one intercepted inference request") + } + for _, r := range inference { + if r.sessionID != capiSessionID { + t.Fatalf("CAPI inference request must carry session id %q, got %q", capiSessionID, r.sessionID) + } + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic") { + t.Fatalf("Expected synthetic content in assistant reply, got %q", assistantText(result)) + } + }) + + t.Run("threads session id into a BYOK session", func(t *testing.T) { + before := len(handler.inferenceRecords()) + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: "claude-sonnet-4.5", + Provider: &copilot.ProviderConfig{ + Type: "openai", + WireAPI: "responses", + BaseURL: "https://byok.invalid/v1", + APIKey: "byok-secret", + ModelID: "claude-sonnet-4.5", + WireModel: "claude-sonnet-4.5", + }, + }) + if err != nil { + t.Fatalf("Failed to create BYOK session: %v", err) + } + byokSessionID := session.SessionID + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + inference := handler.inferenceRecords() + if len(inference) <= before { + t.Fatal("Expected at least one intercepted BYOK inference request") + } + for _, r := range inference[before:] { + if r.sessionID != byokSessionID { + t.Fatalf("BYOK inference request must carry session id %q, got %q", byokSessionID, r.sessionID) + } + } + + if byokSessionID == capiSessionID { + t.Fatal("Expected per-session ids to differ between turns") + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic") { + t.Fatalf("Expected synthetic content in assistant reply, got %q", assistantText(result)) + } + }) +} diff --git a/go/internal/e2e/llm_inference_stream_e2e_test.go b/go/internal/e2e/llm_inference_stream_e2e_test.go new file mode 100644 index 000000000..07605277d --- /dev/null +++ b/go/internal/e2e/llm_inference_stream_e2e_test.go @@ -0,0 +1,74 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +type llmStreamingHandler struct { + mu sync.Mutex + received []*copilot.LlmInferenceRequest +} + +func (h *llmStreamingHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + h.mu.Lock() + h.received = append(h.received, req) + h.mu.Unlock() + if llmIsInferenceURL(req.URL) { + return llmHandleInference(req, llmSyntheticText) + } + return llmHandleNonInferenceModelTraffic(req, nil) +} + +func (h *llmStreamingHandler) inferenceCount() int { + h.mu.Lock() + defer h.mu.Unlock() + n := 0 + for _, r := range h.received { + if llmIsInferenceURL(r.URL) { + n++ + } + } + return n +} + +func TestLlmInferenceStream(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmStreamingHandler{} + client := newLlmClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + if handler.inferenceCount() == 0 { + t.Fatal("Expected at least one inference request via the callback") + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic") { + t.Fatalf("Expected synthetic content in assistant reply, got %q", assistantText(result)) + } +} diff --git a/go/internal/e2e/llm_inference_websocket_e2e_test.go b/go/internal/e2e/llm_inference_websocket_e2e_test.go new file mode 100644 index 000000000..98ef48f5d --- /dev/null +++ b/go/internal/e2e/llm_inference_websocket_e2e_test.go @@ -0,0 +1,124 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "encoding/json" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +const llmWSText = "OK from the synthetic ws." + +var llmWSSupportedEndpoints = []string{"/responses", "ws:/responses"} + +type llmWebSocketHandler struct { + mu sync.Mutex + received []*copilot.LlmInferenceRequest + wsRequestCount atomic.Int32 +} + +// handleHTTPInference answers single-shot HTTP inference requests (e.g. title +// generation) that don't pick the WebSocket transport. +func (h *llmWebSocketHandler) handleHTTPInference(req *copilot.LlmInferenceRequest) error { + llmDrainRequest(req) + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 200, Headers: http.Header{"content-type": {"text/event-stream"}}}); err != nil { + return err + } + for _, event := range llmResponsesEvents(llmWSText, "resp_stub_ws_1") { + if err := req.ResponseBody.Write([]byte(llmSSE(event["type"].(string), event))); err != nil { + return err + } + } + return req.ResponseBody.End() +} + +func (h *llmWebSocketHandler) handleWebSocket(req *copilot.LlmInferenceRequest) error { + // Ack the upgrade (status 101-equivalent) before any message flows. + if err := req.ResponseBody.Start(copilot.LlmInferenceResponseInit{Status: 101, Headers: http.Header{}}); err != nil { + return err + } + // One inbound chunk == one WS message (a response.create request). + for range req.RequestBody { + h.wsRequestCount.Add(1) + for _, event := range llmResponsesEvents(llmWSText, "resp_stub_ws_1") { + raw, _ := json.Marshal(event) + if err := req.ResponseBody.Write(raw); err != nil { + return nil + } + } + } + return req.ResponseBody.End() +} + +func (h *llmWebSocketHandler) OnLlmRequest(req *copilot.LlmInferenceRequest) error { + h.mu.Lock() + h.received = append(h.received, req) + h.mu.Unlock() + + if req.Transport == "websocket" { + return h.handleWebSocket(req) + } + if llmIsInferenceURL(req.URL) { + return h.handleHTTPInference(req) + } + return llmHandleNonInferenceModelTraffic(req, llmWSSupportedEndpoints) +} + +func (h *llmWebSocketHandler) wsRequests() int { + h.mu.Lock() + defer h.mu.Unlock() + n := 0 + for _, r := range h.received { + if r.Transport == "websocket" { + n++ + } + } + return n +} + +func TestLlmInferenceWebSocket(t *testing.T) { + ctx := testharness.NewTestContext(t) + handler := &llmWebSocketHandler{} + client := newLlmClient(ctx, handler, "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES=true") + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + // The main agent turn (tools present, not single-shot) selected the + // WebSocket transport and drove it through the callback. + if handler.wsRequests() == 0 { + t.Fatal("Expected at least one websocket request via the callback") + } + if handler.wsRequestCount.Load() == 0 { + t.Fatal("Expected the runtime to send at least one ws message") + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic ws") { + t.Fatalf("Expected synthetic ws content in assistant reply, got %q", assistantText(result)) + } +} diff --git a/go/internal/e2e/session_e2e_test.go b/go/internal/e2e/session_e2e_test.go index dc2d54ca8..9e21e82be 100644 --- a/go/internal/e2e/session_e2e_test.go +++ b/go/internal/e2e/session_e2e_test.go @@ -1104,7 +1104,7 @@ func TestSessionBlobAttachmentE2E(t *testing.T) { Prompt: "Describe this image", Attachments: []copilot.Attachment{ &copilot.AttachmentBlob{ - Data: data, + Data: &data, MIMEType: mimeType, DisplayName: &displayName, }, diff --git a/go/llm_inference_provider.go b/go/llm_inference_provider.go new file mode 100644 index 000000000..8c98622fe --- /dev/null +++ b/go/llm_inference_provider.go @@ -0,0 +1,503 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package copilot + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "sync" + + "github.com/github/copilot-sdk/go/rpc" +) + +// LlmInferenceRequest is an outbound model-layer request the runtime is asking +// the SDK consumer to service on its behalf. +// +// It is a low-level shape: URL / method / headers verbatim, the request body +// delivered as a stream of frames, and the response written through +// ResponseBody. The runtime does not classify the request (no provider type, +// endpoint kind, or wire API); consumers that need that information derive it +// from the URL and headers. For the idiomatic [net/http] view, use +// [LlmRequestHandler] instead of implementing [LlmInferenceProvider] directly. +type LlmInferenceRequest struct { + // RequestID is an opaque runtime-minted id, stable across the request lifecycle. + RequestID string + // SessionID is the id of the runtime session that triggered this request, or + // empty when the request was issued outside any session (for example the + // startup model catalog). + SessionID string + // Method is the HTTP method (GET, POST, ...). + Method string + // URL is the absolute request URL. + URL string + // Headers are the request headers, multi-valued. + Headers http.Header + // Transport is the transport the runtime would otherwise use: "http" (the + // default, covering plain HTTP and SSE) or "websocket" (a full-duplex + // message channel where each RequestBody frame is one inbound message and + // each ResponseBody write is one outbound message). + Transport string + // RequestBody yields request body frames as they arrive from the runtime. + // The channel is closed when the body ends or the request is cancelled; + // check Context.Err() to distinguish a clean end from a cancellation. + RequestBody <-chan []byte + // Context is cancelled when the runtime cancels this in-flight request (for + // example because the agent turn was aborted upstream). Pass it to the + // outbound call so the upstream is torn down too. + Context context.Context + // ResponseBody is the sink the consumer writes the upstream response into. + // Call Start exactly once before writing body frames, then zero or more + // Write/WriteBinary calls, and finish with End or Error. + ResponseBody LlmInferenceResponseSink +} + +// LlmInferenceResponseInit is the response head passed to +// [LlmInferenceResponseSink.Start]. +type LlmInferenceResponseInit struct { + Status int + StatusText string + Headers http.Header +} + +// LlmInferenceResponseSink is the sink a consumer writes an upstream response +// into. The state machine is strict: Start once, then zero or more +// Write/WriteBinary, then exactly one of End or Error. Calling out of order +// returns an error. +type LlmInferenceResponseSink interface { + // Start sends the response head (status + headers) back to the runtime. + Start(init LlmInferenceResponseInit) error + // Write sends a body frame as UTF-8 text (the common case for JSON / SSE). + Write(data []byte) error + // WriteBinary sends a body frame as binary (base64 on the wire). + WriteBinary(data []byte) error + // End marks end-of-stream cleanly. + End() error + // Error marks end-of-stream with a transport-level failure. code is optional. + Error(message string, code string) error +} + +// LlmInferenceProvider is the low-level registration seam. The SDK consumer +// implements OnLlmRequest; the same callback handles both buffered and +// streaming responses by calling ResponseBody.Write zero or more times before +// End. Most consumers should embed or use [LlmRequestHandler] instead, which +// exposes idiomatic [net/http] request/response seams. +type LlmInferenceProvider interface { + // OnLlmRequest is called once per outbound model-layer request the consumer + // has opted to handle. The consumer must eventually call ResponseBody.End or + // ResponseBody.Error; returning a non-nil error surfaces a transport-level + // failure to the runtime (equivalent to ResponseBody.Error when Start has + // not yet been called). + OnLlmRequest(req *LlmInferenceRequest) error +} + +// LlmInferenceConfig configures a connection-level LLM inference callback. When +// set on [ClientOptions], the client registers as the inference provider on +// connect, and the runtime routes its model-layer HTTP and WebSocket traffic +// through Handler instead of issuing the calls itself. +type LlmInferenceConfig struct { + // Handler services intercepted requests. Use a [*LlmRequestHandler] for the + // idiomatic net/http view, or any type implementing [LlmInferenceProvider] + // for full low-level control. + Handler LlmInferenceProvider +} + +// frameQueue is an unbounded FIFO of body frames, decoupling the RPC dispatch +// goroutine (which only pushes) from the consumer goroutine (which pops). +type frameQueue struct { + mu sync.Mutex + cond *sync.Cond + items [][]byte + done bool +} + +func newFrameQueue() *frameQueue { + q := &frameQueue{} + q.cond = sync.NewCond(&q.mu) + return q +} + +func (q *frameQueue) push(b []byte) { + q.mu.Lock() + if !q.done { + q.items = append(q.items, b) + } + q.cond.Signal() + q.mu.Unlock() +} + +func (q *frameQueue) close() { + q.mu.Lock() + q.done = true + q.cond.Broadcast() + q.mu.Unlock() +} + +func (q *frameQueue) pop() ([]byte, bool) { + q.mu.Lock() + defer q.mu.Unlock() + for len(q.items) == 0 && !q.done { + q.cond.Wait() + } + if len(q.items) > 0 { + b := q.items[0] + q.items = q.items[1:] + return b, true + } + return nil, false +} + +type llmPendingState struct { + mu sync.Mutex + queue *frameQueue + ctx context.Context + cancel context.CancelFunc + started bool + finished bool + cancelled bool +} + +type llmInferenceAdapter struct { + handler LlmInferenceProvider + getRPC func() *rpc.ServerLlmInferenceAPI + + mu sync.Mutex + pending map[string]*llmPendingState + // staged buffers chunks that arrive before their start frame — a reordering + // the runtime's ordered dispatch should make impossible, drained the moment + // the matching start frame registers so a body byte is never dropped. + staged map[string][]*rpc.LlmInferenceHTTPRequestChunkRequest +} + +// newLlmInferenceAdapter adapts an [LlmInferenceProvider] into the generated +// rpc.LlmInferenceHandler consumed by the SDK's RPC dispatcher. +func newLlmInferenceAdapter(handler LlmInferenceProvider, getRPC func() *rpc.ServerLlmInferenceAPI) rpc.LlmInferenceHandler { + return &llmInferenceAdapter{ + handler: handler, + getRPC: getRPC, + pending: make(map[string]*llmPendingState), + staged: make(map[string][]*rpc.LlmInferenceHTTPRequestChunkRequest), + } +} + +func (a *llmInferenceAdapter) HttpRequestStart(params *rpc.LlmInferenceHTTPRequestStartRequest) (*rpc.LlmInferenceHTTPRequestStartResult, error) { + ctx, cancel := context.WithCancel(context.Background()) + queue := newFrameQueue() + bodyCh := make(chan []byte) + state := &llmPendingState{queue: queue, ctx: ctx, cancel: cancel} + + go func() { + defer close(bodyCh) + for { + b, ok := queue.pop() + if !ok { + return + } + select { + case bodyCh <- b: + case <-ctx.Done(): + return + } + } + }() + + a.mu.Lock() + a.pending[params.RequestID] = state + staged := a.staged[params.RequestID] + delete(a.staged, params.RequestID) + a.mu.Unlock() + + for _, chunk := range staged { + a.routeChunk(state, chunk) + } + + transport := "http" + if params.Transport != nil { + transport = string(*params.Transport) + } + sessionID := "" + if params.SessionID != nil { + sessionID = *params.SessionID + } + headers := http.Header{} + for k, v := range params.Headers { + headers[k] = append([]string(nil), v...) + } + sink := &llmResponseSink{requestID: params.RequestID, adapter: a, state: state} + req := &LlmInferenceRequest{ + RequestID: params.RequestID, + SessionID: sessionID, + Method: params.Method, + URL: params.URL, + Headers: headers, + Transport: transport, + RequestBody: bodyCh, + Context: ctx, + ResponseBody: sink, + } + go a.runHandler(req, sink, state) + return &rpc.LlmInferenceHTTPRequestStartResult{}, nil +} + +func (a *llmInferenceAdapter) HttpRequestChunk(params *rpc.LlmInferenceHTTPRequestChunkRequest) (*rpc.LlmInferenceHTTPRequestChunkResult, error) { + a.mu.Lock() + state := a.pending[params.RequestID] + if state == nil { + a.staged[params.RequestID] = append(a.staged[params.RequestID], params) + a.mu.Unlock() + return &rpc.LlmInferenceHTTPRequestChunkResult{}, nil + } + a.mu.Unlock() + a.routeChunk(state, params) + return &rpc.LlmInferenceHTTPRequestChunkResult{}, nil +} + +func (a *llmInferenceAdapter) routeChunk(state *llmPendingState, params *rpc.LlmInferenceHTTPRequestChunkRequest) { + if params.Cancel != nil && *params.Cancel { + state.mu.Lock() + state.cancelled = true + state.mu.Unlock() + state.cancel() + state.queue.close() + return + } + if params.Data != "" { + binary := params.Binary != nil && *params.Binary + if data, err := decodeChunkData(params.Data, binary); err == nil { + state.queue.push(data) + } + } + if params.End != nil && *params.End { + state.queue.close() + } +} + +func (a *llmInferenceAdapter) runHandler(req *LlmInferenceRequest, sink *llmResponseSink, state *llmPendingState) { + err := a.handler.OnLlmRequest(req) + state.mu.Lock() + finished := state.finished + cancelled := state.cancelled + state.mu.Unlock() + if err != nil { + if cancelled || state.ctx.Err() != nil { + a.finishCancelled(sink, state) + return + } + a.failViaSink(sink, state, err.Error()) + return + } + if !finished { + a.failViaSink(sink, state, "LLM inference provider returned without finalising the response (call ResponseBody.End() or .Error())") + } +} + +func (a *llmInferenceAdapter) failViaSink(sink *llmResponseSink, state *llmPendingState, message string) { + state.mu.Lock() + finished := state.finished + started := state.started + state.mu.Unlock() + if finished { + return + } + if !started { + _ = sink.Start(LlmInferenceResponseInit{Status: 502, Headers: http.Header{}}) + } + _ = sink.Error(message, "") +} + +func (a *llmInferenceAdapter) finishCancelled(sink *llmResponseSink, state *llmPendingState) { + state.mu.Lock() + finished := state.finished + started := state.started + state.mu.Unlock() + if finished { + return + } + if !started { + _ = sink.Start(LlmInferenceResponseInit{Status: 499, Headers: http.Header{}}) + } + _ = sink.Error("Request cancelled by runtime", "cancelled") +} + +func (a *llmInferenceAdapter) removePending(requestID string) { + a.mu.Lock() + delete(a.pending, requestID) + a.mu.Unlock() +} + +func decodeChunkData(data string, binary bool) ([]byte, error) { + if binary { + return base64.StdEncoding.DecodeString(data) + } + return []byte(data), nil +} + +type llmResponseSink struct { + requestID string + adapter *llmInferenceAdapter + state *llmPendingState +} + +func (s *llmResponseSink) rpcAPI() (*rpc.ServerLlmInferenceAPI, error) { + r := s.adapter.getRPC() + if r == nil { + return nil, fmt.Errorf("LLM inference response sink used after RPC connection closed") + } + return r, nil +} + +// rejectedByRuntime is invoked when the runtime acknowledges a response frame +// with accepted=false, meaning it has dropped the request (for example because +// it cancelled). It aborts the consumer's upstream work and stops emitting. +func (s *llmResponseSink) rejectedByRuntime() error { + s.state.mu.Lock() + if !s.state.cancelled { + s.state.cancelled = true + s.state.cancel() + } + s.state.finished = true + s.state.mu.Unlock() + s.adapter.removePending(s.requestID) + return fmt.Errorf("LLM inference response was rejected by the runtime (request no longer active)") +} + +func (s *llmResponseSink) Start(init LlmInferenceResponseInit) error { + s.state.mu.Lock() + if s.state.started { + s.state.mu.Unlock() + return fmt.Errorf("LLM inference response sink Start() called twice") + } + if s.state.finished { + s.state.mu.Unlock() + return fmt.Errorf("LLM inference response sink already finished") + } + s.state.started = true + s.state.mu.Unlock() + + api, err := s.rpcAPI() + if err != nil { + return err + } + var statusText *string + if init.StatusText != "" { + st := init.StatusText + statusText = &st + } + headers := map[string][]string(init.Headers) + if headers == nil { + headers = map[string][]string{} + } + result, err := api.HttpResponseStart(context.Background(), &rpc.LlmInferenceHTTPResponseStartRequest{ + RequestID: s.requestID, + Status: int64(init.Status), + StatusText: statusText, + Headers: headers, + }) + if err != nil { + return err + } + if !result.Accepted { + return s.rejectedByRuntime() + } + return nil +} + +func (s *llmResponseSink) Write(data []byte) error { + return s.write(string(data), false) +} + +func (s *llmResponseSink) WriteBinary(data []byte) error { + return s.write(base64.StdEncoding.EncodeToString(data), true) +} + +func (s *llmResponseSink) write(data string, binary bool) error { + s.state.mu.Lock() + cancelled := s.state.cancelled + started := s.state.started + finished := s.state.finished + s.state.mu.Unlock() + if cancelled { + return fmt.Errorf("LLM inference request was cancelled by the runtime") + } + if !started { + return fmt.Errorf("LLM inference response sink Write() called before Start()") + } + if finished { + return fmt.Errorf("LLM inference response sink Write() called after End()/Error()") + } + api, err := s.rpcAPI() + if err != nil { + return err + } + end := false + chunk := &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: data, + End: &end, + } + if binary { + b := true + chunk.Binary = &b + } + result, err := api.HttpResponseChunk(context.Background(), chunk) + if err != nil { + return err + } + if !result.Accepted { + return s.rejectedByRuntime() + } + return nil +} + +func (s *llmResponseSink) End() error { + s.state.mu.Lock() + if s.state.finished { + s.state.mu.Unlock() + return nil + } + s.state.finished = true + s.state.mu.Unlock() + s.adapter.removePending(s.requestID) + api, err := s.rpcAPI() + if err != nil { + return err + } + end := true + _, err = api.HttpResponseChunk(context.Background(), &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: "", + End: &end, + }) + return err +} + +func (s *llmResponseSink) Error(message string, code string) error { + s.state.mu.Lock() + if s.state.finished { + s.state.mu.Unlock() + return nil + } + s.state.finished = true + s.state.mu.Unlock() + s.adapter.removePending(s.requestID) + api, err := s.rpcAPI() + if err != nil { + return err + } + end := true + chunkErr := &rpc.LlmInferenceHTTPResponseChunkError{Message: message} + if code != "" { + c := code + chunkErr.Code = &c + } + _, err = api.HttpResponseChunk(context.Background(), &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: "", + End: &end, + Error: chunkErr, + }) + return err +} diff --git a/go/llm_request_handler.go b/go/llm_request_handler.go new file mode 100644 index 000000000..3852886f2 --- /dev/null +++ b/go/llm_request_handler.go @@ -0,0 +1,442 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package copilot + +import ( + "bytes" + "context" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/coder/websocket" +) + +// Hop-by-hop and length headers the transport recomputes; forwarding them +// verbatim corrupts the request. +var forbiddenRequestHeaders = map[string]struct{}{ + "host": {}, + "connection": {}, + "content-length": {}, + "transfer-encoding": {}, + "keep-alive": {}, + "upgrade": {}, + "proxy-connection": {}, + "te": {}, + "trailer": {}, +} + +func isForbiddenRequestHeader(name string) bool { + lower := strings.ToLower(name) + if _, ok := forbiddenRequestHeaders[lower]; ok { + return true + } + return strings.HasPrefix(lower, "sec-websocket-") +} + +var sharedHTTPTransport = func() http.RoundTripper { + t := http.DefaultTransport.(*http.Transport).Clone() + t.DisableCompression = true + return t +}() + +// LlmRequestContext is the per-request context handed to every +// [LlmRequestHandler] seam. +type LlmRequestContext struct { + RequestID string + SessionID string + Transport string + URL string + Headers http.Header + // Context is cancelled when the runtime cancels this in-flight request. + Context context.Context +} + +// LlmWebSocketCloseStatus is the terminal status for a callback-owned WebSocket +// connection. +type LlmWebSocketCloseStatus struct { + Description string + Code string + Err error +} + +// LlmRequestHandler is the idiomatic base for consumers that observe or replace +// LLM inference requests. It implements [LlmInferenceProvider] by translating +// each request into Go's canonical net/http types. +// +// HTTP requests are forwarded through Transport (an [http.RoundTripper]); supply +// a custom RoundTripper to mutate the request, post-process the response, or +// replace the call entirely. WebSocket requests are serviced by OpenWebSocket; +// supply one to mutate the handshake or return a fully custom handler. +type LlmRequestHandler struct { + // Transport forwards HTTP requests. When nil a shared default transport is + // used. RoundTrip is called directly, so redirects are not followed. + Transport http.RoundTripper + // OpenWebSocket returns a per-connection WebSocket handler. When nil a + // transparent forwarding connection to the request URL is opened. + OpenWebSocket func(ctx *LlmRequestContext) (CopilotWebSocketHandler, error) +} + +// OnLlmRequest implements [LlmInferenceProvider]. +func (h *LlmRequestHandler) OnLlmRequest(req *LlmInferenceRequest) error { + rctx := &LlmRequestContext{ + RequestID: req.RequestID, + SessionID: req.SessionID, + Transport: req.Transport, + URL: req.URL, + Headers: req.Headers, + Context: req.Context, + } + if req.Transport == "websocket" { + return h.handleWebSocket(req, rctx) + } + return h.handleHTTP(req, rctx) +} + +func (h *LlmRequestHandler) roundTripper() http.RoundTripper { + if h.Transport != nil { + return h.Transport + } + return sharedHTTPTransport +} + +func (h *LlmRequestHandler) handleHTTP(req *LlmInferenceRequest, _ *LlmRequestContext) error { + httpReq, err := buildHTTPRequest(req) + if err != nil { + return err + } + resp, err := h.roundTripper().RoundTrip(httpReq) + if err != nil { + return err + } + defer resp.Body.Close() + return streamResponseToSink(resp, req) +} + +func buildHTTPRequest(req *LlmInferenceRequest) (*http.Request, error) { + body := drainBody(req.RequestBody) + method := strings.ToUpper(req.Method) + var bodyReader io.Reader + if len(body) > 0 && method != http.MethodGet && method != http.MethodHead { + bodyReader = bytes.NewReader(body) + } + httpReq, err := http.NewRequestWithContext(req.Context, method, req.URL, bodyReader) + if err != nil { + return nil, err + } + for name, values := range req.Headers { + if isForbiddenRequestHeader(name) { + continue + } + for _, v := range values { + httpReq.Header.Add(name, v) + } + } + return httpReq, nil +} + +func drainBody(ch <-chan []byte) []byte { + var buf bytes.Buffer + for frame := range ch { + buf.Write(frame) + } + return buf.Bytes() +} + +func streamResponseToSink(resp *http.Response, req *LlmInferenceRequest) error { + init := LlmInferenceResponseInit{ + Status: resp.StatusCode, + StatusText: statusText(resp), + Headers: cloneHeader(resp.Header), + } + if err := req.ResponseBody.Start(init); err != nil { + return err + } + buf := make([]byte, 32*1024) + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + frame := make([]byte, n) + copy(frame, buf[:n]) + if err := req.ResponseBody.WriteBinary(frame); err != nil { + return err + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return req.ResponseBody.Error(readErr.Error(), "") + } + } + return req.ResponseBody.End() +} + +func statusText(resp *http.Response) string { + text := strings.TrimSpace(strings.TrimPrefix(resp.Status, strconv.Itoa(resp.StatusCode))) + return text +} + +func cloneHeader(h http.Header) http.Header { + out := http.Header{} + for k, vs := range h { + out[k] = append([]string(nil), vs...) + } + return out +} + +// WebSocketResponseWriter forwards upstream→runtime WebSocket messages back into +// the runtime response. A [CopilotWebSocketHandler] receives one in [CopilotWebSocketHandler.Open]. +type WebSocketResponseWriter interface { + // SendText forwards an upstream text message to the runtime. + SendText(data []byte) error + // SendBinary forwards an upstream binary message to the runtime. + SendBinary(data []byte) error +} + +// CopilotWebSocketHandler is a per-connection WebSocket handler returned by +// [LlmRequestHandler.OpenWebSocket]. The default implementation is +// [ForwardingWebSocketHandler]; a full transport replacement implements this +// interface directly. +type CopilotWebSocketHandler interface { + // Open establishes the connection and starts forwarding upstream→runtime + // messages into resp. It must not block. ctx is cancelled on teardown. + Open(ctx context.Context, resp WebSocketResponseWriter) error + // SendRequestMessage forwards one runtime→upstream message. + SendRequestMessage(ctx context.Context, data []byte) error + // Done is closed when the upstream connection completes (closed or errored). + Done() <-chan struct{} + // Err returns the terminal error after Done is closed, or nil on clean close. + Err() error + // Close tears down the connection. + Close() error +} + +func (h *LlmRequestHandler) handleWebSocket(req *LlmInferenceRequest, rctx *LlmRequestContext) error { + var handler CopilotWebSocketHandler + var err error + if h.OpenWebSocket != nil { + handler, err = h.OpenWebSocket(rctx) + } else { + handler = NewForwardingWebSocketHandler(rctx.URL, rctx.Headers) + } + if err != nil { + return err + } + + writer := &wsResponseWriter{sink: req.ResponseBody} + if err := writer.start(); err != nil { + return err + } + if err := handler.Open(req.Context, writer); err != nil { + return writer.fail(err.Error(), "") + } + defer func() { _ = handler.Close() }() + + clientDone := make(chan struct{}) + go func() { + defer close(clientDone) + for { + select { + case frame, ok := <-req.RequestBody: + if !ok { + return + } + if err := handler.SendRequestMessage(req.Context, frame); err != nil { + return + } + case <-req.Context.Done(): + return + } + } + }() + + select { + case <-handler.Done(): + if e := handler.Err(); e != nil { + return writer.fail(e.Error(), "") + } + return writer.end() + case <-clientDone: + _ = handler.Close() + <-handler.Done() + if e := handler.Err(); e != nil { + return writer.fail(e.Error(), "") + } + return writer.end() + case <-req.Context.Done(): + return writer.fail("Request cancelled by runtime", "cancelled") + } +} + +// wsResponseWriter serialises WebSocket response writes into the sink. +type wsResponseWriter struct { + mu sync.Mutex + sink LlmInferenceResponseSink + started bool + completed bool +} + +func (w *wsResponseWriter) start() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.started { + return nil + } + w.started = true + return w.sink.Start(LlmInferenceResponseInit{Status: 101, Headers: http.Header{}}) +} + +func (w *wsResponseWriter) SendText(data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + return w.sink.Write(data) +} + +func (w *wsResponseWriter) SendBinary(data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + return w.sink.WriteBinary(data) +} + +func (w *wsResponseWriter) end() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + w.completed = true + return w.sink.End() +} + +func (w *wsResponseWriter) fail(message string, code string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + w.completed = true + return w.sink.Error(message, code) +} + +// ForwardingWebSocketHandler is the default [CopilotWebSocketHandler]: it dials +// the real upstream and runs a receive loop forwarding upstream→runtime +// messages. Set OnSendRequestMessage / OnSendResponseMessage to observe, +// transform, or drop messages in either direction. +type ForwardingWebSocketHandler struct { + URL string + Headers http.Header + // OnSendRequestMessage observes or transforms each runtime→upstream frame. + // Return nil to drop the frame. + OnSendRequestMessage func(data []byte) []byte + // OnSendResponseMessage observes or transforms each upstream→runtime frame. + // Return nil to drop the frame. + OnSendResponseMessage func(data []byte) []byte + + conn *websocket.Conn + resp WebSocketResponseWriter + done chan struct{} + err error + closeOnce sync.Once +} + +// NewForwardingWebSocketHandler creates a forwarding handler targeting url with +// the given handshake headers. +func NewForwardingWebSocketHandler(url string, headers http.Header) *ForwardingWebSocketHandler { + return &ForwardingWebSocketHandler{URL: url, Headers: headers, done: make(chan struct{})} +} + +func (f *ForwardingWebSocketHandler) Open(ctx context.Context, resp WebSocketResponseWriter) error { + f.resp = resp + if f.done == nil { + f.done = make(chan struct{}) + } + opts := &websocket.DialOptions{HTTPHeader: f.dialHeaders()} + conn, _, err := websocket.Dial(ctx, f.URL, opts) + if err != nil { + return err + } + conn.SetReadLimit(-1) + f.conn = conn + go f.receiveLoop(ctx) + return nil +} + +func (f *ForwardingWebSocketHandler) dialHeaders() http.Header { + out := http.Header{} + for name, values := range f.Headers { + if isForbiddenRequestHeader(name) { + continue + } + for _, v := range values { + out.Add(name, v) + } + } + return out +} + +func (f *ForwardingWebSocketHandler) receiveLoop(ctx context.Context) { + defer close(f.done) + for { + typ, data, err := f.conn.Read(ctx) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { + f.err = nil + } else if ctx.Err() != nil { + f.err = nil + } else { + f.err = err + } + return + } + out := data + if f.OnSendResponseMessage != nil { + out = f.OnSendResponseMessage(data) + if out == nil { + continue + } + } + if typ == websocket.MessageBinary { + _ = f.resp.SendBinary(out) + } else { + _ = f.resp.SendText(out) + } + } +} + +func (f *ForwardingWebSocketHandler) SendRequestMessage(ctx context.Context, data []byte) error { + out := data + if f.OnSendRequestMessage != nil { + out = f.OnSendRequestMessage(data) + if out == nil { + return nil + } + } + if f.conn == nil { + return nil + } + return f.conn.Write(ctx, websocket.MessageText, out) +} + +func (f *ForwardingWebSocketHandler) Done() <-chan struct{} { return f.done } + +func (f *ForwardingWebSocketHandler) Err() error { return f.err } + +func (f *ForwardingWebSocketHandler) Close() error { + f.closeOnce.Do(func() { + if f.conn != nil { + _ = f.conn.Close(websocket.StatusNormalClosure, "") + } + }) + return nil +} diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index b00bf1c30..db70eb793 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -396,12 +396,21 @@ func (r RawAttachmentData) Type() AttachmentType { // Blob attachment with inline base64-encoded data // Experimental: AttachmentBlob is part of an experimental API and may change or be removed. type AttachmentBlob struct { - // Base64-encoded content - Data string `json:"data"` + // Internal: content-addressed id of the session.binary_asset event holding this + // attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + AssetID *string `json:"assetId,omitempty"` + // Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + ByteLength *int64 `json:"byteLength,omitempty"` + // Base64-encoded content. Present on input and for external consumers; replaced by an + // internal `assetId` reference in persisted events when interned to a content-addressed + // asset. + Data *string `json:"data,omitempty"` // User-facing display name for the attachment DisplayName *string `json:"displayName,omitempty"` // MIME type of the inline data MIMEType string `json:"mimeType"` + // Internal: why model-facing bytes are absent from persistence. Absent externally. + OmittedReason *OmittedBinaryOmittedReason `json:"omittedReason,omitempty"` } func (AttachmentBlob) attachment() {} @@ -454,10 +463,20 @@ func (AttachmentExtensionContext) Type() AttachmentType { // File attachment // Experimental: AttachmentFile is part of an experimental API and may change or be removed. type AttachmentFile struct { + // Internal: content-addressed id of the session.binary_asset event holding this + // attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + AssetID *string `json:"assetId,omitempty"` + // Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + ByteLength *int64 `json:"byteLength,omitempty"` // User-facing display name for the attachment DisplayName string `json:"displayName"` // Optional line range to scope the attachment to a specific section of the file LineRange *AttachmentFileLineRange `json:"lineRange,omitempty"` + // Internal: MIME type of the file's model-facing bytes (post-resize for images). Set when + // the file's bytes are interned to an asset. Absent externally. + MIMEType *string `json:"mimeType,omitempty"` + // Internal: why model-facing bytes are absent from persistence. Absent externally. + OmittedReason *OmittedBinaryOmittedReason `json:"omittedReason,omitempty"` // Absolute file path Path string `json:"path"` } @@ -1916,6 +1935,134 @@ type InstructionSource struct { Type InstructionSourceType `json:"type"` } +// HTTP headers as a map from lowercased header name to a list of values. Multi-valued +// headers (e.g. Set-Cookie) preserve all values. +// Experimental: LlmInferenceHeaders is part of an experimental API and may change or be +// removed. +type LlmInferenceHeaders map[string][]string + +// A request body chunk or cancellation signal. +type LlmInferenceHTTPRequestChunkRequest struct { + // When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + Binary *bool `json:"binary,omitempty"` + // When true, the runtime is cancelling the in-flight request (e.g. upstream consumer + // aborted). `data` is ignored. Implies end-of-request. + Cancel *bool `json:"cancel,omitempty"` + // Optional human-readable reason for the cancellation, propagated for logging. + CancelReason *string `json:"cancelReason,omitempty"` + // Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when + // `binary` is true. May be empty. + Data string `json:"data"` + // When true, this is the final body chunk for the request. The SDK may rely on having + // received an end-marked chunk before treating the request body as complete. + End *bool `json:"end,omitempty"` + // Matches the requestId from the originating httpRequestStart frame. + RequestID string `json:"requestId"` +} + +// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as +// fire-and-forget. +type LlmInferenceHTTPRequestChunkResult struct { +} + +// The head of an outbound model-layer HTTP request. +type LlmInferenceHTTPRequestStartRequest struct { + Headers map[string][]string `json:"headers"` + // HTTP method, e.g. GET, POST. + Method string `json:"method"` + // Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate + // httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies + // back to the runtime. + RequestID string `json:"requestId"` + // Id of the runtime session that triggered this request, when one is in scope. Absent for + // requests issued outside any session (e.g. startup model-catalog or capability + // resolution). This is a payload field — not a dispatch key — because the client-global API + // is registered process-wide rather than per session. + SessionID *string `json:"sessionId,omitempty"` + // Transport the runtime would otherwise use for this request. `http` (the default when + // absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message + // channel where each body chunk maps to one WebSocket message and the `binary` flag + // distinguishes text from binary frames. The SDK consumer uses this to decide whether to + // service the request with an HTTP client or a WebSocket client. It is the one piece of + // request metadata the consumer cannot reliably infer from the URL or headers alone. + Transport *LlmInferenceHTTPRequestStartTransport `json:"transport,omitempty"` + // Absolute request URL. + URL string `json:"url"` +} + +// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it +// does not imply the request will succeed. +type LlmInferenceHTTPRequestStartResult struct { +} + +// Set to terminate the response with a transport-level failure. Implies end-of-stream; any +// further chunks for this requestId are ignored. +// Experimental: LlmInferenceHTTPResponseChunkError is part of an experimental API and may +// change or be removed. +type LlmInferenceHTTPResponseChunkError struct { + // Optional machine-readable error code. + Code *string `json:"code,omitempty"` + // Human-readable failure description. + Message string `json:"message"` +} + +// A response body chunk or terminal error. +// Experimental: LlmInferenceHTTPResponseChunkRequest is part of an experimental API and may +// change or be removed. +type LlmInferenceHTTPResponseChunkRequest struct { + // When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + Binary *bool `json:"binary,omitempty"` + // Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when + // `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk + // with empty data and end=true). + Data string `json:"data"` + // When true, this is the final body chunk for the response. The runtime treats the response + // body as complete after receiving an end-marked chunk. + End *bool `json:"end,omitempty"` + // Set to terminate the response with a transport-level failure. Implies end-of-stream; any + // further chunks for this requestId are ignored. + Error *LlmInferenceHTTPResponseChunkError `json:"error,omitempty"` + // Matches the requestId from the originating httpRequestStart frame. + RequestID string `json:"requestId"` +} + +// Whether the chunk was accepted. +// Experimental: LlmInferenceHTTPResponseChunkResult is part of an experimental API and may +// change or be removed. +type LlmInferenceHTTPResponseChunkResult struct { + // True when the chunk was matched to a pending request; false when unknown. + Accepted bool `json:"accepted"` +} + +// Response head. +// Experimental: LlmInferenceHTTPResponseStartRequest is part of an experimental API and may +// change or be removed. +type LlmInferenceHTTPResponseStartRequest struct { + Headers map[string][]string `json:"headers"` + // Matches the requestId from the originating httpRequestStart frame. + RequestID string `json:"requestId"` + // HTTP status code. + Status int64 `json:"status"` + // Optional HTTP status reason phrase. + StatusText *string `json:"statusText,omitempty"` +} + +// Whether the start frame was accepted. +// Experimental: LlmInferenceHTTPResponseStartResult is part of an experimental API and may +// change or be removed. +type LlmInferenceHTTPResponseStartResult struct { + // True when the response start was matched to a pending request; false when unknown. + Accepted bool `json:"accepted"` +} + +// Indicates whether the calling client was registered as the LLM inference provider. +// Experimental: LlmInferenceSetProviderResult is part of an experimental API and may change +// or be removed. +type LlmInferenceSetProviderResult struct { + // Whether the provider was set successfully + Success bool `json:"success"` +} + // Schema for the `LocalSessionMetadataValue` type. // Experimental: LocalSessionMetadataValue is part of an experimental API and may change or // be removed. @@ -2974,8 +3121,10 @@ type ModelBilling struct { type ModelBillingTokenPrices struct { // Number of tokens per standard billing batch BatchSize *int64 `json:"batchSize,omitempty"` - // AI Credits cost per billing batch of cached tokens + // AI Credits cost per billing batch of cache-read tokens CachePrice *float64 `json:"cachePrice,omitempty"` + // AI Credits cost per billing batch of cache-write (cache creation) tokens. + CacheWritePrice *float64 `json:"cacheWritePrice,omitempty"` // Prompt token budget (max_prompt_tokens) for the default tier. The total context window is // this value plus the model's max_output_tokens. ContextMax *int64 `json:"contextMax,omitempty"` @@ -2989,8 +3138,10 @@ type ModelBillingTokenPrices struct { // Long context tier pricing (available for models with extended context windows) type ModelBillingTokenPricesLongContext struct { - // AI Credits cost per billing batch of cached tokens + // AI Credits cost per billing batch of cache-read tokens CachePrice *float64 `json:"cachePrice,omitempty"` + // AI Credits cost per billing batch of cache-write (cache creation) tokens. + CacheWritePrice *float64 `json:"cacheWritePrice,omitempty"` // Prompt token budget (max_prompt_tokens) for the long context tier. The total context // window is this value plus the model's max_output_tokens. ContextMax *int64 `json:"contextMax,omitempty"` @@ -5291,7 +5442,7 @@ type SandboxConfig struct { AddCurrentWorkingDirectory *bool `json:"addCurrentWorkingDirectory,omitempty"` // Raw `ContainerConfig` (per `@microsoft/mxc-sdk`) passed directly to // `spawnSandboxFromConfig`, bypassing policy merging. - Config map[string]any `json:"config,omitzero"` + Config any `json:"config,omitempty"` // Whether sandboxing is enabled for the session. Enabled bool `json:"enabled"` // User-managed sandbox policy fragment merged into the auto-discovered base policy. @@ -9262,6 +9413,24 @@ const ( InstructionSourceTypeVscode InstructionSourceType = "vscode" ) +// Transport the runtime would otherwise use for this request. `http` (the default when +// absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message +// channel where each body chunk maps to one WebSocket message and the `binary` flag +// distinguishes text from binary frames. The SDK consumer uses this to decide whether to +// service the request with an HTTP client or a WebSocket client. It is the one piece of +// request metadata the consumer cannot reliably infer from the URL or headers alone. +type LlmInferenceHTTPRequestStartTransport string + +const ( + // Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a + // status line, headers, and a (possibly streamed) body. + LlmInferenceHTTPRequestStartTransportHTTP LlmInferenceHTTPRequestStartTransport = "http" + // Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and + // the `binary` flag distinguishes text from binary frames; request and response chunks flow + // concurrently. + LlmInferenceHTTPRequestStartTransportWebsocket LlmInferenceHTTPRequestStartTransport = "websocket" +) + // Allowed values for the `McpAppsHostContextDetailsAvailableDisplayMode` enumeration. // Experimental: MCPAppsHostContextDetailsAvailableDisplayMode is part of an experimental // API and may change or be removed. @@ -9531,6 +9700,19 @@ const ( ModelPolicyStateUnconfigured ModelPolicyState = "unconfigured" ) +// Why the binary data is absent: it exceeded the inline size limit, or its asset was +// unavailable +// Experimental: OmittedBinaryOmittedReason is part of an experimental API and may change or +// be removed. +type OmittedBinaryOmittedReason string + +const ( + // The referenced binary asset could not be found (e.g. a truncated log). + OmittedBinaryOmittedReasonAssetUnavailable OmittedBinaryOmittedReason = "asset_unavailable" + // Bytes exceeded the session's inline size limit. + OmittedBinaryOmittedReasonTooLarge OmittedBinaryOmittedReason = "too_large" +) + // Allowed values for the `OptionsUpdateAdditionalContentExclusionPolicyScope` enumeration. // Experimental: OptionsUpdateAdditionalContentExclusionPolicyScope is part of an // experimental API and may change or be removed. @@ -10660,6 +10842,71 @@ func (a *ServerInstructionsAPI) GetDiscoveryPaths(ctx context.Context, params *I return &result, nil } +// Experimental: ServerLlmInferenceAPI contains experimental APIs that may change or be +// removed. +type ServerLlmInferenceAPI serverAPI + +// HttpResponseChunk delivers a body byte range (or a terminal transport error) for an +// in-flight response, correlated by requestId. Set `end` true on the last chunk. When +// `error` is set the response terminates with a transport-level failure and the runtime +// raises an APIConnectionError. +// +// RPC method: llmInference.httpResponseChunk. +// +// Parameters: A response body chunk or terminal error. +// +// Returns: Whether the chunk was accepted. +func (a *ServerLlmInferenceAPI) HttpResponseChunk(ctx context.Context, params *LlmInferenceHTTPResponseChunkRequest) (*LlmInferenceHTTPResponseChunkResult, error) { + raw, err := a.client.Request(ctx, "llmInference.httpResponseChunk", params) + if err != nil { + return nil, err + } + var result LlmInferenceHTTPResponseChunkResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +// HttpResponseStart delivers the response head (status + headers) for an in-flight request, +// correlated by the requestId the runtime supplied in httpRequestStart. Must be called +// exactly once per request before any httpResponseChunk frames. +// +// RPC method: llmInference.httpResponseStart. +// +// Parameters: Response head. +// +// Returns: Whether the start frame was accepted. +func (a *ServerLlmInferenceAPI) HttpResponseStart(ctx context.Context, params *LlmInferenceHTTPResponseStartRequest) (*LlmInferenceHTTPResponseStartResult, error) { + raw, err := a.client.Request(ctx, "llmInference.httpResponseStart", params) + if err != nil { + return nil, err + } + var result LlmInferenceHTTPResponseStartResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +// SetProvider registers an SDK client as the LLM inference callback provider. +// +// RPC method: llmInference.setProvider. +// +// Returns: Indicates whether the calling client was registered as the LLM inference +// provider. +func (a *ServerLlmInferenceAPI) SetProvider(ctx context.Context) (*LlmInferenceSetProviderResult, error) { + raw, err := a.client.Request(ctx, "llmInference.setProvider", nil) + if err != nil { + return nil, err + } + var result LlmInferenceSetProviderResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + type ServerMCPAPI serverAPI // Discovers MCP servers from user, workspace, plugin, and builtin sources. @@ -11726,6 +11973,7 @@ type ServerRPC struct { AgentRegistry *ServerAgentRegistryAPI Agents *ServerAgentsAPI Instructions *ServerInstructionsAPI + LlmInference *ServerLlmInferenceAPI MCP *ServerMCPAPI Models *ServerModelsAPI Plugins *ServerPluginsAPI @@ -11765,6 +12013,7 @@ func NewServerRPC(client *jsonrpc2.Client) *ServerRPC { r.AgentRegistry = (*ServerAgentRegistryAPI)(&r.common) r.Agents = (*ServerAgentsAPI)(&r.common) r.Instructions = (*ServerInstructionsAPI)(&r.common) + r.LlmInference = (*ServerLlmInferenceAPI)(&r.common) r.MCP = (*ServerMCPAPI)(&r.common) r.Models = (*ServerModelsAPI)(&r.common) r.Plugins = (*ServerPluginsAPI)(&r.common) @@ -16750,3 +16999,94 @@ func RegisterClientSessionAPIHandlers(client *jsonrpc2.Client, getHandlers func( return raw, nil }) } + +// Experimental: LlmInferenceHandler contains experimental APIs that may change or be +// removed. +type LlmInferenceHandler interface { + // HttpRequestChunk delivers a body byte range (or a cancellation signal) for a request + // previously announced via httpRequestStart, correlated by requestId. The runtime fires at + // least one chunk per request — when there is no body, a single chunk with empty data and + // end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; + // the SDK then stops issuing httpResponseChunk frames and may emit a terminal + // httpResponseChunk with error set. + // + // RPC method: llmInference.httpRequestChunk. + // + // Parameters: A request body chunk or cancellation signal. + // + // Returns: Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as + // fire-and-forget. + HttpRequestChunk(request *LlmInferenceHTTPRequestChunkRequest) (*LlmInferenceHTTPRequestChunkResult, error) + // HttpRequestStart announces an outbound model-layer HTTP request the runtime wants the SDK + // client to service. Carries the request head only; the body always follows as one or more + // httpRequestChunk frames keyed by the same requestId, even when the body is empty (a + // single chunk with end=true). + // + // RPC method: llmInference.httpRequestStart. + // + // Parameters: The head of an outbound model-layer HTTP request. + // + // Returns: Acknowledgement. Returning successfully simply means the SDK accepted the start + // frame; it does not imply the request will succeed. + HttpRequestStart(request *LlmInferenceHTTPRequestStartRequest) (*LlmInferenceHTTPRequestStartResult, error) +} + +// ClientGlobalAPIHandlers provides all client-global API handler groups. +// +// Unlike client-session handlers these carry no implicit session id dispatch +// key; a single set of handlers serves the entire connection. +type ClientGlobalAPIHandlers struct { + LlmInference LlmInferenceHandler +} + +func clientGlobalHandlerError(err error) *jsonrpc2.Error { + if err == nil { + return nil + } + var rpcErr *jsonrpc2.Error + if errors.As(err, &rpcErr) { + return rpcErr + } + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} +} + +// RegisterClientGlobalAPIHandlers registers handlers for server-to-client client-global API +// calls. +func RegisterClientGlobalAPIHandlers(client *jsonrpc2.Client, handlers *ClientGlobalAPIHandlers) { + client.SetRequestHandler("llmInference.httpRequestChunk", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request LlmInferenceHTTPRequestChunkRequest + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + if handlers == nil || handlers.LlmInference == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: "No llmInference client-global handler registered"} + } + result, err := handlers.LlmInference.HttpRequestChunk(&request) + if err != nil { + return nil, clientGlobalHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("llmInference.httpRequestStart", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request LlmInferenceHTTPRequestStartRequest + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + if handlers == nil || handlers.LlmInference == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: "No llmInference client-global handler registered"} + } + result, err := handlers.LlmInference.HttpRequestStart(&request) + if err != nil { + return nil, clientGlobalHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) +} diff --git a/go/rpc/zsession_encoding.go b/go/rpc/zsession_encoding.go index 1cc987402..029070a64 100644 --- a/go/rpc/zsession_encoding.go +++ b/go/rpc/zsession_encoding.go @@ -693,15 +693,6 @@ func matchesOmittedBinaryResult(data []byte) bool { if rawGroup0.OmittedReason == nil { return false } - var rawGroup0String string - if err := json.Unmarshal(rawGroup0.OmittedReason, &rawGroup0String); err != nil { - return false - } - switch rawGroup0String { - case "asset_unavailable", "too_large": - default: - return false - } if rawGroup0.AssetID != nil { return false } @@ -1909,124 +1900,6 @@ func (r *PermissionCompletedData) UnmarshalJSON(data []byte) error { return nil } -func unmarshalElicitationCompletedContent(data []byte) (ElicitationCompletedContent, error) { - if string(data) == "null" { - return nil, nil - } - { - var value string - if err := json.Unmarshal(data, &value); err == nil { - return ElicitationCompletedStringContent(value), nil - } - } - { - var value float64 - if err := json.Unmarshal(data, &value); err == nil { - return ElicitationCompletedNumberContent(value), nil - } - } - { - var value bool - if err := json.Unmarshal(data, &value); err == nil { - return ElicitationCompletedBooleanContent(value), nil - } - } - { - var value []string - if err := json.Unmarshal(data, &value); err == nil { - return ElicitationCompletedStringArrayContent(value), nil - } - } - return nil, errors.New("data did not match any union variant for ElicitationCompletedContent") -} - -func (r *ElicitationCompletedData) UnmarshalJSON(data []byte) error { - type rawElicitationCompletedData struct { - Action *ElicitationCompletedAction `json:"action,omitempty"` - Content map[string]json.RawMessage `json:"content,omitzero"` - RequestID string `json:"requestId"` - } - var raw rawElicitationCompletedData - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - r.Action = raw.Action - if raw.Content != nil { - r.Content = make(map[string]ElicitationCompletedContent, len(raw.Content)) - for key, rawValue := range raw.Content { - value, err := unmarshalElicitationCompletedContent(rawValue) - if err != nil { - return err - } - r.Content[key] = value - } - } - r.RequestID = raw.RequestID - return nil -} - -func (r CustomNotificationPayload) MarshalJSON() ([]byte, error) { - if r.AnyArray != nil { - return json.Marshal(r.AnyArray) - } - if r.AnyMap != nil { - return json.Marshal(r.AnyMap) - } - if r.Bool != nil { - return json.Marshal(r.Bool) - } - if r.Double != nil { - return json.Marshal(r.Double) - } - if r.String != nil { - return json.Marshal(r.String) - } - return []byte("null"), nil -} - -func (r *CustomNotificationPayload) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - *r = CustomNotificationPayload{} - return nil - } - { - var value []any - if err := json.Unmarshal(data, &value); err == nil { - *r = CustomNotificationPayload{AnyArray: value} - return nil - } - } - { - var value map[string]any - if err := json.Unmarshal(data, &value); err == nil { - *r = CustomNotificationPayload{AnyMap: value} - return nil - } - } - { - var value bool - if err := json.Unmarshal(data, &value); err == nil { - *r = CustomNotificationPayload{Bool: &value} - return nil - } - } - { - var value float64 - if err := json.Unmarshal(data, &value); err == nil { - *r = CustomNotificationPayload{Double: &value} - return nil - } - } - { - var value string - if err := json.Unmarshal(data, &value); err == nil { - *r = CustomNotificationPayload{String: &value} - return nil - } - } - return errors.New("data did not match any union variant for CustomNotificationPayload") -} - func (r *SessionExtensionsAttachmentsPushedData) UnmarshalJSON(data []byte) error { type rawSessionExtensionsAttachmentsPushedData struct { Attachments []json.RawMessage `json:"attachments"` diff --git a/go/rpc/zsession_events.go b/go/rpc/zsession_events.go index 8a9d20d37..0c0afe807 100644 --- a/go/rpc/zsession_events.go +++ b/go/rpc/zsession_events.go @@ -403,7 +403,7 @@ type ElicitationCompletedData struct { // The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) Action *ElicitationCompletedAction `json:"action,omitempty"` // The submitted form data when action is 'accept'; keys match the requested schema fields - Content map[string]ElicitationCompletedContent `json:"content,omitzero"` + Content map[string]any `json:"content,omitzero"` // Request ID of the resolved elicitation request; clients should dismiss any UI for this request RequestID string `json:"requestId"` } @@ -524,10 +524,16 @@ func (*ExternalToolRequestedData) Type() SessionEventType { type ModelCallFailureData struct { // Completion ID from the model provider (e.g., chatcmpl-abc123) APICallID *string `json:"apiCallId,omitempty"` + // For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. + BadRequestKind *ModelCallFailureBadRequestKind `json:"badRequestKind,omitempty"` // Duration of the failed API call in milliseconds DurationMs *int64 `json:"durationMs,omitempty"` + // For HTTP 400 failures only: the `code` from the CAPI error envelope (e.g. 'model_max_prompt_tokens_exceeded') identifying which deterministic validation failure occurred. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + ErrorCode *string `json:"errorCode,omitempty"` // Raw provider/runtime error message for restricted telemetry ErrorMessage *string `json:"errorMessage,omitempty"` + // For HTTP 400 failures only: the `type` from the CAPI error envelope (e.g. 'websocket_error'), a coarser companion to errorCode for envelopes that carry no code. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + ErrorType *string `json:"errorType,omitempty"` // What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls Initiator *string `json:"initiator,omitempty"` // Model identifier used for the failed API call @@ -730,7 +736,7 @@ type SessionCustomNotificationData struct { // Source-defined custom notification name Name string `json:"name"` // Source-defined JSON payload for the custom notification - Payload CustomNotificationPayload `json:"payload"` + Payload any `json:"payload"` // Namespace for the custom notification producer Source string `json:"source"` // Optional source-defined string identifiers describing the payload subject @@ -1755,7 +1761,7 @@ type CanvasRegistryChangedCanvas struct { // Owning extension display name, when available ExtensionName *string `json:"extensionName,omitempty"` // JSON Schema for canvas open input - InputSchema map[string]any `json:"inputSchema,omitzero"` + InputSchema any `json:"inputSchema,omitempty"` } // Schema for the `CanvasRegistryChangedCanvasAction` type. @@ -1763,7 +1769,7 @@ type CanvasRegistryChangedCanvasAction struct { // Action description Description *string `json:"description,omitempty"` // JSON Schema for action input - InputSchema map[string]any `json:"inputSchema,omitzero"` + InputSchema any `json:"inputSchema,omitempty"` // Action name Name string `json:"name"` } @@ -1847,36 +1853,6 @@ type CustomAgentsUpdatedAgent struct { UserInvocable bool `json:"userInvocable"` } -// Source-defined JSON payload for the custom notification -type CustomNotificationPayload struct { - AnyArray []any - AnyMap map[string]any - Bool *bool - Double *float64 - String *string -} - -// Schema for the `ElicitationCompletedContent` type. -type ElicitationCompletedContent interface { - elicitationCompletedContent() -} - -type ElicitationCompletedBooleanContent bool - -func (ElicitationCompletedBooleanContent) elicitationCompletedContent() {} - -type ElicitationCompletedNumberContent float64 - -func (ElicitationCompletedNumberContent) elicitationCompletedContent() {} - -type ElicitationCompletedStringArrayContent []string - -func (ElicitationCompletedStringArrayContent) elicitationCompletedContent() {} - -type ElicitationCompletedStringContent string - -func (ElicitationCompletedStringContent) elicitationCompletedContent() {} - // JSON Schema describing the form fields to present to the user (form mode only) type ElicitationRequestedSchema struct { // Form field definitions, keyed by field name @@ -2071,7 +2047,7 @@ func (PermissionPromptRequestHook) Kind() PermissionPromptRequestKind { // MCP tool invocation permission prompt type PermissionPromptRequestMCP struct { // Arguments to pass to the MCP tool - Args *any `json:"args,omitempty"` + Args any `json:"args,omitempty"` // Name of the MCP server providing the tool ServerName string `json:"serverName"` // Tool call ID that triggered this permission request @@ -3263,6 +3239,16 @@ const ( MCPServerTransportStdio MCPServerTransport = "stdio" ) +// For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. +type ModelCallFailureBadRequestKind string + +const ( + // The 400 response carried no error body (transient gateway/proxy signature). + ModelCallFailureBadRequestKindBodyless ModelCallFailureBadRequestKind = "bodyless" + // The 400 response carried a structured CAPI error envelope (deterministic validation failure). + ModelCallFailureBadRequestKindStructuredError ModelCallFailureBadRequestKind = "structured_error" +) + // Where the failed model call originated type ModelCallFailureSource string @@ -3275,16 +3261,6 @@ const ( ModelCallFailureSourceTopLevel ModelCallFailureSource = "top_level" ) -// Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable -type OmittedBinaryOmittedReason string - -const ( - // The referenced binary asset could not be found (e.g. a truncated log). - OmittedBinaryOmittedReasonAssetUnavailable OmittedBinaryOmittedReason = "asset_unavailable" - // Bytes exceeded the session's inline size limit. - OmittedBinaryOmittedReasonTooLarge OmittedBinaryOmittedReason = "too_large" -) - // Binary result type discriminator. Use "image" for images and "resource" for other binary data. type OmittedBinaryType string diff --git a/go/types.go b/go/types.go index ba83c6b6d..4c83950bf 100644 --- a/go/types.go +++ b/go/types.go @@ -116,6 +116,12 @@ type ClientOptions struct { // on connection, routing session-scoped file I/O through per-session // handlers. SessionFS *SessionFSConfig + // LlmInference configures a connection-level LLM inference callback. When + // provided, the client registers as the inference provider on connection, + // and the runtime routes its model-layer HTTP and WebSocket traffic through + // the handler instead of issuing the calls itself. Works for both CAPI and + // BYOK sessions. + LlmInference *LlmInferenceConfig // Telemetry configures OpenTelemetry integration for the runtime. // When non-nil, COPILOT_OTEL_ENABLED=true is set and any populated // fields are mapped to the corresponding environment variables. diff --git a/go/zsession_events.go b/go/zsession_events.go index 944a84b91..03531aa9d 100644 --- a/go/zsession_events.go +++ b/go/zsession_events.go @@ -60,14 +60,8 @@ type ( CompactionCompleteCompactionTokensUsedCopilotUsageTokenDetail = rpc.CompactionCompleteCompactionTokensUsedCopilotUsageTokenDetail ContextTier = rpc.ContextTier CustomAgentsUpdatedAgent = rpc.CustomAgentsUpdatedAgent - CustomNotificationPayload = rpc.CustomNotificationPayload ElicitationCompletedAction = rpc.ElicitationCompletedAction - ElicitationCompletedBooleanContent = rpc.ElicitationCompletedBooleanContent - ElicitationCompletedContent = rpc.ElicitationCompletedContent ElicitationCompletedData = rpc.ElicitationCompletedData - ElicitationCompletedNumberContent = rpc.ElicitationCompletedNumberContent - ElicitationCompletedStringArrayContent = rpc.ElicitationCompletedStringArrayContent - ElicitationCompletedStringContent = rpc.ElicitationCompletedStringContent ElicitationRequestedData = rpc.ElicitationRequestedData ElicitationRequestedMode = rpc.ElicitationRequestedMode ElicitationRequestedSchema = rpc.ElicitationRequestedSchema @@ -100,6 +94,7 @@ type ( MCPServerSource = rpc.MCPServerSource MCPServerStatus = rpc.MCPServerStatus MCPServerTransport = rpc.MCPServerTransport + ModelCallFailureBadRequestKind = rpc.ModelCallFailureBadRequestKind ModelCallFailureData = rpc.ModelCallFailureData ModelCallFailureSource = rpc.ModelCallFailureSource OmittedBinaryOmittedReason = rpc.OmittedBinaryOmittedReason @@ -366,6 +361,8 @@ const ( MCPServerTransportMemory = rpc.MCPServerTransportMemory MCPServerTransportSSE = rpc.MCPServerTransportSSE MCPServerTransportStdio = rpc.MCPServerTransportStdio + ModelCallFailureBadRequestKindBodyless = rpc.ModelCallFailureBadRequestKindBodyless + ModelCallFailureBadRequestKindStructuredError = rpc.ModelCallFailureBadRequestKindStructuredError ModelCallFailureSourceMCPSampling = rpc.ModelCallFailureSourceMCPSampling ModelCallFailureSourceSubagent = rpc.ModelCallFailureSourceSubagent ModelCallFailureSourceTopLevel = rpc.ModelCallFailureSourceTopLevel diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 5e1ae09ee..99ccf7ee3 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -16,6 +16,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.28.1", @@ -29,7 +30,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" @@ -1329,6 +1331,16 @@ "undici-types": "~7.18.0" } }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.56.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.56.1.tgz", @@ -3959,6 +3971,28 @@ "dev": true, "license": "MIT" }, + "node_modules/ws": { + "version": "8.21.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.21.0.tgz", + "integrity": "sha512-Vsp28b7DRcimFQvrqu2Wek3z1iYxDCWqHYB8Qsnk/S4RfaCQzPGPyBNuVjJV3cd6UiKtUtp6sNM77gWvzcCH+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/yaml": { "version": "2.9.0", "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.9.0.tgz", diff --git a/nodejs/package.json b/nodejs/package.json index 11dc978bc..d0f2bd5f7 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -63,6 +63,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.28.1", @@ -76,7 +77,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 6b4aca13e..ac556b91c 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -29,12 +29,14 @@ import { import { createServerRpc, createInternalServerRpc, + registerClientGlobalApiHandlers, registerClientSessionApiHandlers, } from "./generated/rpc.js"; import type { OpenCanvasInstance, SessionUpdateOptionsParams } from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession } from "./session.js"; import { createSessionFsAdapter, type SessionFsProvider } from "./sessionFsProvider.js"; +import { createLlmInferenceAdapter } from "./llmInferenceProvider.js"; import { getTraceContext } from "./telemetry.js"; import { ToolSet } from "./toolSet.js"; import type { @@ -60,6 +62,7 @@ import type { SessionCapabilities, SessionEvent, SessionFsConfig, + LlmInferenceConfig, SessionLifecycleEvent, SessionLifecycleEventType, SessionLifecycleHandler, @@ -389,6 +392,8 @@ export class CopilotClient { private negotiatedProtocolVersion: number | null = null; /** Connection-level session filesystem config, set via constructor option. */ private sessionFsConfig: SessionFsConfig | null = null; + private llmInferenceConfig: LlmInferenceConfig | null = null; + private llmInferenceHandlers: import("./generated/rpc.js").ClientGlobalApiHandlers = {}; /** * Typed server-scoped RPC methods. @@ -500,6 +505,8 @@ export class CopilotClient { this.onListModels = options.onListModels; this.onGetTraceContext = options.onGetTraceContext; this.sessionFsConfig = options.sessionFs ?? null; + this.llmInferenceConfig = options.llmInference ?? null; + this.setupLlmInference(); const effectiveEnv = options.env ?? process.env; this.resolvedEnv = effectiveEnv; @@ -616,6 +623,27 @@ export class CopilotClient { session.clientSessionApis.sessionFs = createSessionFsAdapter(provider); } + private setupLlmInference(): void { + if (!this.llmInferenceConfig) { + return; + } + const provider = this.llmInferenceConfig.handler; + if (!provider) { + throw new Error( + "handler is required on client options.llmInference when llmInference is enabled." + ); + } + this.llmInferenceHandlers = { + llmInference: createLlmInferenceAdapter(provider, () => { + if (!this.connection) { + return undefined; + } + this._rpc ??= createServerRpc(this.connection); + return this._rpc; + }), + }; + } + /** * Starts the CLI server and establishes a connection. * @@ -663,6 +691,13 @@ export class CopilotClient { }); } + // If an LLM inference provider was configured, register it. + // The runtime will then route outbound model HTTP requests + // through the registered handler for the duration of each session. + if (this.llmInferenceConfig) { + await this.connection!.sendRequest("llmInference.setProvider", {}); + } + this.state = "connected"; } catch (error) { this.state = "error"; @@ -2333,6 +2368,11 @@ export class CopilotClient { return session.clientSessionApis; }); + // Register client *global* API handlers (e.g. LLM inference) on the + // same connection. These methods carry no implicit sessionId dispatch + // — the runtime calls into a single handler for the whole connection. + registerClientGlobalApiHandlers(this.connection, this.llmInferenceHandlers); + this.connection.onClose(() => { this.state = "disconnected"; }); diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index ba27a2d52..2ce64d43d 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -501,6 +501,18 @@ export type InstructionSourceLocation = | "working-directory" /** Instructions live in plugin-provided configuration. */ | "plugin"; +/** + * Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartTransport". + */ +/** @experimental */ +export type LlmInferenceHttpRequestStartTransport = + /** Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. */ + | "http" + /** Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. */ + | "websocket"; /** * Repository host type * @@ -3648,7 +3660,6 @@ export interface ExternalToolTextResultForLlm { * Structured content blocks from the tool */ contents?: ExternalToolTextResultForLlmContent[]; - [k: string]: unknown | undefined; } /** * Binary result returned by a tool for the model @@ -4303,6 +4314,196 @@ export interface InstructionSource { */ projectPath?: string; } +/** + * HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHeaders". + */ +/** @experimental */ +export interface LlmInferenceHeaders { + [k: string]: string[] | undefined; +} +/** + * A request body chunk or cancellation signal. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestChunkRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestChunkRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + */ + data: string; + /** + * When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + */ + binary?: boolean; + /** + * When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + */ + end?: boolean; + /** + * When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + */ + cancel?: boolean; + /** + * Optional human-readable reason for the cancellation, propagated for logging. + */ + cancelReason?: string; +} +/** + * Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestChunkResult". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestChunkResult {} +/** + * The head of an outbound model-layer HTTP request. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestStartRequest { + /** + * Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + */ + requestId: string; + /** + * Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + */ + sessionId?: string; + /** + * HTTP method, e.g. GET, POST. + */ + method: string; + /** + * Absolute request URL. + */ + url: string; + headers: LlmInferenceHeaders; + transport?: LlmInferenceHttpRequestStartTransport; +} +/** + * Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartResult". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestStartResult {} +/** + * Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkError". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkError { + /** + * Human-readable failure description. + */ + message: string; + /** + * Optional machine-readable error code. + */ + code?: string; +} +/** + * A response body chunk or terminal error. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + */ + data: string; + /** + * When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + */ + binary?: boolean; + /** + * When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + */ + end?: boolean; + error?: LlmInferenceHttpResponseChunkError; +} +/** + * Whether the chunk was accepted. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkResult". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkResult { + /** + * True when the chunk was matched to a pending request; false when unknown. + */ + accepted: boolean; +} +/** + * Response head. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseStartRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseStartRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * HTTP status code. + */ + status: number; + /** + * Optional HTTP status reason phrase. + */ + statusText?: string; + headers: LlmInferenceHeaders; +} +/** + * Whether the start frame was accepted. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseStartResult". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseStartResult { + /** + * True when the response start was matched to a pending request; false when unknown. + */ + accepted: boolean; +} +/** + * Indicates whether the calling client was registered as the LLM inference provider. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceSetProviderResult". + */ +/** @experimental */ +export interface LlmInferenceSetProviderResult { + /** + * Whether the provider was set successfully + */ + success: boolean; +} /** * Schema for the `LocalSessionMetadataValue` type. * @@ -4944,7 +5145,6 @@ export interface McpServerAuthConfigRedirectPort { * Fixed port for the OAuth redirect callback server. */ redirectPort?: number; - [k: string]: unknown | undefined; } /** * Remote MCP server configuration accessed over HTTP or SSE. @@ -5153,7 +5353,9 @@ export interface McpExecuteSamplingParams { /** * The original MCP JSON-RPC request ID (string or number). Used by the runtime to correlate the inference with the originating MCP request for telemetry; this is distinct from `requestId` (which is the schema-level cancellation handle). */ - mcpRequestId: string | number; + mcpRequestId: { + [k: string]: unknown | undefined; + }; request: McpExecuteSamplingRequest; } /** @@ -5996,9 +6198,13 @@ export interface ModelBillingTokenPrices { */ outputPrice?: number; /** - * AI Credits cost per billing batch of cached tokens + * AI Credits cost per billing batch of cache-read tokens */ cachePrice?: number; + /** + * AI Credits cost per billing batch of cache-write (cache creation) tokens. + */ + cacheWritePrice?: number; /** * Number of tokens per standard billing batch */ @@ -6025,9 +6231,13 @@ export interface ModelBillingTokenPricesLongContext { */ outputPrice?: number; /** - * AI Credits cost per billing batch of cached tokens + * AI Credits cost per billing batch of cache-read tokens */ cachePrice?: number; + /** + * AI Credits cost per billing batch of cache-write (cache creation) tokens. + */ + cacheWritePrice?: number; /** * Prompt token budget (max_prompt_tokens) for the long context tier. The total context window is this value plus the model's max_output_tokens. */ @@ -6313,9 +6523,8 @@ export interface NameSetRequest { /** @experimental */ export interface OptionsUpdateAdditionalContentExclusionPolicy { rules: OptionsUpdateAdditionalContentExclusionPolicyRule[]; - last_updated_at: string | number; + last_updated_at: unknown; scope: OptionsUpdateAdditionalContentExclusionPolicyScope; - [k: string]: unknown | undefined; } /** * Schema for the `OptionsUpdateAdditionalContentExclusionPolicyRule` type. @@ -6329,7 +6538,6 @@ export interface OptionsUpdateAdditionalContentExclusionPolicyRule { ifAnyMatch?: string[]; ifNoneMatch?: string[]; source: OptionsUpdateAdditionalContentExclusionPolicyRuleSource; - [k: string]: unknown | undefined; } /** * Schema for the `OptionsUpdateAdditionalContentExclusionPolicyRuleSource` type. @@ -7325,9 +7533,8 @@ export interface PermissionRulesSet { /** @experimental */ export interface PermissionsConfigureAdditionalContentExclusionPolicy { rules: PermissionsConfigureAdditionalContentExclusionPolicyRule[]; - last_updated_at: string | number; + last_updated_at: unknown; scope: PermissionsConfigureAdditionalContentExclusionPolicyScope; - [k: string]: unknown | undefined; } /** * Schema for the `PermissionsConfigureAdditionalContentExclusionPolicyRule` type. @@ -7341,7 +7548,6 @@ export interface PermissionsConfigureAdditionalContentExclusionPolicyRule { ifAnyMatch?: string[]; ifNoneMatch?: string[]; source: PermissionsConfigureAdditionalContentExclusionPolicyRuleSource; - [k: string]: unknown | undefined; } /** * Schema for the `PermissionsConfigureAdditionalContentExclusionPolicyRuleSource` type. @@ -9687,7 +9893,7 @@ export interface SessionFsSqliteQueryRequest { * Optional named bind parameters */ params?: { - [k: string]: (string | number | null) | undefined; + [k: string]: unknown | undefined; }; } /** @@ -10221,9 +10427,8 @@ export interface SessionOpenOptions { /** @experimental */ export interface SessionOpenOptionsAdditionalContentExclusionPolicy { rules: SessionOpenOptionsAdditionalContentExclusionPolicyRule[]; - last_updated_at: string | number; + last_updated_at: unknown; scope: SessionOpenOptionsAdditionalContentExclusionPolicyScope; - [k: string]: unknown | undefined; } /** * Schema for the `SessionOpenOptionsAdditionalContentExclusionPolicyRule` type. @@ -10237,7 +10442,6 @@ export interface SessionOpenOptionsAdditionalContentExclusionPolicyRule { ifAnyMatch?: string[]; ifNoneMatch?: string[]; source: SessionOpenOptionsAdditionalContentExclusionPolicyRuleSource; - [k: string]: unknown | undefined; } /** * Schema for the `SessionOpenOptionsAdditionalContentExclusionPolicyRuleSource` type. @@ -13573,6 +13777,34 @@ export function createServerRpc(connection: MessageConnection) { connection.sendRequest("sessionFs.setProvider", params), }, /** @experimental */ + llmInference: { + /** + * Registers an SDK client as the LLM inference callback provider. + * + * @returns Indicates whether the calling client was registered as the LLM inference provider. + */ + setProvider: async (): Promise => + connection.sendRequest("llmInference.setProvider", {}), + /** + * Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames. + * + * @param params Response head. + * + * @returns Whether the start frame was accepted. + */ + httpResponseStart: async (params: LlmInferenceHttpResponseStartRequest): Promise => + connection.sendRequest("llmInference.httpResponseStart", params), + /** + * Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError. + * + * @param params A response body chunk or terminal error. + * + * @returns Whether the chunk was accepted. + */ + httpResponseChunk: async (params: LlmInferenceHttpResponseChunkRequest): Promise => + connection.sendRequest("llmInference.httpResponseChunk", params), + }, + /** @experimental */ sessions: { /** * Creates or resumes a local session and returns the opened session ID. @@ -15537,3 +15769,52 @@ export function registerClientSessionApiHandlers( return handler.invoke(params); }); } + +/** Handler for `llmInference` client global API methods. */ +/** @experimental */ +export interface LlmInferenceHandler { + /** + * Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + * + * @param params The head of an outbound model-layer HTTP request. + * + * @returns Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + */ + httpRequestStart(params: LlmInferenceHttpRequestStartRequest): Promise; + /** + * Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + * + * @param params A request body chunk or cancellation signal. + * + * @returns Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + */ + httpRequestChunk(params: LlmInferenceHttpRequestChunkRequest): Promise; +} + +/** All client global API handler groups. */ +export interface ClientGlobalApiHandlers { + llmInference?: LlmInferenceHandler; +} + +/** + * Register client global API handlers on a JSON-RPC connection. + * The server calls these methods to delegate work to the client. + * Unlike session-scoped client APIs, these methods carry no implicit + * `sessionId` dispatch key — a single set of handlers serves the entire + * connection. + */ +export function registerClientGlobalApiHandlers( + connection: MessageConnection, + handlers: ClientGlobalApiHandlers, +): void { + connection.onRequest("llmInference.httpRequestStart", async (params: LlmInferenceHttpRequestStartRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestStart(params); + }); + connection.onRequest("llmInference.httpRequestChunk", async (params: LlmInferenceHttpRequestChunkRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestChunk(params); + }); +} diff --git a/nodejs/src/generated/session-events.ts b/nodejs/src/generated/session-events.ts index 96f871783..b42bd304c 100644 --- a/nodejs/src/generated/session-events.ts +++ b/nodejs/src/generated/session-events.ts @@ -212,6 +212,14 @@ export type Attachment = | AttachmentGitHubReference | AttachmentBlob | AttachmentExtensionContext; +/** + * Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable + */ +export type OmittedBinaryOmittedReason = + /** Bytes exceeded the session's inline size limit. */ + | "too_large" + /** The referenced binary asset could not be found (e.g. a truncated log). */ + | "asset_unavailable"; /** * Type of GitHub reference */ @@ -242,6 +250,14 @@ export type AssistantUsageApiEndpoint = | "/responses" /** WebSocket Responses API endpoint. */ | "ws:/responses"; +/** + * For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. + */ +export type ModelCallFailureBadRequestKind = + /** The 400 response carried no error body (transient gateway/proxy signature). */ + | "bodyless" + /** The 400 response carried a structured CAPI error envelope (deterministic validation failure). */ + | "structured_error"; /** * Where the failed model call originated */ @@ -283,14 +299,6 @@ export type PersistedBinaryImageType = | "image" /** Other binary resource data. */ | "resource"; -/** - * Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable - */ -export type OmittedBinaryOmittedReason = - /** Bytes exceeded the session's inline size limit. */ - | "too_large" - /** The referenced binary asset could not be found (e.g. a truncated log). */ - | "asset_unavailable"; /** * Binary result type discriminator. Use "image" for images and "resource" for other binary data. */ @@ -479,22 +487,6 @@ export type ElicitationCompletedAction = | "decline" /** The user dismissed the request. */ | "cancel"; -/** - * Schema for the `ElicitationCompletedContent` type. - */ -export type ElicitationCompletedContent = (string | number | boolean | string[]) | undefined; -/** - * Source-defined JSON payload for the custom notification - */ -export type CustomNotificationPayload = - | string - | number - | boolean - | null - | unknown[] - | { - [k: string]: unknown | undefined; - }; /** * The user's auto-mode-switch choice */ @@ -2292,11 +2284,24 @@ export interface UserMessageData { * File attachment */ export interface AttachmentFile { + /** + * Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + */ + assetId?: string; + /** + * Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + */ + byteLength?: number; /** * User-facing display name for the attachment */ displayName: string; lineRange?: AttachmentFileLineRange; + /** + * Internal: MIME type of the file's model-facing bytes (post-resize for images). Set when the file's bytes are interned to an asset. Absent externally. + */ + mimeType?: string; + omittedReason?: OmittedBinaryOmittedReason; /** * Absolute file path */ @@ -2422,9 +2427,17 @@ export interface AttachmentGitHubReference { */ export interface AttachmentBlob { /** - * Base64-encoded content + * Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. */ - data: string; + assetId?: string; + /** + * Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + */ + byteLength?: number; + /** + * Base64-encoded content. Present on input and for external consumers; replaced by an internal `assetId` reference in persisted events when interned to a content-addressed asset. + */ + data?: string; /** * User-facing display name for the attachment */ @@ -2433,6 +2446,7 @@ export interface AttachmentBlob { * MIME type of the inline data */ mimeType: string; + omittedReason?: OmittedBinaryOmittedReason; /** * Attachment type discriminator */ @@ -3238,14 +3252,23 @@ export interface ModelCallFailureData { * Completion ID from the model provider (e.g., chatcmpl-abc123) */ apiCallId?: string; + badRequestKind?: ModelCallFailureBadRequestKind; /** * Duration of the failed API call in milliseconds */ durationMs?: number; + /** + * For HTTP 400 failures only: the `code` from the CAPI error envelope (e.g. 'model_max_prompt_tokens_exceeded') identifying which deterministic validation failure occurred. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + */ + errorCode?: string; /** * Raw provider/runtime error message for restricted telemetry */ errorMessage?: string; + /** + * For HTTP 400 failures only: the `type` from the CAPI error envelope (e.g. 'websocket_error'), a coarser companion to errorCode for envelopes that carry no code. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + */ + errorType?: string; /** * What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls */ @@ -3971,27 +3994,19 @@ export interface ToolExecutionCompleteUIResourceMetaUIPermissions { /** * Schema for the `ToolExecutionCompleteUIResourceMetaUIPermissionsCamera` type. */ -export interface ToolExecutionCompleteUIResourceMetaUIPermissionsCamera { - [k: string]: unknown | undefined; -} +export interface ToolExecutionCompleteUIResourceMetaUIPermissionsCamera {} /** * Schema for the `ToolExecutionCompleteUIResourceMetaUIPermissionsClipboardWrite` type. */ -export interface ToolExecutionCompleteUIResourceMetaUIPermissionsClipboardWrite { - [k: string]: unknown | undefined; -} +export interface ToolExecutionCompleteUIResourceMetaUIPermissionsClipboardWrite {} /** * Schema for the `ToolExecutionCompleteUIResourceMetaUIPermissionsGeolocation` type. */ -export interface ToolExecutionCompleteUIResourceMetaUIPermissionsGeolocation { - [k: string]: unknown | undefined; -} +export interface ToolExecutionCompleteUIResourceMetaUIPermissionsGeolocation {} /** * Schema for the `ToolExecutionCompleteUIResourceMetaUIPermissionsMicrophone` type. */ -export interface ToolExecutionCompleteUIResourceMetaUIPermissionsMicrophone { - [k: string]: unknown | undefined; -} +export interface ToolExecutionCompleteUIResourceMetaUIPermissionsMicrophone {} /** * Tool definition metadata, present for MCP tools with MCP Apps support */ @@ -5260,7 +5275,12 @@ export interface PermissionPromptRequestRead { * MCP tool invocation permission prompt */ export interface PermissionPromptRequestMcp { - args?: unknown; + /** + * Arguments to pass to the MCP tool + */ + args?: { + [k: string]: unknown | undefined; + }; /** * Prompt kind discriminator */ @@ -5878,7 +5898,6 @@ export interface ElicitationRequestedData { * URL to open in the user's browser (url mode only) */ url?: string; - [k: string]: unknown | undefined; } /** * JSON Schema describing the form fields to present to the user (form mode only) @@ -5945,6 +5964,12 @@ export interface ElicitationCompletedData { */ requestId: string; } +/** + * Schema for the `ElicitationCompletedContent` type. + */ +export interface ElicitationCompletedContent { + [k: string]: unknown | undefined; +} /** * Session event "sampling.requested". Sampling request from an MCP server; contains the server name and a requestId for correlation */ @@ -5982,7 +6007,9 @@ export interface SamplingRequestedData { /** * The JSON-RPC request ID from the MCP protocol */ - mcpRequestId: string | number; + mcpRequestId: { + [k: string]: unknown | undefined; + }; /** * Unique identifier for this sampling request; used to respond via session.respondToSampling() */ @@ -5991,7 +6018,6 @@ export interface SamplingRequestedData { * Name of the MCP server that initiated the sampling request */ serverName: string; - [k: string]: unknown | undefined; } /** * Session event "sampling.completed". Sampling request completion notification signaling UI dismissal @@ -6185,6 +6211,12 @@ export interface CustomNotificationData { */ version?: number; } +/** + * Source-defined JSON payload for the custom notification + */ +export interface CustomNotificationPayload { + [k: string]: unknown | undefined; +} /** * Optional source-defined string identifiers describing the payload subject */ @@ -7443,7 +7475,6 @@ export interface McpAppToolCallCompleteError { */ export interface McpAppToolCallCompleteToolMeta { ui?: McpAppToolCallCompleteToolMetaUI; - [k: string]: unknown | undefined; } /** * Schema for the `McpAppToolCallCompleteToolMetaUI` type. @@ -7457,5 +7488,4 @@ export interface McpAppToolCallCompleteToolMetaUI { * Tool visibility per SEP-1865 (typically a subset of `["model","app"]`) */ visibility?: string[]; - [k: string]: unknown | undefined; } diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 9b266fc9c..9fa6fc4eb 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -28,6 +28,10 @@ export { approveAll, convertMcpCallToolResult, createSessionFsAdapter, + CopilotWebSocketHandler, + ForwardingWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, SYSTEM_MESSAGE_SECTIONS, } from "./types.js"; // Re-export the generated session-event types (every *Event interface and @@ -125,6 +129,11 @@ export type { SessionFsSqliteQueryResult, SessionFsSqliteQueryType, SessionFsSqliteProvider, + LlmInferenceConfig, + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, + LlmRequestContext, SystemMessageAppendConfig, SystemMessageConfig, SystemMessageCustomizeConfig, diff --git a/nodejs/src/llmInferenceProvider.ts b/nodejs/src/llmInferenceProvider.ts new file mode 100644 index 000000000..4e43900b2 --- /dev/null +++ b/nodejs/src/llmInferenceProvider.ts @@ -0,0 +1,437 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import type { + LlmInferenceHandler, + LlmInferenceHeaders, + LlmInferenceHttpRequestChunkRequest, + LlmInferenceHttpRequestChunkResult, + LlmInferenceHttpRequestStartRequest, + LlmInferenceHttpRequestStartResult, +} from "./generated/rpc.js"; +import type { createServerRpc } from "./generated/rpc.js"; + +type ServerRpc = ReturnType; + +/** + * An outbound model-layer HTTP request the runtime is asking the SDK + * consumer to handle on its behalf. + * + * This is a low-level shape: URL / method / headers verbatim, body bytes + * delivered as an async iterable, response delivered through the + * {@link LlmInferenceResponseSink}. The runtime does not classify the + * request (no provider type, endpoint kind, wire API). Consumers that + * need that information derive it themselves from the URL / headers. + */ +export interface LlmInferenceRequest { + /** Opaque runtime-minted id, stable across the request lifecycle. */ + requestId: string; + /** + * Id of the runtime session that triggered this request, when one is + * in scope. Absent for out-of-session requests (e.g. startup model + * catalog). + */ + sessionId?: string; + /** HTTP method (`GET`, `POST`, ...). */ + method: string; + /** Absolute URL. */ + url: string; + /** HTTP request headers, multi-valued. */ + headers: LlmInferenceHeaders; + /** + * Transport the runtime would otherwise use for this request. + * `"http"` (the default) covers plain HTTP and SSE responses; + * `"websocket"` indicates a full-duplex message channel where each + * {@link requestBody} chunk is one inbound WebSocket message and each + * {@link responseBody} write is one outbound message. Consumers branch + * on this to decide whether to service the request with an HTTP client + * or a WebSocket client. + */ + transport: "http" | "websocket"; + /** + * Request body bytes, yielded as they arrive from the runtime. + * Always iterable; an empty body yields zero chunks before completing. + */ + requestBody: AsyncIterable; + /** + * Aborts when the runtime cancels this in-flight request (e.g. the + * agent turn was aborted upstream). Pass it straight to `fetch` / + * `HttpClient.SendAsync` / your transport so the upstream call is torn + * down too. After it fires, writes to {@link responseBody} are ignored. + */ + signal: AbortSignal; + /** + * Sink the consumer writes the upstream response into. Call + * {@link LlmInferenceResponseSink.start} exactly once before writing + * body chunks, then one or more {@link LlmInferenceResponseSink.write} + * calls, and finish with {@link LlmInferenceResponseSink.end} or + * {@link LlmInferenceResponseSink.error}. + */ + responseBody: LlmInferenceResponseSink; +} + +/** Response head passed to {@link LlmInferenceResponseSink.start}. */ +export interface LlmInferenceResponseInit { + status: number; + statusText?: string; + headers?: LlmInferenceHeaders; +} + +/** + * Sink the consumer writes the upstream response into. The state machine + * is strict: `start` once → 0..N `write` → exactly one of `end` or + * `error`. Calling out of order throws. + */ +export interface LlmInferenceResponseSink { + /** Send the response head (status + headers) back to the runtime. */ + start(init: LlmInferenceResponseInit): Promise; + /** + * Send a body chunk. `string` is encoded as UTF-8; `Uint8Array` is sent + * as binary (base64 on the wire). + */ + write(data: string | Uint8Array): Promise; + /** Mark end-of-stream cleanly. */ + end(): Promise; + /** Mark end-of-stream with a transport-level failure. */ + error(error: { message: string; code?: string }): Promise; +} + +/** + * Interface for an LLM inference provider. The SDK consumer implements + * `onLlmRequest`. The same callback handles both buffered and streaming + * responses — the consumer just calls `responseBody.write` zero or more + * times before `end`. + * + * Use {@link createLlmInferenceAdapter} to convert an + * {@link LlmInferenceProvider} into the {@link LlmInferenceHandler} the + * SDK's RPC layer registers. + */ +export interface LlmInferenceProvider { + /** + * Called by the runtime once per outbound LLM HTTP request the + * consumer has opted to handle. The consumer is responsible for + * eventually calling either `responseBody.end()` or + * `responseBody.error(...)`; failing to do so leaks runtime state. + * Throwing surfaces a transport-level failure to the runtime + * (equivalent to `responseBody.error({ message: err.message })` + * provided `start` has not yet been called). + */ + onLlmRequest(request: LlmInferenceRequest): Promise | void; +} + +interface BodyQueueItem { + chunk?: Uint8Array; + end?: boolean; + cancel?: { reason?: string }; +} + +interface BodyQueue { + push(item: BodyQueueItem): void; + iterable: AsyncIterable; +} + +function makeBodyQueue(): BodyQueue { + const buffer: BodyQueueItem[] = []; + let waker: (() => void) | null = null; + let done = false; + const wake = (): void => { + const w = waker; + waker = null; + w?.(); + }; + return { + push(item: BodyQueueItem): void { + buffer.push(item); + wake(); + }, + iterable: { + [Symbol.asyncIterator](): AsyncIterator { + return { + async next(): Promise> { + if (done) { + return { value: undefined, done: true }; + } + while (buffer.length === 0) { + await new Promise((resolve) => { + waker = resolve; + }); + } + const item = buffer.shift()!; + if (item.cancel) { + done = true; + const reason = item.cancel.reason + ? `Request cancelled by runtime: ${item.cancel.reason}` + : "Request cancelled by runtime"; + throw new Error(reason); + } + if (item.end) { + done = true; + return { value: undefined, done: true }; + } + return { value: item.chunk ?? new Uint8Array(), done: false }; + }, + }; + }, + }, + }; +} + +const sharedTextEncoder = new TextEncoder(); + +function decodeChunkData(data: string, binary: boolean): Uint8Array { + if (binary) { + return new Uint8Array(Buffer.from(data, "base64")); + } + return sharedTextEncoder.encode(data); +} + +interface PendingState { + queue: BodyQueue; + started: boolean; + finished: boolean; + abort: AbortController; + cancelled: boolean; +} + +/** + * Adapt an {@link LlmInferenceProvider} into the generated + * {@link LlmInferenceHandler} shape consumed by the SDK's RPC dispatcher. + * + * Maintains a per-`requestId` state table: each `httpRequestStart` + * allocates a body queue + response sink and fires + * `provider.onLlmRequest` in the background. Subsequent `httpRequestChunk` + * frames are routed into the queue. The sink translates `start` / + * `write` / `end` / `error` calls into outbound + * `serverRpc.llmInference.httpResponseStart` / `httpResponseChunk` calls. + * + * The handler returns from `httpRequestStart` immediately (synchronously + * after registering state) so the runtime's RPC reply is not gated on the + * consumer's I/O. The actual provider work runs asynchronously. + */ +export function createLlmInferenceAdapter( + provider: LlmInferenceProvider, + getServerRpc: () => ServerRpc | undefined +): LlmInferenceHandler { + const pending = new Map(); + // Defense-in-depth backstop: chunks that arrive before their `start` + // frame (a reordering the runtime's single ordered dispatch should make + // impossible) are staged here keyed by requestId and drained the moment + // `httpRequestStart` registers the matching state, so a body byte is + // never silently dropped. + const staged = new Map(); + + function routeChunk(state: PendingState, params: LlmInferenceHttpRequestChunkRequest): void { + if (params.cancel) { + state.cancelled = true; + state.abort.abort(); + state.queue.push({ cancel: { reason: params.cancelReason } }); + return; + } + if (params.data && params.data.length > 0) { + state.queue.push({ chunk: decodeChunkData(params.data, !!params.binary) }); + } + if (params.end) { + state.queue.push({ end: true }); + } + } + + function makeSink(requestId: string, state: PendingState): LlmInferenceResponseSink { + const rpc = (): ServerRpc => { + const r = getServerRpc(); + if (!r) { + throw new Error("LLM inference response sink used after RPC connection closed."); + } + return r; + }; + // The runtime acknowledges every response frame with `accepted`. + // `accepted: false` means it has dropped the request (e.g. it + // cancelled), so we abort the provider's upstream work and stop + // emitting — there is no consumer for further frames. + const rejectedByRuntime = (): never => { + if (!state.cancelled) { + state.cancelled = true; + state.abort.abort(); + } + state.finished = true; + pending.delete(requestId); + throw new Error( + "LLM inference response was rejected by the runtime (request no longer active)." + ); + }; + return { + async start(init: LlmInferenceResponseInit): Promise { + if (state.started) { + throw new Error("LLM inference response sink.start() called twice."); + } + if (state.finished) { + throw new Error("LLM inference response sink already finished."); + } + state.started = true; + const result = await rpc().llmInference.httpResponseStart({ + requestId, + status: init.status, + statusText: init.statusText, + headers: init.headers ?? {}, + }); + if (!result.accepted) { + rejectedByRuntime(); + } + }, + async write(data: string | Uint8Array): Promise { + if (state.cancelled) { + throw new Error("LLM inference request was cancelled by the runtime."); + } + if (!state.started) { + throw new Error("LLM inference response sink.write() called before start()."); + } + if (state.finished) { + throw new Error( + "LLM inference response sink.write() called after end()/error()." + ); + } + const isString = typeof data === "string"; + const result = await rpc().llmInference.httpResponseChunk({ + requestId, + data: isString ? data : Buffer.from(data).toString("base64"), + binary: !isString, + end: false, + }); + if (!result.accepted) { + rejectedByRuntime(); + } + }, + async end(): Promise { + if (state.finished) { + return; + } + state.finished = true; + pending.delete(requestId); + await rpc().llmInference.httpResponseChunk({ + requestId, + data: "", + end: true, + }); + }, + async error(err: { message: string; code?: string }): Promise { + if (state.finished) { + return; + } + state.finished = true; + pending.delete(requestId); + await rpc().llmInference.httpResponseChunk({ + requestId, + data: "", + end: true, + error: { message: err.message, code: err.code }, + }); + }, + }; + } + + async function failViaSink( + sink: LlmInferenceResponseSink, + state: PendingState, + message: string + ): Promise { + if (state.finished) { + return; + } + try { + if (!state.started) { + await sink.start({ status: 502, headers: {} }); + } + await sink.error({ message }); + } catch { + // Best-effort — the connection may already be dead. + } + } + + async function finishCancelled( + sink: LlmInferenceResponseSink, + state: PendingState + ): Promise { + if (state.finished) { + return; + } + try { + if (!state.started) { + await sink.start({ status: 499, headers: {} }); + } + await sink.error({ message: "Request cancelled by runtime", code: "cancelled" }); + } catch { + // Best-effort — the runtime already dropped the request on cancel. + } + } + + return { + async httpRequestStart( + params: LlmInferenceHttpRequestStartRequest + ): Promise { + const state: PendingState = { + queue: makeBodyQueue(), + started: false, + finished: false, + abort: new AbortController(), + cancelled: false, + }; + pending.set(params.requestId, state); + const stagedChunks = staged.get(params.requestId); + if (stagedChunks) { + staged.delete(params.requestId); + for (const chunk of stagedChunks) { + routeChunk(state, chunk); + } + } + const sink = makeSink(params.requestId, state); + const request: LlmInferenceRequest = { + requestId: params.requestId, + sessionId: params.sessionId, + method: params.method, + url: params.url, + headers: params.headers, + transport: params.transport ?? "http", + requestBody: state.queue.iterable, + signal: state.abort.signal, + responseBody: sink, + }; + void (async () => { + try { + await provider.onLlmRequest(request); + if (!state.finished) { + await failViaSink( + sink, + state, + "LLM inference provider returned without finalising the response (call responseBody.end() or .error())." + ); + } + } catch (err) { + if (state.cancelled || state.abort.signal.aborted) { + // The runtime already cancelled this request; the + // provider's throw is just the abort propagating + // out of its upstream call. Acknowledge with a + // terminal cancelled error if we still can. + await finishCancelled(sink, state); + return; + } + const message = err instanceof Error ? err.message : String(err); + await failViaSink(sink, state, message); + } + })(); + return {}; + }, + async httpRequestChunk( + params: LlmInferenceHttpRequestChunkRequest + ): Promise { + const state = pending.get(params.requestId); + if (!state) { + const buffered = staged.get(params.requestId) ?? []; + buffered.push(params); + staged.set(params.requestId, buffered); + return {}; + } + routeChunk(state, params); + return {}; + }, + }; +} diff --git a/nodejs/src/llmRequestHandler.ts b/nodejs/src/llmRequestHandler.ts new file mode 100644 index 000000000..1640183b3 --- /dev/null +++ b/nodejs/src/llmRequestHandler.ts @@ -0,0 +1,469 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import type { LlmInferenceHeaders } from "./generated/rpc.js"; +import type { LlmInferenceProvider, LlmInferenceRequest, LlmInferenceResponseSink } from "./llmInferenceProvider.js"; + +const sharedTextDecoder = new TextDecoder("utf-8", { fatal: false }); +const kBridge = Symbol("llmWebSocketResponseBridge"); +const kCompletion = Symbol("llmWebSocketCompletion"); +const kOpen = Symbol("llmWebSocketOpen"); +const kSuppressCloseOnDispose = Symbol("llmWebSocketSuppressCloseOnDispose"); + +type InternalContext = LlmRequestContext & { [kBridge]: LlmWebSocketResponseBridge }; + +/** + * Per-request context handed to every {@link LlmRequestHandler} hook. + * + * @experimental + */ +export interface LlmRequestContext { + readonly requestId: string; + readonly sessionId?: string; + readonly transport: "http" | "websocket"; + readonly url: string; + readonly headers: LlmInferenceHeaders; + readonly signal: AbortSignal; +} + +/** + * Terminal status for a callback-owned WebSocket connection. + * + * @experimental + */ +export class LlmWebSocketCloseStatus { + static readonly normalClosure = new LlmWebSocketCloseStatus(); + + constructor( + readonly description?: string, + readonly errorCode?: string, + readonly error?: Error + ) {} +} + +/** + * Per-connection WebSocket handler returned by {@link LlmRequestHandler.openWebSocket}. + * + * @experimental + */ +export abstract class CopilotWebSocketHandler implements AsyncDisposable { + readonly #response: LlmWebSocketResponseBridge; + readonly #completion: Promise; + #resolveCompletion!: (status: LlmWebSocketCloseStatus) => void; + #closed = false; + [kSuppressCloseOnDispose] = false; + + protected readonly context: LlmRequestContext; + + protected constructor(context: LlmRequestContext) { + this.context = context; + const bridge = (context as Partial)[kBridge]; + if (!bridge) { + throw new Error("WebSocket response bridge is not attached"); + } + this.#response = bridge; + this.#completion = new Promise((resolve) => { + this.#resolveCompletion = resolve; + }); + } + + async sendResponseMessage(data: string | Uint8Array): Promise { + await this.#response.write(data); + } + + async close(status: LlmWebSocketCloseStatus = LlmWebSocketCloseStatus.normalClosure): Promise { + if (this.#closed) { + return; + } + this.#closed = true; + if (status.error) { + await this.#response.error({ + message: status.description ?? status.error.message, + code: status.errorCode, + }); + } else { + await this.#response.end(); + } + this.#resolveCompletion(status); + } + + abstract sendRequestMessage(data: string | Uint8Array): Promise | void; + + async [Symbol.asyncDispose](): Promise { + if (!this[kSuppressCloseOnDispose] && !this.#closed) { + await this.close(LlmWebSocketCloseStatus.normalClosure); + } + } + + /** @internal */ + get [kCompletion](): Promise { + return this.#completion; + } + + /** @internal */ + async [kOpen](): Promise {} +} + +/** + * Default pass-through WebSocket handler backed by the WHATWG `WebSocket`. + * + * @experimental + */ +export class ForwardingWebSocketHandler extends CopilotWebSocketHandler { + readonly #url: string; + #upstream: WebSocket | null = null; + + constructor(context: LlmRequestContext, url = context.url) { + super(context); + this.#url = url; + } + + override sendRequestMessage(data: string | Uint8Array): void { + if (this.#upstream?.readyState !== WebSocket.OPEN) { + return; + } + this.#upstream.send(data); + } + + /** @internal */ + override async [kOpen](): Promise { + if (this.#upstream) { + return; + } + const upstream = new WebSocket(this.#url); + upstream.binaryType = "arraybuffer"; + this.#upstream = upstream; + upstream.addEventListener("message", (event) => { + void this.sendResponseMessage(normalizeWsData(event.data)).catch(async (err: unknown) => { + await this.close( + new LlmWebSocketCloseStatus( + err instanceof Error ? err.message : String(err), + undefined, + err instanceof Error ? err : new Error(String(err)) + ) + ); + }); + }); + upstream.addEventListener("close", () => { + void this.close(LlmWebSocketCloseStatus.normalClosure); + }); + upstream.addEventListener("error", () => { + void this.close(new LlmWebSocketCloseStatus("WebSocket error", undefined, new Error("WebSocket error"))); + }); + await new Promise((resolve, reject) => { + if (upstream.readyState === WebSocket.OPEN) { + resolve(); + return; + } + upstream.addEventListener("open", () => resolve(), { once: true }); + upstream.addEventListener("error", () => reject(new Error("WebSocket error")), { once: true }); + }); + } + + override async close( + status: LlmWebSocketCloseStatus = LlmWebSocketCloseStatus.normalClosure + ): Promise { + try { + if ( + this.#upstream?.readyState === WebSocket.OPEN || + this.#upstream?.readyState === WebSocket.CONNECTING + ) { + this.#upstream?.close(); + } + } catch { + // Best-effort; the socket may already be closed. + } + await super.close(status); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.#upstream?.close(); + } catch { + // Best-effort. + } + } + } +} + +/** + * Base class for SDK consumers who want to observe or mutate the LLM + * inference requests the runtime issues. + * + * @experimental + */ +export class LlmRequestHandler implements LlmInferenceProvider { + async onLlmRequest(req: LlmInferenceRequest): Promise { + const bridge = new LlmWebSocketResponseBridge(req.responseBody); + const ctx: InternalContext = { + requestId: req.requestId, + sessionId: req.sessionId, + transport: req.transport, + url: req.url, + headers: req.headers, + signal: req.signal, + [kBridge]: bridge, + }; + + if (req.transport === "websocket") { + await this.#handleWebSocket(req, ctx); + } else { + await this.#handleHttp(req, ctx); + } + } + + protected sendRequest(request: Request, ctx: LlmRequestContext): Promise { + return fetch(request, { signal: ctx.signal }); + } + + protected openWebSocket(ctx: LlmRequestContext): Promise { + return Promise.resolve(new ForwardingWebSocketHandler(ctx)); + } + + async #handleHttp(req: LlmInferenceRequest, ctx: LlmRequestContext): Promise { + const request = await buildFetchRequest(req); + const response = await this.sendRequest(request, ctx); + await streamResponseToSink(response, req); + } + + async #handleWebSocket(req: LlmInferenceRequest, ctx: InternalContext): Promise { + const handler = await this.openWebSocket(ctx); + try { + await handler[kOpen](); + await ctx[kBridge].start(); + + let cancelled: unknown; + const clientSettled = (async () => { + for await (const chunk of req.requestBody) { + await handler.sendRequestMessage(decodeFrame(chunk)); + } + return "client-complete" as const; + })().catch((err) => { + cancelled = err; + return "client-error" as const; + }); + + const first = await Promise.race([ + clientSettled, + handler[kCompletion].then(() => "server-done" as const), + ]); + + if (first === "client-error") { + handler[kSuppressCloseOnDispose] = true; + throw cancelled instanceof Error ? cancelled : new Error(String(cancelled)); + } + + if (first === "client-complete") { + await handler.close(LlmWebSocketCloseStatus.normalClosure); + await handler[kCompletion]; + return; + } + + const status = await handler[kCompletion]; + if (status.error) { + throw status.error; + } + } finally { + await handler[Symbol.asyncDispose](); + } + } +} + +const FORBIDDEN_REQUEST_HEADERS = new Set([ + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", +]); + +async function buildFetchRequest(req: LlmInferenceRequest): Promise { + const headers = new Headers(); + for (const [name, values] of Object.entries(req.headers)) { + if (!values) { + continue; + } + if (FORBIDDEN_REQUEST_HEADERS.has(name.toLowerCase())) { + continue; + } + for (const value of values) { + headers.append(name, value); + } + } + + const method = req.method.toUpperCase(); + const hasBody = method !== "GET" && method !== "HEAD"; + + let body: Uint8Array | undefined; + if (hasBody) { + const buffered = await drainAsync(req.requestBody); + if (buffered.length > 0) { + body = buffered; + } + } else { + await drainAsync(req.requestBody); + } + + return new Request(req.url, { method, headers, body }); +} + +async function drainAsync(stream: AsyncIterable): Promise { + const parts: Uint8Array[] = []; + let total = 0; + for await (const chunk of stream) { + parts.push(chunk); + total += chunk.byteLength; + } + if (parts.length === 0) { + return new Uint8Array(0); + } + if (parts.length === 1) { + return parts[0]; + } + const out = new Uint8Array(total); + let off = 0; + for (const part of parts) { + out.set(part, off); + off += part.byteLength; + } + return out; +} + +async function streamResponseToSink(response: Response, req: LlmInferenceRequest): Promise { + const headers = headersToMultiMap(response.headers); + await req.responseBody.start({ + status: response.status, + statusText: response.statusText || undefined, + headers, + }); + + const body = response.body; + if (!body) { + await req.responseBody.end(); + return; + } + + const reader = body.getReader(); + try { + for (;;) { + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.byteLength > 0) { + await req.responseBody.write(value); + } + } + await req.responseBody.end(); + } finally { + reader.releaseLock(); + } +} + +function headersToMultiMap(headers: Headers): LlmInferenceHeaders { + const out: Record = {}; + headers.forEach((value, name) => { + if (name.toLowerCase() === "set-cookie") { + return; + } + const list = out[name] ?? (out[name] = []); + list.push(value); + }); + const setCookies = headers.getSetCookie(); + if (setCookies.length > 0) { + out["set-cookie"] = setCookies; + } + return out; +} + +function decodeFrame(chunk: Uint8Array): string { + return sharedTextDecoder.decode(chunk); +} + +function normalizeWsData(data: unknown): string | Uint8Array { + if (typeof data === "string") { + return data; + } + if (data instanceof Uint8Array) { + return data; + } + if (data instanceof ArrayBuffer) { + return new Uint8Array(data); + } + return new Uint8Array(); +} + +class LlmWebSocketResponseBridge { + readonly #sink: LlmInferenceResponseSink; + readonly #pending: Array<() => Promise> = []; + #started = false; + #completed = false; + #serial: Promise = Promise.resolve(); + + constructor(sink: LlmInferenceResponseSink) { + this.#sink = sink; + } + + async start(): Promise { + await this.#enqueue(async () => { + if (this.#started) { + return; + } + this.#started = true; + await this.#sink.start({ status: 101, headers: {} }); + while (this.#pending.length > 0) { + await this.#pending.shift()!(); + } + }); + } + + async write(data: string | Uint8Array): Promise { + await this.#enqueueOrBuffer(async () => { + if (!this.#completed) { + await this.#sink.write(data); + } + }); + } + + async end(): Promise { + await this.#enqueueOrBuffer(async () => { + if (this.#completed) { + return; + } + this.#completed = true; + await this.#sink.end(); + }); + } + + async error(error: { message: string; code?: string }): Promise { + await this.#enqueueOrBuffer(async () => { + if (this.#completed) { + return; + } + this.#completed = true; + await this.#sink.error(error); + }); + } + + async #enqueueOrBuffer(action: () => Promise): Promise { + if (!this.#started) { + this.#pending.push(action); + return; + } + await this.#enqueue(action); + } + + async #enqueue(action: () => Promise): Promise { + const run = this.#serial.then(action, action); + this.#serial = run.catch(() => {}); + await run; + } +} diff --git a/nodejs/src/sessionFsProvider.ts b/nodejs/src/sessionFsProvider.ts index 7e959849e..5ab0bf2c5 100644 --- a/nodejs/src/sessionFsProvider.ts +++ b/nodejs/src/sessionFsProvider.ts @@ -96,7 +96,7 @@ export interface SessionFsProvider { } function normalizeSqliteParams( - params?: Record + params?: Record ): Record | undefined { if (!params) { return undefined; @@ -104,9 +104,16 @@ function normalizeSqliteParams( const normalized: Record = {}; for (const [key, value] of Object.entries(params)) { - if (value !== undefined) { + if (value === undefined) { + continue; + } + if (value === null || typeof value === "string" || typeof value === "number") { normalized[key] = value; + continue; } + throw new Error( + `Invalid SQLite bind parameter "${key}": expected string, number, or null but got ${typeof value}` + ); } return normalized; } @@ -212,11 +219,10 @@ export function createSessionFsAdapter(provider: SessionFsProvider): SessionFsHa if (!provider.sqlite) { throw new Error("SQLite is not supported by this provider"); } - const result = await provider.sqlite.query( - queryType, - query, - normalizeSqliteParams(bindParams) - ); + // The generated schema types bind-param values as `unknown` (the runtime + // emits an opaque map); normalizeSqliteParams validates each value at runtime + // and narrows to the scalar string/number/null shape SQLite accepts. + const result = await provider.sqlite.query(queryType, query, normalizeSqliteParams(bindParams)); return result ?? { rows: [], columns: [], rowsAffected: 0 }; }, sqliteExists: async () => { diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index f198a88b3..fceebd2c5 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -9,6 +9,7 @@ // Import and re-export generated session event types import type { Canvas } from "./canvas.js"; import type { SessionFsProvider } from "./sessionFsProvider.js"; +import type { LlmRequestHandler } from "./llmRequestHandler.js"; import type { ReasoningSummary, SessionEvent as GeneratedSessionEvent, @@ -33,6 +34,19 @@ export type { SessionFsFileInfo } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryResult } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryType } from "./sessionFsProvider.js"; export type { SessionFsSqliteProvider } from "./sessionFsProvider.js"; +export type { + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, +} from "./llmInferenceProvider.js"; +export type { LlmInferenceHeaders } from "./generated/rpc.js"; +export type { LlmRequestContext } from "./llmRequestHandler.js"; +export { + CopilotWebSocketHandler, + ForwardingWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, +} from "./llmRequestHandler.js"; /** * Options for creating a CopilotClient @@ -305,6 +319,27 @@ export interface CopilotClientOptions { */ sessionFs?: SessionFsConfig; + /** + * Custom LLM inference callback provider (experimental). + * + * When provided, the client registers as the runtime's LLM inference + * provider on connection: every outbound model-layer request the runtime + * would otherwise have issued itself — plain HTTP, streaming SSE, and + * WebSocket — is dispatched back to the callback over JSON-RPC. The + * callback returns the response verbatim, exactly as if the runtime had + * issued the request itself. + * + * v1 notes: + * - HTTP (buffered and streaming SSE) and WebSocket transports are all + * intercepted. The callback receives a `transport` discriminator and a + * symmetric request-body stream / response-body sink for both. + * - The callback is set process-globally on the runtime; the same + * provider is invoked for every session created on this client. + * + * @experimental + */ + llmInference?: LlmInferenceConfig; + /** * Server-wide idle timeout for sessions in seconds. * Sessions without activity for this duration are automatically cleaned up. @@ -2465,6 +2500,28 @@ export interface SessionFsConfig { }; } +/** + * Configuration for a custom LLM inference callback provider + * (experimental). + * + * @experimental + */ +export interface LlmInferenceConfig { + /** + * The handler that services LLM inference requests. The runtime routes + * all outbound model HTTP and WebSocket requests through this handler + * for the lifetime of the client, regardless of which session triggered + * them. + * + * Subclass {@link LlmRequestHandler} and override the hooks you need; + * an instance that overrides nothing is a transparent pass-through. + * + * Per-request session correlation is available on + * {@link LlmInferenceRequest.sessionId}. + */ + handler?: LlmRequestHandler; +} + /** * Filter options for listing sessions */ diff --git a/nodejs/test/e2e/llm_inference.e2e.test.ts b/nodejs/test/e2e/llm_inference.e2e.test.ts new file mode 100644 index 000000000..0d4898b92 --- /dev/null +++ b/nodejs/test/e2e/llm_inference.e2e.test.ts @@ -0,0 +1,131 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +/** + * Drain the request body and reply with a single buffered response. The + * unified callback supports both buffered and streaming uniformly — for + * non-streaming responses, the consumer writes the whole body once and + * calls `end`. + */ +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + for await (const _chunk of req.requestBody) { + // discard — the runtime always sends at least one chunk (with end:true). + } + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function handleNonStreaming(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + return respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + } + if (url.includes("/models/session")) { + return respondBuffered(req, { status: 200, headers: {} }, "{}"); + } + if (url.includes("/policy")) { + return respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + } + return respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); +} + +describe("LLM inference callback", async () => { + const received: LlmInferenceRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req): Promise { + received.push(req); + await handleNonStreaming(req); + } + })(), + }, + }, + }); + + it("registers the provider on connect without erroring", async () => { + await client.start(); + expect(client).toBeDefined(); + }); + + it( + "invokes the callback for non-streaming model-layer requests and threads sessionId through", + async () => { + const baselineLength = received.length; + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + // Drive a turn so model-layer traffic (catalog, + // session-intent, inference) flows through the callback. + // We swallow errors here — the buffered handler returns + // empty JSON for inference, which is not a valid model + // response; the agent will surface a transport error. + // What we care about is that the runtime *attempted* to + // call the callback for the model-layer endpoints. + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch { + // expected — see comment above + } + } finally { + await session.disconnect(); + } + + expect(received.length).toBeGreaterThan(baselineLength); + const newRequests = received.slice(baselineLength); + for (const r of newRequests) { + expect(r.url).toMatch(/^https?:\/\//); + expect(typeof r.method).toBe("string"); + } + + const catalog = newRequests.find((r) => r.url.toLowerCase().endsWith("/models")); + expect(catalog, "expected to intercept the /models catalog request").toBeDefined(); + + const inSession = newRequests.find((r) => typeof r.sessionId === "string"); + if (inSession) { + expect(inSession.sessionId).toMatch(/[a-zA-Z0-9-]+/); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts b/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts new file mode 100644 index 000000000..72f1471c0 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts @@ -0,0 +1,164 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function serviceNonInference(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return true; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return true; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return true; + } + return false; +} + +async function waitFor(predicate: () => boolean, timeoutMs: number): Promise { + const start = Date.now(); + while (!predicate()) { + if (Date.now() - start > timeoutMs) { + throw new Error("waitFor timed out"); + } + await new Promise((resolve) => setTimeout(resolve, 50)); + } +} + +/** + * Verifies the runtime → consumer cancellation path: when an in-flight + * turn is aborted via `session.abort()`, the runtime cancels the + * callback-served inference request and the consumer observes + * `req.signal.aborted` so it can tear down its upstream call. + */ +describe("LLM inference callback — cancellation", async () => { + let inferenceEntered = false; + let sawAbort = false; + let resolveAbortSeen: (() => void) | undefined; + const abortSeen = new Promise((resolve) => { + resolveAbortSeen = resolve; + }); + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + if (await serviceNonInference(req)) { + return; + } + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.includes("/responses") || + url.endsWith("/messages") || + url.endsWith("/v1/messages"); + if (!isInference) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + return; + } + + // Inference: never produce a response. Wait for the + // runtime to cancel us, recording the abort. + await drainRequest(req); + inferenceEntered = true; + await new Promise((resolve) => { + if (req.signal.aborted) { + resolve(); + return; + } + req.signal.addEventListener("abort", () => resolve(), { once: true }); + }); + sawAbort = true; + resolveAbortSeen?.(); + try { + await req.responseBody.error({ message: "cancelled by upstream", code: "cancelled" }); + } catch { + // Runtime already dropped the request on cancel. + } + } + })(), + }, + }, + }); + + it( + "propagates runtime cancellation to the consumer's req.signal", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + await session.send({ prompt: "Say OK." }); + await waitFor(() => inferenceEntered, 60_000); + await session.abort(); + await Promise.race([ + abortSeen, + new Promise((_resolve, reject) => + setTimeout(() => reject(new Error("timed out waiting for abort")), 30_000), + ), + ]); + } finally { + await session.disconnect(); + } + + // The consumer observed the runtime-driven cancellation. + expect(inferenceEntered).toBe(true); + expect(sawAbort).toBe(true); + }, + 120_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts b/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts new file mode 100644 index 000000000..c504bdd2b --- /dev/null +++ b/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts @@ -0,0 +1,147 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function serviceNonInference(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return true; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return true; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return true; + } + return false; +} + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.includes("/chat/completions") || + u.includes("/responses") || + u.endsWith("/messages") || + u.endsWith("/v1/messages") + ); +} + +/** + * Verifies the consumer → runtime cancellation path: when the consumer + * itself decides to abort the upstream call (e.g. its own + * `AbortController` fired, or the upstream socket dropped), it signals the + * runtime via `responseBody.error({ code: "cancelled" })`. The runtime + * must surface that faithfully as a request failure rather than hanging + * waiting for a response head/body. + */ +describe("LLM inference callback — consumer-initiated cancellation", async () => { + let inferenceAttempts = 0; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + if (await serviceNonInference(req)) { + return; + } + if (!isInferenceUrl(req.url)) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + return; + } + + // Consumer-initiated cancellation: the consumer's own + // upstream call was aborted, so it tells the runtime to + // give up on this request. No response head is ever + // produced; the runtime should see a transport failure. + await drainRequest(req); + inferenceAttempts += 1; + await req.responseBody.error({ + message: "upstream call aborted by consumer", + code: "cancelled", + }); + } + })(), + }, + }, + }); + + it( + "surfaces a consumer-signalled cancellation to the runtime", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + + let caught: unknown; + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch (err) { + caught = err; + } finally { + await session.disconnect(); + } + + // The runtime reached the inference step and the consumer's + // cancellation terminated it (rather than the runtime hanging). + expect(inferenceAttempts).toBeGreaterThan(0); + if (caught) { + const message = caught instanceof Error ? caught.message : String(caught); + expect(message.length).toBeGreaterThan(0); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_errors.e2e.test.ts b/nodejs/test/e2e/llm_inference_errors.e2e.test.ts new file mode 100644 index 000000000..4d8c84643 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_errors.e2e.test.ts @@ -0,0 +1,147 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * Verifies that errors thrown (or signalled via `responseBody.error`) by + * the LLM inference callback surface to the SDK consumer as transport + * failures, so the runtime's existing retry / error-reporting machinery + * handles them uniformly. + */ +describe("LLM inference callback — error mapping", async () => { + let callsBeforeError = 0; + let totalCalls = 0; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + totalCalls += 1; + const url = req.url.toLowerCase(); + + // Service models / session / policy normally so the + // agent can reach the inference step. + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { + max_context_window_tokens: 200000, + max_output_tokens: 8192, + }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered( + req, + { status: 200, headers: {} }, + JSON.stringify({ state: "enabled" }), + ); + return; + } + + // Inference: throw a transport-level error from the + // callback. The adapter converts this into a + // terminal `httpResponseChunk` with `error` set, so + // the runtime surfaces it as `APIConnectionError`. + if (url.includes("/chat/completions") || url.includes("/responses")) { + await drainRequest(req); + callsBeforeError += 1; + throw new Error("synthetic-callback-transport-failure"); + } + + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + } + })(), + }, + }, + }); + + it( + "surfaces a callback-thrown error to the SDK consumer", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + + let caught: unknown; + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch (err) { + caught = err; + } finally { + await session.disconnect(); + } + + // The agent layer typically wraps inference failures in its + // own error type and may convert them to an event rather than + // a thrown exception, so the assertion is loose: either we + // caught an error referencing the callback failure, or the + // inference call was attempted at least once and the runtime + // did NOT hang waiting for a response. + expect(totalCalls).toBeGreaterThan(0); + expect(callsBeforeError).toBeGreaterThan(0); + if (caught) { + const message = caught instanceof Error ? caught.message : String(caught); + expect(message.length).toBeGreaterThan(0); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_handler.e2e.test.ts b/nodejs/test/e2e/llm_inference_handler.e2e.test.ts new file mode 100644 index 000000000..e8fcc7529 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_handler.e2e.test.ts @@ -0,0 +1,390 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { createServer, IncomingMessage, Server as HttpServer, ServerResponse } from "http"; +import { AddressInfo } from "net"; +import { afterAll, describe, expect, it } from "vitest"; +import { WebSocket as WsClient, WebSocketServer } from "ws"; +import { + approveAll, + CopilotWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, + type LlmRequestContext, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const HTTP_TEXT = "OK from synthetic HTTP upstream."; +const WS_TEXT = "OK from synthetic WS upstream."; + +/** + * Stand up an in-process upstream that speaks the real CAPI shapes the + * runtime needs: model catalog, policy, `/responses` SSE for HTTP + * inference, and a WebSocket endpoint at `/responses` that answers each + * inbound `response.create` with the ordered `/responses` events the + * reducer expects. + * + * Returned `url` is what the handler subclass rewrites every + * intercepted request to point at — the runtime never talks to this + * server directly; the handler does, on the runtime's behalf. + */ +async function startFakeUpstream(): Promise<{ + url: string; + server: HttpServer; + wsRequestCount: () => number; + close: () => Promise; +}> { + let wsRequests = 0; + + const httpServer = createServer((req, res) => { + const url = new URL(req.url ?? "/", `http://${req.headers.host ?? "localhost"}`); + if (url.pathname === "/models" && req.method === "GET") { + sendJson(res, 200, { + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + supported_endpoints: ["/responses", "ws:/responses"], + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { + max_context_window_tokens: 200000, + max_output_tokens: 8192, + }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }); + return; + } + if (url.pathname.endsWith("/models/session")) { + sendJson(res, 200, {}); + return; + } + if (url.pathname.includes("/policy")) { + sendJson(res, 200, { state: "enabled" }); + return; + } + if (url.pathname.endsWith("/responses") && req.method === "POST") { + // Single-shot HTTP inference (e.g. title generation). SSE + // events the `responses-client.ts` reducer accepts. + drainBody(req) + .then(() => { + res.writeHead(200, { + "content-type": "text/event-stream", + "cache-control": "no-cache", + }); + for (const event of buildResponsesEvents(HTTP_TEXT, "resp_stub_http")) { + res.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + } + res.end(); + }) + .catch(() => { + res.writeHead(500).end(); + }); + return; + } + // Anything else: not found. + res.writeHead(404, { "content-type": "application/json" }); + res.end(JSON.stringify({ error: "not_found", path: url.pathname })); + }); + + const wss = new WebSocketServer({ server: httpServer, path: "/responses" }); + wss.on("connection", (socket) => { + socket.on("message", (raw) => { + wsRequests++; + // For each `response.create` request the runtime sends, + // answer with the ordered `/responses` event objects — one + // event per outbound WS message, raw JSON (NOT SSE-framed). + for (const event of buildResponsesEvents(WS_TEXT, "resp_stub_ws")) { + socket.send(JSON.stringify(event)); + } + void raw; + }); + }); + + await new Promise((resolve) => httpServer.listen(0, "127.0.0.1", resolve)); + const port = (httpServer.address() as AddressInfo).port; + const url = `http://127.0.0.1:${port}`; + + return { + url, + server: httpServer, + wsRequestCount: () => wsRequests, + async close() { + wss.clients.forEach((c) => c.terminate()); + await new Promise((resolve) => wss.close(() => resolve())); + await new Promise((resolve) => httpServer.close(() => resolve())); + }, + }; +} + +function sendJson(res: ServerResponse, status: number, body: unknown): void { + res.writeHead(status, { "content-type": "application/json" }); + res.end(JSON.stringify(body)); +} + +async function drainBody(req: IncomingMessage): Promise { + const parts: Buffer[] = []; + for await (const chunk of req) { + parts.push(chunk as Buffer); + } + return Buffer.concat(parts); +} + +function buildResponsesEvents(text: string, id: string): Array> { + return [ + { + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + }, + { + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + }, + { + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + }, + { type: "response.output_text.delta", output_index: 0, content_index: 0, delta: text }, + { type: "response.output_text.done", output_index: 0, content_index: 0, text }, + { + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + }, + ]; +} + +/** +interface Counters { + httpRequests: number; + httpResponses: number; + wsRequestMessages: number; + wsResponseMessages: number; +} + +/** + * Single handler subclass that services BOTH transports against the + * per-test fake upstream. Demonstrates mutation in each direction: + * + * - HTTP: rewrites the URL to point at the test server, adds an + * `X-Test-Mutated` header to the outbound request, and adds an + * `X-Test-Response-Mutated` header on the way back. The test server + * echoes the request header into a counter so we can assert it + * actually arrived upstream. + * - WebSocket: rewrites the WS URL similarly, opens with the `ws` + * package inside a custom per-connection handler, and observes + * message counts in both directions. + */ +class TestHandler extends LlmRequestHandler { + constructor( + private readonly upstreamUrl: string, + private readonly counters: Counters + ) { + super(); + } + + private rewriteUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + parsed.protocol = upstream.protocol; + parsed.host = upstream.host; + return parsed.toString(); + } + + private rewriteWsUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + // The upstream URL is http(s); flip to ws(s) for the WS open. + parsed.protocol = upstream.protocol === "https:" ? "wss:" : "ws:"; + parsed.host = upstream.host; + return parsed.toString(); + } + + protected override async sendRequest(request: Request, _ctx: LlmRequestContext): Promise { + this.counters.httpRequests++; + const rewritten = this.rewriteUrl(request.url); + const requestHeaders = new Headers(request.headers); + requestHeaders.set("x-test-mutated", "1"); + const rewrittenRequest = new Request(rewritten, { + method: request.method, + headers: requestHeaders, + body: request.body, + // @ts-expect-error duplex is required by undici when streaming a body + duplex: "half", + }); + const response = await fetch(rewrittenRequest, { signal: _ctx.signal }); + this.counters.httpResponses++; + const responseHeaders = new Headers(response.headers); + responseHeaders.set("x-test-response-mutated", "1"); + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers: responseHeaders, + }); + } + + protected override async openWebSocket(ctx: LlmRequestContext): Promise { + return TestSocketHandler.connect(this.rewriteWsUrl(ctx.url), ctx, this.counters); + } +} + +class TestSocketHandler extends CopilotWebSocketHandler { + static async connect( + url: string, + ctx: LlmRequestContext, + counters: Counters + ): Promise { + const client = new WsClient(url); + await new Promise((resolve, reject) => { + client.once("open", () => resolve()); + client.once("error", (err) => reject(err)); + }); + return new TestSocketHandler(client, ctx, counters); + } + + private constructor( + private readonly client: WsClient, + ctx: LlmRequestContext, + private readonly counters: Counters + ) { + super(ctx); + this.client.on("message", (data, isBinary) => { + this.counters.wsResponseMessages++; + void this.sendResponseMessage(isBinary ? (data as Buffer) : data.toString("utf-8")); + }); + this.client.once("close", () => { + void this.close(); + }); + this.client.once("error", (err) => { + void this.close(new LlmWebSocketCloseStatus(err.message, undefined, err as Error)); + }); + const onAbort = (): void => { + try { + this.client.close(); + } catch { + /* best-effort */ + } + }; + ctx.signal.addEventListener("abort", onAbort, { once: true }); + this.client.once("close", () => ctx.signal.removeEventListener("abort", onAbort)); + } + + override sendRequestMessage(data: string | Uint8Array): void { + this.counters.wsRequestMessages++; + if (this.client.readyState !== WsClient.OPEN) { + return; + } + this.client.send(data); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.client.close(); + } catch { + /* best-effort */ + } + } + } +} + +describe("LlmRequestHandler — single subclass handles HTTP + WebSocket", async () => { + const upstream = await startFakeUpstream(); + const counters: Counters = { + httpRequests: 0, + httpResponses: 0, + wsRequestMessages: 0, + wsResponseMessages: 0, + }; + + const { copilotClient: client, env } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new TestHandler(upstream.url, counters), + }, + }, + }); + + // Enable the WebSocket Responses transport in the spawned runtime so + // the main agent turn picks the WS path; single-shot calls (title + // generation) still go over HTTP through the same subclass. + env.COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES = "true"; + + afterAll(async () => { + await upstream.close(); + }); + + it("services both an HTTP turn and a WebSocket turn end-to-end via one handler", async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // The HTTP hooks fired — the runtime issued model-layer GETs + // (catalog, policy) and possibly a single-shot inference. + expect(counters.httpRequests, "expected sendRequest to fire").toBeGreaterThan(0); + expect(counters.httpResponses, "expected sendRequest response mutation to fire").toBeGreaterThan( + 0 + ); + + // The WebSocket hooks fired — the main agent turn went over + // the WS path and we observed messages in both directions. + expect( + counters.wsRequestMessages, + "expected sendRequestMessage (runtime → upstream) to fire" + ).toBeGreaterThan(0); + expect( + counters.wsResponseMessages, + "expected sendResponseMessage (upstream → runtime) to fire" + ).toBeGreaterThan(0); + expect( + upstream.wsRequestCount(), + "expected upstream WS to receive request messages" + ).toBeGreaterThan(0); + + // The synthetic content from the upstream surfaced in the + // assistant turn — proves the full chain (runtime → handler + // → upstream → handler → runtime) is intact for the + // transport the main agent turn used. + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from synthetic (HTTP|WS) upstream/); + }, 90_000); +}); diff --git a/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts b/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts new file mode 100644 index 000000000..8637f7b6e --- /dev/null +++ b/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts @@ -0,0 +1,335 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const SYNTHETIC_TEXT = "OK from the synthetic stream."; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * Serve the model-layer GETs/POSTs the runtime issues that are not + * inference (catalog, model session, policy). These flow through the same + * callback but carry no session id (they happen outside an agent turn). + */ +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }) + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesize a well-formed inference response so the agent turn completes. + * The runtime selects `/responses` for both the CAPI and BYOK sessions + * here; `/chat/completions` is handled too for robustness. The consumer + * fabricates the response directly — there is no upstream server and the + * CAPI record/replay proxy is never the inference endpoint. + */ +async function handleInference(req: LlmInferenceRequest): Promise { + const bodyText = await drainRequest(req); + const wantsStream = /"stream"\s*:\s*true/.test(bodyText); + const url = req.url.toLowerCase(); + + // `/responses` streams via SSE only when the request asked for it + // (`stream: true`). BYOK turns whose config-derived model doesn't + // advertise streaming issue a buffered request expecting a single + // JSON `response` object, so branch on the flag exactly as a real + // upstream would. + if (url.includes("/responses")) { + if (!wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["application/json"] }, + }); + await req.responseBody.write( + JSON.stringify({ + id: "resp_stub_1", + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }) + ); + await req.responseBody.end(); + return; + } + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const id = "resp_stub_1"; + const events: string[] = [ + `event: response.created\ndata: ${JSON.stringify({ + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + })}\n\n`, + `event: response.output_item.added\ndata: ${JSON.stringify({ + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + })}\n\n`, + `event: response.content_part.added\ndata: ${JSON.stringify({ + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + })}\n\n`, + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + output_index: 0, + content_index: 0, + delta: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.output_text.done\ndata: ${JSON.stringify({ + type: "response.output_text.done", + output_index: 0, + content_index: 0, + text: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + })}\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + if (url.includes("/chat/completions") && wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const base = { id: "chatcmpl-stub-1", object: "chat.completion.chunk", created: 1, model: "claude-sonnet-4.5" }; + const events: string[] = [ + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { role: "assistant", content: "" }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { content: SYNTHETIC_TEXT }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + })}\n\n`, + `data: [DONE]\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + // /chat/completions non-streaming — buffered JSON. + await req.responseBody.start({ status: 200, headers: { "content-type": ["application/json"] } }); + await req.responseBody.write( + JSON.stringify({ + id: "chatcmpl-stub-1", + object: "chat.completion", + created: 1, + model: "claude-sonnet-4.5", + choices: [ + { index: 0, message: { role: "assistant", content: SYNTHETIC_TEXT }, finish_reason: "stop" }, + ], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + }) + ); + await req.responseBody.end(); +} + +interface InterceptedRequest { + url: string; + sessionId?: string; +} + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); +} + +/** + * Asserts the runtime threads its session id into the LLM inference + * callback for BOTH a CAPI session and a BYOK session. The callback alone + * services every model-layer request — no upstream server, no CAPI proxy + * acting as the inference endpoint — so the only source of `req.sessionId` + * is the runtime's own per-client threading. + */ +describe("LLM inference callback threads the runtime session id (CAPI + BYOK)", async () => { + const records: InterceptedRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + records.push({ url: req.url, sessionId: req.sessionId }); + if (isInferenceUrl(req.url)) { + await handleInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + let capiSessionId: string | undefined; + + it("threads the session id into a CAPI session's inference request", async () => { + await client.start(); + const baseline = records.length; + const session = await client.createSession({ onPermissionRequest: approveAll }); + capiSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect(inference.length, "expected at least one intercepted inference request").toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "CAPI inference request must carry the runtime session id").toBe( + session.sessionId + ); + } + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); + + it("threads the session id into a BYOK session's inference request", async () => { + await client.start(); + const baseline = records.length; + const session = await client.createSession({ + onPermissionRequest: approveAll, + // BYOK providers require an explicit model id. + model: "claude-sonnet-4.5", + provider: { + type: "openai", + wireApi: "responses", + baseUrl: "https://byok.invalid/v1", + apiKey: "byok-secret", + modelId: "claude-sonnet-4.5", + wireModel: "claude-sonnet-4.5", + }, + }); + const byokSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect(inference.length, "expected at least one intercepted BYOK inference request").toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "BYOK inference request must carry the runtime session id").toBe(byokSessionId); + } + + // Session ids are per-session, so the two turns must differ — proves + // we assert against a real, request-specific id, not a constant. + expect(byokSessionId).not.toBe(capiSessionId); + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); +}); diff --git a/nodejs/test/e2e/llm_inference_stream.e2e.test.ts b/nodejs/test/e2e/llm_inference_stream.e2e.test.ts new file mode 100644 index 000000000..db25cf41f --- /dev/null +++ b/nodejs/test/e2e/llm_inference_stream.e2e.test.ts @@ -0,0 +1,260 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesizes a minimal but well-formed response for the runtime's + * inference request. The runtime calls the buffered code path for + * `/chat/completions` and the streaming code path for `/responses`, but + * the unified callback has no field telling the consumer which — the + * consumer dispatches by URL. + */ +async function handleInference(req: LlmInferenceRequest): Promise { + const bodyText = await drainRequest(req); + const wantsStream = /"stream"\s*:\s*true/.test(bodyText); + const url = req.url.toLowerCase(); + + if (url.includes("/responses")) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const id = "resp_stub_1"; + const events: string[] = [ + `event: response.created\ndata: ${JSON.stringify({ + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + })}\n\n`, + `event: response.output_item.added\ndata: ${JSON.stringify({ + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + })}\n\n`, + `event: response.content_part.added\ndata: ${JSON.stringify({ + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + })}\n\n`, + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + output_index: 0, + content_index: 0, + delta: "OK from the synthetic stream.", + })}\n\n`, + `event: response.output_text.done\ndata: ${JSON.stringify({ + type: "response.output_text.done", + output_index: 0, + content_index: 0, + text: "OK from the synthetic stream.", + })}\n\n`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "OK from the synthetic stream." }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + })}\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + if (url.includes("/chat/completions") && wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const base = { + id: "chatcmpl-stub-1", + object: "chat.completion.chunk", + created: 1, + model: "claude-sonnet-4.5", + }; + const events: string[] = [ + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { role: "assistant", content: "" }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [ + { + index: 0, + delta: { content: "OK from the synthetic stream." }, + finish_reason: null, + }, + ], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + })}\n\n`, + `data: [DONE]\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + // /chat/completions non-streaming — buffered JSON. (body already drained above) + await req.responseBody.start({ status: 200, headers: { "content-type": ["application/json"] } }); + await req.responseBody.write( + JSON.stringify({ + id: "chatcmpl-stub-1", + object: "chat.completion", + created: 1, + model: "claude-sonnet-4.5", + choices: [ + { + index: 0, + message: { role: "assistant", content: "OK from the synthetic stream." }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + }), + ); + await req.responseBody.end(); +} + +describe("LLM inference callback — fully mocked streaming", async () => { + const received: LlmInferenceRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + received.push(req); + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.endsWith("/responses") || + url.endsWith("/v1/messages") || + url.endsWith("/messages"); + if (isInference) { + await handleInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + it( + "completes a full user→assistant turn entirely via the callback (chunked SSE response)", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // At least one inference request flowed through the callback. + const inferenceReqs = received.filter((r) => { + const u = r.url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); + }); + expect(inferenceReqs.length, "expected at least one inference request via the callback").toBeGreaterThan( + 0, + ); + + // The synthetic content surfaced in the assistant response. + expect(resultJson).toMatch(/OK from the synthetic/); + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts b/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts new file mode 100644 index 000000000..440124784 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts @@ -0,0 +1,226 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const WS_TEXT = "OK from the synthetic ws."; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * The fake model catalog advertises both `/responses` and `ws:/responses` + * so `pickModelProtocol` selects the Responses wire API and `ai-client.ts` + * is allowed to pick the WebSocket transport (the feature flag is enabled + * via the env var below). No `/v1/messages`, otherwise the model would be + * routed to the Anthropic Messages API instead. + */ +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + supported_endpoints: ["/responses", "ws:/responses"], + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesizes the `/responses` SSE event stream for the HTTP code path + * (single-shot inference requests — e.g. title generation — that don't + * pick the WebSocket transport). + */ +async function handleHttpInference(req: LlmInferenceRequest): Promise { + await drainRequest(req); + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + for (const event of buildResponsesEvents()) { + await req.responseBody.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + } + await req.responseBody.end(); +} + +/** + * Builds the ordered `/responses` event objects the reducer expects. + * Used raw (one object = one WS message) for the WebSocket path and + * SSE-framed for the HTTP path. + */ +function buildResponsesEvents(): Array> { + const id = "resp_stub_ws_1"; + return [ + { type: "response.created", response: { id, object: "response", status: "in_progress", output: [] } }, + { + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + }, + { + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + }, + { type: "response.output_text.delta", output_index: 0, content_index: 0, delta: WS_TEXT }, + { type: "response.output_text.done", output_index: 0, content_index: 0, text: WS_TEXT }, + { + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: WS_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + }, + ]; +} + +/** + * Full-duplex WebSocket handler. The runtime opens the channel + * (`transport === "websocket"`), the consumer acks the upgrade, then + * pumps bidirectionally: every inbound `response.create` request the + * runtime sends is answered with the ordered `/responses` event objects, + * one event per outbound WS message (raw JSON, *not* SSE-framed). The + * connection is reused across turns; it stays open until the runtime + * closes it, at which point `req.requestBody` throws and we stop. + */ +async function handleWebSocket(req: LlmInferenceRequest, onRequest: () => void): Promise { + // Ack the upgrade (status 101-equivalent) before any message flows. + await req.responseBody.start({ status: 101, headers: {} }); + try { + for await (const _outbound of req.requestBody) { + onRequest(); + for (const event of buildResponsesEvents()) { + await req.responseBody.write(JSON.stringify(event)); + } + } + } catch { + // Expected: the runtime cancels the request body when it closes the + // socket at session teardown. Nothing more to do. + } +} + +describe("LLM inference callback — full-duplex WebSocket transport", async () => { + const received: LlmInferenceRequest[] = []; + let wsRequestCount = 0; + + const { copilotClient: client, env } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + received.push(req); + if (req.transport === "websocket") { + await handleWebSocket(req, () => { + wsRequestCount++; + }); + return; + } + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.endsWith("/responses") || + url.endsWith("/v1/messages") || + url.endsWith("/messages"); + if (isInference) { + await handleHttpInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + // Enable the WebSocket Responses transport in the spawned runtime. The + // harness env object is the same one passed to the CLI subprocess, so + // mutating it here flips the ExP flag for this test file's client. + env.COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES = "true"; + + it( + "completes a user→assistant turn over the WebSocket transport via the callback", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // The main agent turn (tools present, not single-shot) selected the + // WebSocket transport and drove it through the callback. + const wsReqs = received.filter((r) => r.transport === "websocket"); + expect(wsReqs.length, "expected at least one websocket request via the callback").toBeGreaterThan(0); + expect(wsRequestCount, "expected the runtime to send at least one ws message").toBeGreaterThan(0); + + // The synthetic content surfaced in the assistant response. + expect(resultJson).toMatch(/OK from the synthetic ws/); + }, + 90_000, + ); +}); diff --git a/nodejs/test/llm_inference_callbacks.test.ts b/nodejs/test/llm_inference_callbacks.test.ts new file mode 100644 index 000000000..061082ca6 --- /dev/null +++ b/nodejs/test/llm_inference_callbacks.test.ts @@ -0,0 +1,294 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { + CopilotWebSocketHandler, + LlmRequestHandler, + type LlmInferenceRequest, + type LlmInferenceResponseInit, + type LlmInferenceResponseSink, + type LlmRequestContext, + LlmWebSocketCloseStatus, +} from "../src/index.js"; +import { + createLlmInferenceAdapter, + type LlmInferenceProvider, +} from "../src/llmInferenceProvider.js"; + +/** + * Minimal fake of the server RPC surface the adapter uses to send response + * frames back to the runtime. Records every frame and lets the test decide + * what `accepted` value the runtime returns. + */ +function makeFakeServerRpc(accepted: { start?: boolean; chunk?: boolean } = {}): { + rpc: () => ReturnType; + starts: LlmInferenceResponseInit[]; + chunks: { data: string; binary?: boolean; end?: boolean; error?: unknown }[]; +} { + const starts: LlmInferenceResponseInit[] = []; + const chunks: { data: string; binary?: boolean; end?: boolean; error?: unknown }[] = []; + function createFakeRpc() { + return { + llmInference: { + async httpResponseStart(p: { + status: number; + statusText?: string; + headers: Record; + }) { + starts.push({ status: p.status, statusText: p.statusText, headers: p.headers }); + return { accepted: accepted.start ?? true }; + }, + async httpResponseChunk(p: { + data: string; + binary?: boolean; + end?: boolean; + error?: unknown; + }) { + chunks.push({ data: p.data, binary: p.binary, end: p.end, error: p.error }); + return { accepted: accepted.chunk ?? true }; + }, + }, + }; + } + const single = createFakeRpc(); + return { rpc: () => single, starts, chunks }; +} + +describe("createLlmInferenceAdapter", () => { + it("stages body chunks that arrive before their start frame and replays them in order", async () => { + const received: string[] = []; + let resolveDone: () => void; + const done = new Promise((r) => { + resolveDone = r; + }); + const provider: LlmInferenceProvider = { + async onLlmRequest(req: LlmInferenceRequest) { + const decoder = new TextDecoder(); + for await (const chunk of req.requestBody) { + received.push(decoder.decode(chunk)); + } + await req.responseBody.start({ status: 200, headers: {} }); + await req.responseBody.end(); + resolveDone(); + }, + }; + const fake = makeFakeServerRpc(); + const handler = createLlmInferenceAdapter(provider, () => fake.rpc() as never); + + // Chunks arrive BEFORE the start frame (simulating a reordering the + // runtime should never actually produce). They must be staged and + // delivered once the start frame registers the request. + await handler.httpRequestChunk({ + requestId: "r1", + data: "hello ", + binary: false, + end: false, + }); + await handler.httpRequestChunk({ + requestId: "r1", + data: "world", + binary: false, + end: false, + }); + await handler.httpRequestChunk({ requestId: "r1", data: "", end: true }); + + await handler.httpRequestStart({ + requestId: "r1", + method: "POST", + url: "https://example.test/v1/chat", + headers: {}, + transport: "http", + }); + + await done; + expect(received.join("")).toBe("hello world"); + }); + + it("aborts the provider when the runtime rejects a response frame (accepted=false)", async () => { + let aborted = false; + let writeThrew = false; + let finished: () => void; + const settled = new Promise((r) => { + finished = r; + }); + const provider: LlmInferenceProvider = { + async onLlmRequest(req: LlmInferenceRequest) { + req.signal.addEventListener("abort", () => { + aborted = true; + }); + for await (const _ of req.requestBody) { + // drain + } + await req.responseBody.start({ status: 200, headers: {} }); + try { + await req.responseBody.write("rejected-chunk"); + } catch { + writeThrew = true; + } + finished(); + }, + }; + const fake = makeFakeServerRpc({ start: true, chunk: false }); + const handler = createLlmInferenceAdapter(provider, () => fake.rpc() as never); + + await handler.httpRequestStart({ + requestId: "r2", + method: "POST", + url: "https://example.test/v1/chat", + headers: {}, + transport: "http", + }); + await handler.httpRequestChunk({ requestId: "r2", data: "", end: true }); + + await settled; + expect(writeThrew).toBe(true); + expect(aborted).toBe(true); + }); +}); + +/** + * Controllable fake of a callback-owned WebSocket connection. The test drives + * messages, close, and error explicitly. + */ +class FakeSocketHandler extends CopilotWebSocketHandler { + sent: (string | Uint8Array)[] = []; + + override sendRequestMessage(data: string | Uint8Array): void { + this.sent.push(data); + } + + async emitMessage(data: string | Uint8Array): Promise { + await this.sendResponseMessage(data); + } + + async closeFromUpstream(): Promise { + await this.close(); + } + + async failFromUpstream(error: Error): Promise { + await this.close(new LlmWebSocketCloseStatus(error.message, undefined, error)); + } +} + +interface RecordingSink extends LlmInferenceResponseSink { + starts: LlmInferenceResponseInit[]; + writes: (string | Uint8Array)[]; + ended: boolean; + errored?: { message: string; code?: string }; +} + +function makeRecordingSink(): RecordingSink { + const sink: RecordingSink = { + starts: [], + writes: [], + ended: false, + async start(init) { + sink.starts.push(init); + }, + async write(data) { + sink.writes.push(data); + }, + async end() { + sink.ended = true; + }, + async error(err) { + sink.errored = err; + }, + }; + return sink; +} + +/** Async-iterable request body that yields nothing until the test releases it. */ +function gatedRequestBody(): { body: AsyncIterable; release: () => void } { + let release!: () => void; + const gate = new Promise((r) => { + release = r; + }); + return { + release, + body: { + async *[Symbol.asyncIterator]() { + await gate; + }, + }, + }; +} + +describe("LlmRequestHandler WebSocket dispatch", () => { + it("finalises the response when the upstream closes while the request stream is still open", async () => { + let upstream!: FakeSocketHandler; + class Handler extends LlmRequestHandler { + protected override openWebSocket(ctx: LlmRequestContext): CopilotWebSocketHandler { + upstream = new FakeSocketHandler(ctx); + return upstream; + } + } + const handler = new Handler(); + const sink = makeRecordingSink(); + const gated = gatedRequestBody(); + const abort = new AbortController(); + const req: LlmInferenceRequest = { + requestId: "ws1", + method: "GET", + url: "wss://example.test/responses", + headers: {}, + transport: "websocket", + requestBody: gated.body, + signal: abort.signal, + responseBody: sink, + }; + + const turn = handler.onLlmRequest(req); + + // Let the handler register its listeners and ack the upgrade, then + // deliver an upstream message and close the socket — all while the + // request body is still parked (no runtime → upstream frames yet). + await new Promise((r) => setTimeout(r, 10)); + await upstream.emitMessage("server-event-1"); + await upstream.closeFromUpstream(); + + // The turn must resolve (not hang) because the upstream terminated. + await turn; + + expect(sink.starts).toEqual([{ status: 101, headers: {} }]); + expect(sink.writes).toContain("server-event-1"); + expect(sink.ended).toBe(true); + + gated.release(); + }); + + it("surfaces an upstream error as a thrown failure", async () => { + let upstream!: FakeSocketHandler; + class Handler extends LlmRequestHandler { + protected override openWebSocket(ctx: LlmRequestContext): CopilotWebSocketHandler { + upstream = new FakeSocketHandler(ctx); + return upstream; + } + } + const handler = new Handler(); + const sink = makeRecordingSink(); + const gated = gatedRequestBody(); + const abort = new AbortController(); + const req: LlmInferenceRequest = { + requestId: "ws2", + method: "GET", + url: "wss://example.test/responses", + headers: {}, + transport: "websocket", + requestBody: gated.body, + signal: abort.signal, + responseBody: sink, + }; + + const turn = handler.onLlmRequest(req); + await new Promise((r) => setTimeout(r, 10)); + await upstream.failFromUpstream(new Error("upstream exploded")); + + await expect(turn).rejects.toThrow("upstream exploded"); + expect(sink.ended).toBe(false); + + gated.release(); + }); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 1bda91072..3c48f2440 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -148,6 +148,22 @@ SessionFsSqliteQueryResult, create_session_fs_adapter, ) +from .llm_inference_provider import ( + LlmInferenceConfig, + LlmInferenceHeaders, + LlmInferenceProvider, + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, + create_llm_inference_adapter, +) +from .llm_request_handler import ( + CopilotWebSocketHandler, + ForwardingWebSocketHandler, + LlmRequestContext, + LlmRequestHandler, + LlmWebSocketCloseStatus, +) from .tools import ( Tool, ToolBinaryResult, @@ -186,6 +202,7 @@ "CopilotClient", "CopilotClientMode", "CopilotSession", + "CopilotWebSocketHandler", "CreateSessionFsHandler", "ElicitationContext", "ElicitationHandler", @@ -198,11 +215,21 @@ "ExitPlanModeRequest", "ExitPlanModeResult", "ExtensionInfo", + "ForwardingWebSocketHandler", "GetAuthStatusResponse", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", "LargeToolOutputConfig", + "LlmInferenceConfig", + "LlmInferenceHeaders", + "LlmInferenceProvider", + "LlmInferenceRequest", + "LlmInferenceResponseInit", + "LlmInferenceResponseSink", + "LlmRequestContext", + "LlmRequestHandler", + "LlmWebSocketCloseStatus", "LogLevel", "MCPHTTPServerConfig", "MCPServerConfig", @@ -297,6 +324,7 @@ "UserPromptSubmittedHookInput", "UserPromptSubmittedHookOutput", "convert_mcp_call_tool_result", + "create_llm_inference_adapter", "create_session_fs_adapter", "define_tool", ] diff --git a/python/copilot/client.py b/python/copilot/client.py index 2c407149c..f4a64719e 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -62,6 +62,7 @@ ExtensionInfo, ) from .generated.rpc import ( + ClientGlobalApiHandlers, ClientSessionApiHandlers, ModelBillingTokenPrices, ModelBillingTokenPricesLongContext, # noqa: F401 @@ -71,6 +72,7 @@ _ConnectRequest, _InternalServerRpc, from_datetime, + register_client_global_api_handlers, register_client_session_api_handlers, ) from .generated.session_events import ( @@ -106,6 +108,7 @@ _PermissionHandlerFn, ) from .session_fs_provider import SessionFsProvider, create_session_fs_adapter +from .llm_inference_provider import LlmInferenceConfig, create_llm_inference_adapter from .tools import Tool logger = logging.getLogger(__name__) @@ -352,6 +355,7 @@ class _CopilotClientOptions: use_logged_in_user: bool | None = None telemetry: TelemetryConfig | None = None session_fs: SessionFsConfig | None = None + llm_inference: LlmInferenceConfig | None = None session_idle_timeout_seconds: int | None = None enable_remote_sessions: bool = False on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] | None = None @@ -1049,6 +1053,7 @@ def __init__( use_logged_in_user: bool | None = None, telemetry: TelemetryConfig | None = None, session_fs: SessionFsConfig | None = None, + llm_inference: LlmInferenceConfig | None = None, session_idle_timeout_seconds: int | None = None, enable_remote_sessions: bool = False, on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] | None = None, @@ -1083,6 +1088,10 @@ def __init__( telemetry. session_fs: Connection-level session filesystem provider configuration. + llm_inference: Connection-level LLM inference callback + configuration. When set, the supplied handler services every + model-layer HTTP/WebSocket request the runtime would otherwise + issue (both BYOK and CAPI). session_idle_timeout_seconds: Server-wide session idle timeout in seconds. Sessions without activity for this duration are automatically cleaned up. Set to ``None`` or ``0`` to disable. @@ -1119,6 +1128,7 @@ def __init__( use_logged_in_user=use_logged_in_user, telemetry=telemetry, session_fs=session_fs, + llm_inference=llm_inference, session_idle_timeout_seconds=session_idle_timeout_seconds, enable_remote_sessions=enable_remote_sessions, on_list_models=on_list_models, @@ -1209,6 +1219,7 @@ def __init__( if options.session_fs is not None: _validate_session_fs_config(options.session_fs) self._session_fs_config = options.session_fs + self._llm_inference_config = options.llm_inference @property def rpc(self) -> ServerRpc: @@ -1361,6 +1372,9 @@ async def start(self) -> None: session_fs_start, ) + if self._llm_inference_config is not None: + await self._set_llm_inference_provider() + self._state = "connected" log_timing( logger, @@ -3532,6 +3546,7 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) register_client_session_api_handlers(self._client, self._get_client_session_handlers) + self._register_llm_inference_handlers() # Start listening for messages loop = asyncio.get_running_loop() @@ -3651,6 +3666,7 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) register_client_session_api_handlers(self._client, self._get_client_session_handlers) + self._register_llm_inference_handlers() # Start listening for messages loop = asyncio.get_running_loop() @@ -3723,6 +3739,22 @@ async def _set_session_fs_provider(self) -> None: await self._client.request("sessionFs.setProvider", params) + def _register_llm_inference_handlers(self) -> None: + if self._llm_inference_config is None or not self._client: + return + adapter = create_llm_inference_adapter( + self._llm_inference_config.handler, + lambda: self._rpc.llm_inference if self._rpc is not None else None, + ) + register_client_global_api_handlers( + self._client, ClientGlobalApiHandlers(llm_inference=adapter) + ) + + async def _set_llm_inference_provider(self) -> None: + if self._llm_inference_config is None or self._rpc is None: + return + await self._rpc.llm_inference.set_provider() + def _get_client_session_handlers(self, session_id: str) -> ClientSessionApiHandlers: with self._sessions_lock: session = self._sessions.get(session_id) diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 6822fd7ff..e35d6a228 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -1864,6 +1864,212 @@ def to_dict(self) -> dict: result["projectPaths"] = from_union([lambda x: from_list(from_str, x), from_none], self.project_paths) return result +@dataclass +class LlmInferenceHTTPRequestChunkRequest: + """A request body chunk or cancellation signal.""" + + data: str + """Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when + `binary` is true. May be empty. + """ + request_id: str + """Matches the requestId from the originating httpRequestStart frame.""" + + binary: bool | None = None + """When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text.""" + + cancel: bool | None = None + """When true, the runtime is cancelling the in-flight request (e.g. upstream consumer + aborted). `data` is ignored. Implies end-of-request. + """ + cancel_reason: str | None = None + """Optional human-readable reason for the cancellation, propagated for logging.""" + + end: bool | None = None + """When true, this is the final body chunk for the request. The SDK may rely on having + received an end-marked chunk before treating the request body as complete. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestChunkRequest': + assert isinstance(obj, dict) + data = from_str(obj.get("data")) + request_id = from_str(obj.get("requestId")) + binary = from_union([from_bool, from_none], obj.get("binary")) + cancel = from_union([from_bool, from_none], obj.get("cancel")) + cancel_reason = from_union([from_str, from_none], obj.get("cancelReason")) + end = from_union([from_bool, from_none], obj.get("end")) + return LlmInferenceHTTPRequestChunkRequest(data, request_id, binary, cancel, cancel_reason, end) + + def to_dict(self) -> dict: + result: dict = {} + result["data"] = from_str(self.data) + result["requestId"] = from_str(self.request_id) + if self.binary is not None: + result["binary"] = from_union([from_bool, from_none], self.binary) + if self.cancel is not None: + result["cancel"] = from_union([from_bool, from_none], self.cancel) + if self.cancel_reason is not None: + result["cancelReason"] = from_union([from_str, from_none], self.cancel_reason) + if self.end is not None: + result["end"] = from_union([from_bool, from_none], self.end) + return result + +@dataclass +class LlmInferenceHTTPRequestChunkResult: + """Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as + fire-and-forget. + """ + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestChunkResult': + assert isinstance(obj, dict) + return LlmInferenceHTTPRequestChunkResult() + + def to_dict(self) -> dict: + result: dict = {} + return result + +class LlmInferenceHTTPRequestStartTransport(Enum): + """Transport the runtime would otherwise use for this request. `http` (the default when + absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message + channel where each body chunk maps to one WebSocket message and the `binary` flag + distinguishes text from binary frames. The SDK consumer uses this to decide whether to + service the request with an HTTP client or a WebSocket client. It is the one piece of + request metadata the consumer cannot reliably infer from the URL or headers alone. + """ + HTTP = "http" + WEBSOCKET = "websocket" + +@dataclass +class LlmInferenceHTTPRequestStartResult: + """Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it + does not imply the request will succeed. + """ + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestStartResult': + assert isinstance(obj, dict) + return LlmInferenceHTTPRequestStartResult() + + def to_dict(self) -> dict: + result: dict = {} + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceHTTPResponseChunkError: + """Set to terminate the response with a transport-level failure. Implies end-of-stream; any + further chunks for this requestId are ignored. + """ + message: str + """Human-readable failure description.""" + + code: str | None = None + """Optional machine-readable error code.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPResponseChunkError': + assert isinstance(obj, dict) + message = from_str(obj.get("message")) + code = from_union([from_str, from_none], obj.get("code")) + return LlmInferenceHTTPResponseChunkError(message, code) + + def to_dict(self) -> dict: + result: dict = {} + result["message"] = from_str(self.message) + if self.code is not None: + result["code"] = from_union([from_str, from_none], self.code) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceHTTPResponseChunkResult: + """Whether the chunk was accepted.""" + + accepted: bool + """True when the chunk was matched to a pending request; false when unknown.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPResponseChunkResult': + assert isinstance(obj, dict) + accepted = from_bool(obj.get("accepted")) + return LlmInferenceHTTPResponseChunkResult(accepted) + + def to_dict(self) -> dict: + result: dict = {} + result["accepted"] = from_bool(self.accepted) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceHTTPResponseStartRequest: + """Response head.""" + + headers: dict[str, list[str]] + request_id: str + """Matches the requestId from the originating httpRequestStart frame.""" + + status: int + """HTTP status code.""" + + status_text: str | None = None + """Optional HTTP status reason phrase.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPResponseStartRequest': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + request_id = from_str(obj.get("requestId")) + status = from_int(obj.get("status")) + status_text = from_union([from_str, from_none], obj.get("statusText")) + return LlmInferenceHTTPResponseStartRequest(headers, request_id, status, status_text) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["requestId"] = from_str(self.request_id) + result["status"] = from_int(self.status) + if self.status_text is not None: + result["statusText"] = from_union([from_str, from_none], self.status_text) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceHTTPResponseStartResult: + """Whether the start frame was accepted.""" + + accepted: bool + """True when the response start was matched to a pending request; false when unknown.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPResponseStartResult': + assert isinstance(obj, dict) + accepted = from_bool(obj.get("accepted")) + return LlmInferenceHTTPResponseStartResult(accepted) + + def to_dict(self) -> dict: + result: dict = {} + result["accepted"] = from_bool(self.accepted) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceSetProviderResult: + """Indicates whether the calling client was registered as the LLM inference provider.""" + + success: bool + """Whether the provider was set successfully""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceSetProviderResult': + assert isinstance(obj, dict) + success = from_bool(obj.get("success")) + return LlmInferenceSetProviderResult(success) + + def to_dict(self) -> dict: + result: dict = {} + result["success"] = from_bool(self.success) + return result + # Experimental: this type is part of an experimental API and may change or be removed. class HostType(Enum): """Repository host type @@ -3136,7 +3342,10 @@ class ModelBillingTokenPricesLongContext: """Long context tier pricing (available for models with extended context windows)""" cache_price: float | None = None - """AI Credits cost per billing batch of cached tokens""" + """AI Credits cost per billing batch of cache-read tokens""" + + cache_write_price: float | None = None + """AI Credits cost per billing batch of cache-write (cache creation) tokens.""" context_max: int | None = None """Prompt token budget (max_prompt_tokens) for the long context tier. The total context @@ -3152,15 +3361,18 @@ class ModelBillingTokenPricesLongContext: def from_dict(obj: Any) -> 'ModelBillingTokenPricesLongContext': assert isinstance(obj, dict) cache_price = from_union([from_float, from_none], obj.get("cachePrice")) + cache_write_price = from_union([from_float, from_none], obj.get("cacheWritePrice")) context_max = from_union([from_int, from_none], obj.get("contextMax")) input_price = from_union([from_float, from_none], obj.get("inputPrice")) output_price = from_union([from_float, from_none], obj.get("outputPrice")) - return ModelBillingTokenPricesLongContext(cache_price, context_max, input_price, output_price) + return ModelBillingTokenPricesLongContext(cache_price, cache_write_price, context_max, input_price, output_price) def to_dict(self) -> dict: result: dict = {} if self.cache_price is not None: result["cachePrice"] = from_union([to_float, from_none], self.cache_price) + if self.cache_write_price is not None: + result["cacheWritePrice"] = from_union([to_float, from_none], self.cache_write_price) if self.context_max is not None: result["contextMax"] = from_union([from_int, from_none], self.context_max) if self.input_price is not None: @@ -10002,6 +10214,107 @@ def to_dict(self) -> dict: result["projectPath"] = from_union([from_str, from_none], self.project_path) return result +@dataclass +class LlmInferenceHTTPRequestStartRequest: + """The head of an outbound model-layer HTTP request.""" + + headers: dict[str, list[str]] + method: str + """HTTP method, e.g. GET, POST.""" + + request_id: str + """Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate + httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies + back to the runtime. + """ + url: str + """Absolute request URL.""" + + session_id: str | None = None + """Id of the runtime session that triggered this request, when one is in scope. Absent for + requests issued outside any session (e.g. startup model-catalog or capability + resolution). This is a payload field — not a dispatch key — because the client-global API + is registered process-wide rather than per session. + """ + transport: LlmInferenceHTTPRequestStartTransport | None = None + """Transport the runtime would otherwise use for this request. `http` (the default when + absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message + channel where each body chunk maps to one WebSocket message and the `binary` flag + distinguishes text from binary frames. The SDK consumer uses this to decide whether to + service the request with an HTTP client or a WebSocket client. It is the one piece of + request metadata the consumer cannot reliably infer from the URL or headers alone. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestStartRequest': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + method = from_str(obj.get("method")) + request_id = from_str(obj.get("requestId")) + url = from_str(obj.get("url")) + session_id = from_union([from_str, from_none], obj.get("sessionId")) + transport = from_union([LlmInferenceHTTPRequestStartTransport, from_none], obj.get("transport")) + return LlmInferenceHTTPRequestStartRequest(headers, method, request_id, url, session_id, transport) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["method"] = from_str(self.method) + result["requestId"] = from_str(self.request_id) + result["url"] = from_str(self.url) + if self.session_id is not None: + result["sessionId"] = from_union([from_str, from_none], self.session_id) + if self.transport is not None: + result["transport"] = from_union([lambda x: to_enum(LlmInferenceHTTPRequestStartTransport, x), from_none], self.transport) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceHTTPResponseChunkRequest: + """A response body chunk or terminal error.""" + + data: str + """Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when + `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk + with empty data and end=true). + """ + request_id: str + """Matches the requestId from the originating httpRequestStart frame.""" + + binary: bool | None = None + """When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text.""" + + end: bool | None = None + """When true, this is the final body chunk for the response. The runtime treats the response + body as complete after receiving an end-marked chunk. + """ + error: LlmInferenceHTTPResponseChunkError | None = None + """Set to terminate the response with a transport-level failure. Implies end-of-stream; any + further chunks for this requestId are ignored. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPResponseChunkRequest': + assert isinstance(obj, dict) + data = from_str(obj.get("data")) + request_id = from_str(obj.get("requestId")) + binary = from_union([from_bool, from_none], obj.get("binary")) + end = from_union([from_bool, from_none], obj.get("end")) + error = from_union([LlmInferenceHTTPResponseChunkError.from_dict, from_none], obj.get("error")) + return LlmInferenceHTTPResponseChunkRequest(data, request_id, binary, end, error) + + def to_dict(self) -> dict: + result: dict = {} + result["data"] = from_str(self.data) + result["requestId"] = from_str(self.request_id) + if self.binary is not None: + result["binary"] = from_union([from_bool, from_none], self.binary) + if self.end is not None: + result["end"] = from_union([from_bool, from_none], self.end) + if self.error is not None: + result["error"] = from_union([lambda x: to_class(LlmInferenceHTTPResponseChunkError, x), from_none], self.error) + return result + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class SessionContext: @@ -10961,7 +11274,10 @@ class ModelBillingTokenPrices: """Number of tokens per standard billing batch""" cache_price: float | None = None - """AI Credits cost per billing batch of cached tokens""" + """AI Credits cost per billing batch of cache-read tokens""" + + cache_write_price: float | None = None + """AI Credits cost per billing batch of cache-write (cache creation) tokens.""" context_max: int | None = None """Prompt token budget (max_prompt_tokens) for the default tier. The total context window is @@ -10981,11 +11297,12 @@ def from_dict(obj: Any) -> 'ModelBillingTokenPrices': assert isinstance(obj, dict) batch_size = from_union([from_int, from_none], obj.get("batchSize")) cache_price = from_union([from_float, from_none], obj.get("cachePrice")) + cache_write_price = from_union([from_float, from_none], obj.get("cacheWritePrice")) context_max = from_union([from_int, from_none], obj.get("contextMax")) input_price = from_union([from_float, from_none], obj.get("inputPrice")) long_context = from_union([ModelBillingTokenPricesLongContext.from_dict, from_none], obj.get("longContext")) output_price = from_union([from_float, from_none], obj.get("outputPrice")) - return ModelBillingTokenPrices(batch_size, cache_price, context_max, input_price, long_context, output_price) + return ModelBillingTokenPrices(batch_size, cache_price, cache_write_price, context_max, input_price, long_context, output_price) def to_dict(self) -> dict: result: dict = {} @@ -10993,6 +11310,8 @@ def to_dict(self) -> dict: result["batchSize"] = from_union([from_int, from_none], self.batch_size) if self.cache_price is not None: result["cachePrice"] = from_union([to_float, from_none], self.cache_price) + if self.cache_write_price is not None: + result["cacheWritePrice"] = from_union([to_float, from_none], self.cache_write_price) if self.context_max is not None: result["contextMax"] = from_union([from_int, from_none], self.context_max) if self.input_price is not None: @@ -13439,7 +13758,7 @@ class SessionFSSqliteQueryRequest: session_id: str """Target session identifier""" - params: dict[str, float | str | None] | None = None + params: dict[str, Any] | None = None """Optional named bind parameters""" @staticmethod @@ -13448,7 +13767,7 @@ def from_dict(obj: Any) -> 'SessionFSSqliteQueryRequest': query = from_str(obj.get("query")) query_type = SessionFSSqliteQueryType(obj.get("queryType")) session_id = from_str(obj.get("sessionId")) - params = from_union([lambda x: from_dict(lambda x: from_union([from_none, from_float, from_str], x), x), from_none], obj.get("params")) + params = from_union([lambda x: from_dict(lambda x: x, x), from_none], obj.get("params")) return SessionFSSqliteQueryRequest(query, query_type, session_id, params) def to_dict(self) -> dict: @@ -13457,7 +13776,7 @@ def to_dict(self) -> dict: result["queryType"] = to_enum(SessionFSSqliteQueryType, self.query_type) result["sessionId"] = from_str(self.session_id) if self.params is not None: - result["params"] = from_union([lambda x: from_dict(lambda x: from_union([from_none, to_float, from_str], x), x), from_none], self.params) + result["params"] = from_union([lambda x: from_dict(lambda x: x, x), from_none], self.params) return result # Experimental: this type is part of an experimental API and may change or be removed. @@ -16122,7 +16441,7 @@ def to_dict(self) -> dict: class OptionsUpdateAdditionalContentExclusionPolicy: """Schema for the `OptionsUpdateAdditionalContentExclusionPolicy` type.""" - last_updated_at: float | str + last_updated_at: Any rules: list[OptionsUpdateAdditionalContentExclusionPolicyRule] scope: AdditionalContentExclusionPolicyScope """Allowed values for the `OptionsUpdateAdditionalContentExclusionPolicyScope` enumeration.""" @@ -16130,14 +16449,14 @@ class OptionsUpdateAdditionalContentExclusionPolicy: @staticmethod def from_dict(obj: Any) -> 'OptionsUpdateAdditionalContentExclusionPolicy': assert isinstance(obj, dict) - last_updated_at = from_union([from_float, from_str], obj.get("last_updated_at")) + last_updated_at = obj.get("last_updated_at") rules = from_list(OptionsUpdateAdditionalContentExclusionPolicyRule.from_dict, obj.get("rules")) scope = AdditionalContentExclusionPolicyScope(obj.get("scope")) return OptionsUpdateAdditionalContentExclusionPolicy(last_updated_at, rules, scope) def to_dict(self) -> dict: result: dict = {} - result["last_updated_at"] = from_union([to_float, from_str], self.last_updated_at) + result["last_updated_at"] = self.last_updated_at result["rules"] = from_list(lambda x: to_class(OptionsUpdateAdditionalContentExclusionPolicyRule, x), self.rules) result["scope"] = to_enum(AdditionalContentExclusionPolicyScope, self.scope) return result @@ -16147,7 +16466,7 @@ def to_dict(self) -> dict: class PermissionsConfigureAdditionalContentExclusionPolicy: """Schema for the `PermissionsConfigureAdditionalContentExclusionPolicy` type.""" - last_updated_at: float | str + last_updated_at: Any rules: list[PermissionsConfigureAdditionalContentExclusionPolicyRule] scope: AdditionalContentExclusionPolicyScope """Allowed values for the `PermissionsConfigureAdditionalContentExclusionPolicyScope` @@ -16157,14 +16476,14 @@ class PermissionsConfigureAdditionalContentExclusionPolicy: @staticmethod def from_dict(obj: Any) -> 'PermissionsConfigureAdditionalContentExclusionPolicy': assert isinstance(obj, dict) - last_updated_at = from_union([from_float, from_str], obj.get("last_updated_at")) + last_updated_at = obj.get("last_updated_at") rules = from_list(PermissionsConfigureAdditionalContentExclusionPolicyRule.from_dict, obj.get("rules")) scope = AdditionalContentExclusionPolicyScope(obj.get("scope")) return PermissionsConfigureAdditionalContentExclusionPolicy(last_updated_at, rules, scope) def to_dict(self) -> dict: result: dict = {} - result["last_updated_at"] = from_union([to_float, from_str], self.last_updated_at) + result["last_updated_at"] = self.last_updated_at result["rules"] = from_list(lambda x: to_class(PermissionsConfigureAdditionalContentExclusionPolicyRule, x), self.rules) result["scope"] = to_enum(AdditionalContentExclusionPolicyScope, self.scope) return result @@ -16595,7 +16914,7 @@ def to_dict(self) -> dict: class SessionOpenOptionsAdditionalContentExclusionPolicy: """Schema for the `SessionOpenOptionsAdditionalContentExclusionPolicy` type.""" - last_updated_at: float | str + last_updated_at: Any rules: list[SessionOpenOptionsAdditionalContentExclusionPolicyRule] scope: AdditionalContentExclusionPolicyScope """Allowed values for the `SessionOpenOptionsAdditionalContentExclusionPolicyScope` @@ -16605,14 +16924,14 @@ class SessionOpenOptionsAdditionalContentExclusionPolicy: @staticmethod def from_dict(obj: Any) -> 'SessionOpenOptionsAdditionalContentExclusionPolicy': assert isinstance(obj, dict) - last_updated_at = from_union([from_float, from_str], obj.get("last_updated_at")) + last_updated_at = obj.get("last_updated_at") rules = from_list(SessionOpenOptionsAdditionalContentExclusionPolicyRule.from_dict, obj.get("rules")) scope = AdditionalContentExclusionPolicyScope(obj.get("scope")) return SessionOpenOptionsAdditionalContentExclusionPolicy(last_updated_at, rules, scope) def to_dict(self) -> dict: result: dict = {} - result["last_updated_at"] = from_union([to_float, from_str], self.last_updated_at) + result["last_updated_at"] = self.last_updated_at result["rules"] = from_list(lambda x: to_class(SessionOpenOptionsAdditionalContentExclusionPolicyRule, x), self.rules) result["scope"] = to_enum(AdditionalContentExclusionPolicyScope, self.scope) return result @@ -18240,7 +18559,7 @@ class SandboxConfig: add_current_working_directory: bool | None = None """Whether to auto-add the current working directory to readwritePaths. Default: true.""" - config: dict[str, Any] | None = None + config: Any = None """Raw `ContainerConfig` (per `@microsoft/mxc-sdk`) passed directly to `spawnSandboxFromConfig`, bypassing policy merging. """ @@ -18252,7 +18571,7 @@ def from_dict(obj: Any) -> 'SandboxConfig': assert isinstance(obj, dict) enabled = from_bool(obj.get("enabled")) add_current_working_directory = from_union([from_bool, from_none], obj.get("addCurrentWorkingDirectory")) - config = from_union([lambda x: from_dict(lambda x: x, x), from_none], obj.get("config")) + config = obj.get("config") user_policy = from_union([SandboxConfigUserPolicy.from_dict, from_none], obj.get("userPolicy")) return SandboxConfig(enabled, add_current_working_directory, config, user_policy) @@ -18262,7 +18581,7 @@ def to_dict(self) -> dict: if self.add_current_working_directory is not None: result["addCurrentWorkingDirectory"] = from_union([from_bool, from_none], self.add_current_working_directory) if self.config is not None: - result["config"] = from_union([lambda x: from_dict(lambda x: x, x), from_none], self.config) + result["config"] = self.config if self.user_policy is not None: result["userPolicy"] = from_union([lambda x: to_class(SandboxConfigUserPolicy, x), from_none], self.user_policy) return result @@ -19460,11 +19779,6 @@ def to_dict(self) -> dict: class MCPExecuteSamplingParams: """Identifiers and raw MCP CreateMessageRequest params used to run a sampling inference.""" - mcp_request_id: float | str - """The original MCP JSON-RPC request ID (string or number). Used by the runtime to correlate - the inference with the originating MCP request for telemetry; this is distinct from - `requestId` (which is the schema-level cancellation handle). - """ request: dict[str, Any] """Raw MCP CreateMessageRequest params, as received in the `sampling.requested` event. Treated as opaque at the schema layer; the runtime converts the embedded MCP messages @@ -19478,10 +19792,15 @@ class MCPExecuteSamplingParams: server_name: str """Name of the MCP server that initiated the sampling request""" + mcp_request_id: Any = None + """The original MCP JSON-RPC request ID (string or number). Used by the runtime to correlate + the inference with the originating MCP request for telemetry; this is distinct from + `requestId` (which is the schema-level cancellation handle). + """ @staticmethod def from_dict(obj: Any) -> 'MCPExecuteSamplingParams': assert isinstance(obj, dict) - mcp_request_id = from_union([from_float, from_str], obj.get("mcpRequestId")) + mcp_request_id = obj.get("mcpRequestId") request = from_dict(lambda x: x, obj.get("request")) request_id = from_str(obj.get("requestId")) server_name = from_str(obj.get("serverName")) @@ -19489,7 +19808,7 @@ def from_dict(obj: Any) -> 'MCPExecuteSamplingParams': def to_dict(self) -> dict: result: dict = {} - result["mcpRequestId"] = from_union([to_float, from_str], self.mcp_request_id) + result["mcpRequestId"] = self.mcp_request_id result["request"] = from_dict(lambda x: x, self.request) result["requestId"] = from_str(self.request_id) result["serverName"] = from_str(self.server_name) @@ -20264,6 +20583,18 @@ class RPC: instruction_source: InstructionSource instruction_source_location: InstructionLocation instruction_source_type: InstructionSourceType + llm_inference_headers: dict[str, list[str]] + llm_inference_http_request_chunk_request: LlmInferenceHTTPRequestChunkRequest + llm_inference_http_request_chunk_result: LlmInferenceHTTPRequestChunkResult + llm_inference_http_request_start_request: LlmInferenceHTTPRequestStartRequest + llm_inference_http_request_start_result: LlmInferenceHTTPRequestStartResult + llm_inference_http_request_start_transport: LlmInferenceHTTPRequestStartTransport + llm_inference_http_response_chunk_error: LlmInferenceHTTPResponseChunkError + llm_inference_http_response_chunk_request: LlmInferenceHTTPResponseChunkRequest + llm_inference_http_response_chunk_result: LlmInferenceHTTPResponseChunkResult + llm_inference_http_response_start_request: LlmInferenceHTTPResponseStartRequest + llm_inference_http_response_start_result: LlmInferenceHTTPResponseStartResult + llm_inference_set_provider_result: LlmInferenceSetProviderResult local_session_metadata_value: LocalSessionMetadataValue log_request: LogRequest log_result: LogResult @@ -21006,6 +21337,18 @@ def from_dict(obj: Any) -> 'RPC': instruction_source = InstructionSource.from_dict(obj.get("InstructionSource")) instruction_source_location = InstructionLocation(obj.get("InstructionSourceLocation")) instruction_source_type = InstructionSourceType(obj.get("InstructionSourceType")) + llm_inference_headers = from_dict(lambda x: from_list(from_str, x), obj.get("LlmInferenceHeaders")) + llm_inference_http_request_chunk_request = LlmInferenceHTTPRequestChunkRequest.from_dict(obj.get("LlmInferenceHttpRequestChunkRequest")) + llm_inference_http_request_chunk_result = LlmInferenceHTTPRequestChunkResult.from_dict(obj.get("LlmInferenceHttpRequestChunkResult")) + llm_inference_http_request_start_request = LlmInferenceHTTPRequestStartRequest.from_dict(obj.get("LlmInferenceHttpRequestStartRequest")) + llm_inference_http_request_start_result = LlmInferenceHTTPRequestStartResult.from_dict(obj.get("LlmInferenceHttpRequestStartResult")) + llm_inference_http_request_start_transport = LlmInferenceHTTPRequestStartTransport(obj.get("LlmInferenceHttpRequestStartTransport")) + llm_inference_http_response_chunk_error = LlmInferenceHTTPResponseChunkError.from_dict(obj.get("LlmInferenceHttpResponseChunkError")) + llm_inference_http_response_chunk_request = LlmInferenceHTTPResponseChunkRequest.from_dict(obj.get("LlmInferenceHttpResponseChunkRequest")) + llm_inference_http_response_chunk_result = LlmInferenceHTTPResponseChunkResult.from_dict(obj.get("LlmInferenceHttpResponseChunkResult")) + llm_inference_http_response_start_request = LlmInferenceHTTPResponseStartRequest.from_dict(obj.get("LlmInferenceHttpResponseStartRequest")) + llm_inference_http_response_start_result = LlmInferenceHTTPResponseStartResult.from_dict(obj.get("LlmInferenceHttpResponseStartResult")) + llm_inference_set_provider_result = LlmInferenceSetProviderResult.from_dict(obj.get("LlmInferenceSetProviderResult")) local_session_metadata_value = LocalSessionMetadataValue.from_dict(obj.get("LocalSessionMetadataValue")) log_request = LogRequest.from_dict(obj.get("LogRequest")) log_result = LogResult.from_dict(obj.get("LogResult")) @@ -21595,7 +21938,7 @@ def from_dict(obj: Any) -> 'RPC': subagent_settings = from_union([SubagentSettings.from_dict, from_none], obj.get("SubagentSettings")) task_progress = from_union([TaskProgress.from_dict, from_none], obj.get("TaskProgress")) workspace_summary = from_union([WorkspaceSummary.from_dict, from_none], obj.get("WorkspaceSummary")) - return RPC(abort_request, abort_result, account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_discovery_path, agent_discovery_path_list, agent_discovery_path_scope, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, agents_get_discovery_paths_request, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instruction_discovery_path, instruction_discovery_path_kind, instruction_discovery_path_list, instruction_discovery_path_location, instructions_discover_request, instructions_get_discovery_paths_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_defer_tools, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, named_provider_config, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_model_config, provider_session_token, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_discovery_path, skill_discovery_path_list, skill_discovery_scope, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_discovery_paths_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) + return RPC(abort_request, abort_result, account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_discovery_path, agent_discovery_path_list, agent_discovery_path_scope, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, agents_get_discovery_paths_request, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instruction_discovery_path, instruction_discovery_path_kind, instruction_discovery_path_list, instruction_discovery_path_location, instructions_discover_request, instructions_get_discovery_paths_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, llm_inference_headers, llm_inference_http_request_chunk_request, llm_inference_http_request_chunk_result, llm_inference_http_request_start_request, llm_inference_http_request_start_result, llm_inference_http_request_start_transport, llm_inference_http_response_chunk_error, llm_inference_http_response_chunk_request, llm_inference_http_response_chunk_result, llm_inference_http_response_start_request, llm_inference_http_response_start_result, llm_inference_set_provider_result, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_defer_tools, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, named_provider_config, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_model_config, provider_session_token, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_discovery_path, skill_discovery_path_list, skill_discovery_scope, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_discovery_paths_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) def to_dict(self) -> dict: result: dict = {} @@ -21748,6 +22091,18 @@ def to_dict(self) -> dict: result["InstructionSource"] = to_class(InstructionSource, self.instruction_source) result["InstructionSourceLocation"] = to_enum(InstructionLocation, self.instruction_source_location) result["InstructionSourceType"] = to_enum(InstructionSourceType, self.instruction_source_type) + result["LlmInferenceHeaders"] = from_dict(lambda x: from_list(from_str, x), self.llm_inference_headers) + result["LlmInferenceHttpRequestChunkRequest"] = to_class(LlmInferenceHTTPRequestChunkRequest, self.llm_inference_http_request_chunk_request) + result["LlmInferenceHttpRequestChunkResult"] = to_class(LlmInferenceHTTPRequestChunkResult, self.llm_inference_http_request_chunk_result) + result["LlmInferenceHttpRequestStartRequest"] = to_class(LlmInferenceHTTPRequestStartRequest, self.llm_inference_http_request_start_request) + result["LlmInferenceHttpRequestStartResult"] = to_class(LlmInferenceHTTPRequestStartResult, self.llm_inference_http_request_start_result) + result["LlmInferenceHttpRequestStartTransport"] = to_enum(LlmInferenceHTTPRequestStartTransport, self.llm_inference_http_request_start_transport) + result["LlmInferenceHttpResponseChunkError"] = to_class(LlmInferenceHTTPResponseChunkError, self.llm_inference_http_response_chunk_error) + result["LlmInferenceHttpResponseChunkRequest"] = to_class(LlmInferenceHTTPResponseChunkRequest, self.llm_inference_http_response_chunk_request) + result["LlmInferenceHttpResponseChunkResult"] = to_class(LlmInferenceHTTPResponseChunkResult, self.llm_inference_http_response_chunk_result) + result["LlmInferenceHttpResponseStartRequest"] = to_class(LlmInferenceHTTPResponseStartRequest, self.llm_inference_http_response_start_request) + result["LlmInferenceHttpResponseStartResult"] = to_class(LlmInferenceHTTPResponseStartResult, self.llm_inference_http_response_start_result) + result["LlmInferenceSetProviderResult"] = to_class(LlmInferenceSetProviderResult, self.llm_inference_set_provider_result) result["LocalSessionMetadataValue"] = to_class(LocalSessionMetadataValue, self.local_session_metadata_value) result["LogRequest"] = to_class(LogRequest, self.log_request) result["LogResult"] = to_class(LogResult, self.log_result) @@ -22565,6 +22920,7 @@ def _load_TaskInfo(obj: Any) -> "TaskInfo": FilterMapping = dict InstructionDiscoveryPathLocation = InstructionLocation InstructionSourceLocation = InstructionLocation +LlmInferenceHeaders = dict McpAppsHostContextDetailsAvailableDisplayMode = MCPAppsDisplayMode McpAppsHostContextDetailsDisplayMode = MCPAppsDisplayMode McpAppsHostContextDetailsTheme = Theme @@ -22879,6 +23235,26 @@ async def set_provider(self, params: SessionFSSetProviderRequest, *, timeout: fl return SessionFSSetProviderResult.from_dict(await self._client.request("sessionFs.setProvider", params_dict, **_timeout_kwargs(timeout))) +# Experimental: this API group is experimental and may change or be removed. +class ServerLlmInferenceApi: + def __init__(self, client: "JsonRpcClient"): + self._client = client + + async def set_provider(self, *, timeout: float | None = None) -> LlmInferenceSetProviderResult: + "Registers an SDK client as the LLM inference callback provider.\n\nReturns:\n Indicates whether the calling client was registered as the LLM inference provider." + return LlmInferenceSetProviderResult.from_dict(await self._client.request("llmInference.setProvider", {}, **_timeout_kwargs(timeout))) + + async def http_response_start(self, params: LlmInferenceHTTPResponseStartRequest, *, timeout: float | None = None) -> LlmInferenceHTTPResponseStartResult: + "Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames.\n\nArgs:\n params: Response head.\n\nReturns:\n Whether the start frame was accepted." + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return LlmInferenceHTTPResponseStartResult.from_dict(await self._client.request("llmInference.httpResponseStart", params_dict, **_timeout_kwargs(timeout))) + + async def http_response_chunk(self, params: LlmInferenceHTTPResponseChunkRequest, *, timeout: float | None = None) -> LlmInferenceHTTPResponseChunkResult: + "Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError.\n\nArgs:\n params: A response body chunk or terminal error.\n\nReturns:\n Whether the chunk was accepted." + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return LlmInferenceHTTPResponseChunkResult.from_dict(await self._client.request("llmInference.httpResponseChunk", params_dict, **_timeout_kwargs(timeout))) + + # Experimental: this API group is experimental and may change or be removed. class ServerSessionsApi: def __init__(self, client: "JsonRpcClient"): @@ -23025,6 +23401,7 @@ def __init__(self, client: "JsonRpcClient"): self.user = ServerUserApi(client) self.runtime = ServerRuntimeApi(client) self.session_fs = ServerSessionFsApi(client) + self.llm_inference = ServerLlmInferenceApi(client) self.sessions = ServerSessionsApi(client) self.agent_registry = ServerAgentRegistryApi(client) @@ -24455,6 +24832,44 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: return result.value if hasattr(result, 'value') else result client.set_request_handler("canvas.action.invoke", handle_canvas_action_invoke) +# Experimental: this API group is experimental and may change or be removed. +class LlmInferenceHandler(Protocol): + async def http_request_start(self, params: LlmInferenceHTTPRequestStartRequest) -> LlmInferenceHTTPRequestStartResult: + "Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true).\n\nArgs:\n params: The head of an outbound model-layer HTTP request.\n\nReturns:\n Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed." + pass + async def http_request_chunk(self, params: LlmInferenceHTTPRequestChunkRequest) -> LlmInferenceHTTPRequestChunkResult: + "Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set.\n\nArgs:\n params: A request body chunk or cancellation signal.\n\nReturns:\n Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget." + pass + +@dataclass +class ClientGlobalApiHandlers: + llm_inference: LlmInferenceHandler | None = None + +def register_client_global_api_handlers( + client: "JsonRpcClient", + handlers: ClientGlobalApiHandlers, +) -> None: + """Register client-global request handlers on a JSON-RPC connection. + + Unlike client-session handlers these methods carry no implicit + session_id dispatch key; a single set of handlers serves the entire + connection. + """ + async def handle_llm_inference_http_request_start(params: dict) -> dict | None: + request = LlmInferenceHTTPRequestStartRequest.from_dict(params) + handler = handlers.llm_inference + if handler is None: raise RuntimeError("No llm_inference client-global handler registered") + result = await handler.http_request_start(request) + return result.to_dict() + client.set_request_handler("llmInference.httpRequestStart", handle_llm_inference_http_request_start) + async def handle_llm_inference_http_request_chunk(params: dict) -> dict | None: + request = LlmInferenceHTTPRequestChunkRequest.from_dict(params) + handler = handlers.llm_inference + if handler is None: raise RuntimeError("No llm_inference client-global handler registered") + result = await handler.http_request_chunk(request) + return result.to_dict() + client.set_request_handler("llmInference.httpRequestChunk", handle_llm_inference_http_request_chunk) + __all__ = [ "APIKeyAuthInfo", "APIKeyAuthInfoType", @@ -24524,6 +24939,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "CanvasProviderOpenRequest", "CanvasProviderOpenResult", "CanvasSessionContext", + "ClientGlobalApiHandlers", "ClientSessionApiHandlers", "CommandList", "CommandsApi", @@ -24638,6 +25054,19 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "InstructionsGetDiscoveryPathsRequest", "InstructionsGetSourcesResult", "KindEnum", + "LlmInferenceHTTPRequestChunkRequest", + "LlmInferenceHTTPRequestChunkResult", + "LlmInferenceHTTPRequestStartRequest", + "LlmInferenceHTTPRequestStartResult", + "LlmInferenceHTTPRequestStartTransport", + "LlmInferenceHTTPResponseChunkError", + "LlmInferenceHTTPResponseChunkRequest", + "LlmInferenceHTTPResponseChunkResult", + "LlmInferenceHTTPResponseStartRequest", + "LlmInferenceHTTPResponseStartResult", + "LlmInferenceHandler", + "LlmInferenceHeaders", + "LlmInferenceSetProviderResult", "LocalSessionMetadataValue", "LogRequest", "LogResult", @@ -25036,6 +25465,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "ServerAgentsApi", "ServerInstructionSourceList", "ServerInstructionsApi", + "ServerLlmInferenceApi", "ServerMcpApi", "ServerMcpConfigApi", "ServerModelsApi", diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index 0a3c81816..237bc2795 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -996,30 +996,46 @@ def to_dict(self) -> dict: @dataclass class AttachmentBlob: "Blob attachment with inline base64-encoded data" - data: str mime_type: str type: ClassVar[str] = "blob" + asset_id: str | None = None + byte_length: int | None = None + data: str | None = None display_name: str | None = None + omitted_reason: OmittedBinaryOmittedReason | None = None @staticmethod def from_dict(obj: Any) -> "AttachmentBlob": assert isinstance(obj, dict) - data = from_str(obj.get("data")) mime_type = from_str(obj.get("mimeType")) + asset_id = from_union([from_none, from_str], obj.get("assetId")) + byte_length = from_union([from_none, from_int], obj.get("byteLength")) + data = from_union([from_none, from_str], obj.get("data")) display_name = from_union([from_none, from_str], obj.get("displayName")) + omitted_reason = from_union([from_none, lambda x: parse_enum(OmittedBinaryOmittedReason, x)], obj.get("omittedReason")) return AttachmentBlob( - data=data, mime_type=mime_type, + asset_id=asset_id, + byte_length=byte_length, + data=data, display_name=display_name, + omitted_reason=omitted_reason, ) def to_dict(self) -> dict: result: dict = {} - result["data"] = from_str(self.data) result["mimeType"] = from_str(self.mime_type) result["type"] = self.type + if self.asset_id is not None: + result["assetId"] = from_union([from_none, from_str], self.asset_id) + if self.byte_length is not None: + result["byteLength"] = from_union([from_none, to_int], self.byte_length) + if self.data is not None: + result["data"] = from_union([from_none, from_str], self.data) if self.display_name is not None: result["displayName"] = from_union([from_none, from_str], self.display_name) + if self.omitted_reason is not None: + result["omittedReason"] = from_union([from_none, lambda x: to_enum(OmittedBinaryOmittedReason, x)], self.omitted_reason) return result @@ -1098,18 +1114,30 @@ class AttachmentFile: display_name: str path: str type: ClassVar[str] = "file" + asset_id: str | None = None + byte_length: int | None = None line_range: AttachmentFileLineRange | None = None + mime_type: str | None = None + omitted_reason: OmittedBinaryOmittedReason | None = None @staticmethod def from_dict(obj: Any) -> "AttachmentFile": assert isinstance(obj, dict) display_name = from_str(obj.get("displayName")) path = from_str(obj.get("path")) + asset_id = from_union([from_none, from_str], obj.get("assetId")) + byte_length = from_union([from_none, from_int], obj.get("byteLength")) line_range = from_union([from_none, AttachmentFileLineRange.from_dict], obj.get("lineRange")) + mime_type = from_union([from_none, from_str], obj.get("mimeType")) + omitted_reason = from_union([from_none, lambda x: parse_enum(OmittedBinaryOmittedReason, x)], obj.get("omittedReason")) return AttachmentFile( display_name=display_name, path=path, + asset_id=asset_id, + byte_length=byte_length, line_range=line_range, + mime_type=mime_type, + omitted_reason=omitted_reason, ) def to_dict(self) -> dict: @@ -1117,8 +1145,16 @@ def to_dict(self) -> dict: result["displayName"] = from_str(self.display_name) result["path"] = from_str(self.path) result["type"] = self.type + if self.asset_id is not None: + result["assetId"] = from_union([from_none, from_str], self.asset_id) + if self.byte_length is not None: + result["byteLength"] = from_union([from_none, to_int], self.byte_length) if self.line_range is not None: result["lineRange"] = from_union([from_none, lambda x: to_class(AttachmentFileLineRange, x)], self.line_range) + if self.mime_type is not None: + result["mimeType"] = from_union([from_none, from_str], self.mime_type) + if self.omitted_reason is not None: + result["omittedReason"] = from_union([from_none, lambda x: to_enum(OmittedBinaryOmittedReason, x)], self.omitted_reason) return result @@ -1345,7 +1381,7 @@ class CanvasRegistryChangedCanvas: extension_id: str actions: list[CanvasRegistryChangedCanvasAction] | None = None extension_name: str | None = None - input_schema: dict[str, Any] | None = None + input_schema: Any = None @staticmethod def from_dict(obj: Any) -> "CanvasRegistryChangedCanvas": @@ -1356,7 +1392,7 @@ def from_dict(obj: Any) -> "CanvasRegistryChangedCanvas": extension_id = from_str(obj.get("extensionId")) actions = from_union([from_none, lambda x: from_list(CanvasRegistryChangedCanvasAction.from_dict, x)], obj.get("actions")) extension_name = from_union([from_none, from_str], obj.get("extensionName")) - input_schema = from_union([from_none, lambda x: from_dict(lambda x: x, x)], obj.get("inputSchema")) + input_schema = obj.get("inputSchema") return CanvasRegistryChangedCanvas( canvas_id=canvas_id, description=description, @@ -1378,7 +1414,7 @@ def to_dict(self) -> dict: if self.extension_name is not None: result["extensionName"] = from_union([from_none, from_str], self.extension_name) if self.input_schema is not None: - result["inputSchema"] = from_union([from_none, lambda x: from_dict(lambda x: x, x)], self.input_schema) + result["inputSchema"] = self.input_schema return result @@ -1387,14 +1423,14 @@ class CanvasRegistryChangedCanvasAction: "Schema for the `CanvasRegistryChangedCanvasAction` type." name: str description: str | None = None - input_schema: dict[str, Any] | None = None + input_schema: Any = None @staticmethod def from_dict(obj: Any) -> "CanvasRegistryChangedCanvasAction": assert isinstance(obj, dict) name = from_str(obj.get("name")) description = from_union([from_none, from_str], obj.get("description")) - input_schema = from_union([from_none, lambda x: from_dict(lambda x: x, x)], obj.get("inputSchema")) + input_schema = obj.get("inputSchema") return CanvasRegistryChangedCanvasAction( name=name, description=description, @@ -1407,7 +1443,7 @@ def to_dict(self) -> dict: if self.description is not None: result["description"] = from_union([from_none, from_str], self.description) if self.input_schema is not None: - result["inputSchema"] = from_union([from_none, lambda x: from_dict(lambda x: x, x)], self.input_schema) + result["inputSchema"] = self.input_schema return result @@ -2462,8 +2498,11 @@ class ModelCallFailureData: "Failed LLM API call metadata for telemetry" source: ModelCallFailureSource api_call_id: str | None = None + bad_request_kind: ModelCallFailureBadRequestKind | None = None duration: timedelta | None = None + error_code: str | None = None error_message: str | None = None + error_type: str | None = None initiator: str | None = None model: str | None = None provider_call_id: str | None = None @@ -2475,8 +2514,11 @@ def from_dict(obj: Any) -> "ModelCallFailureData": assert isinstance(obj, dict) source = parse_enum(ModelCallFailureSource, obj.get("source")) api_call_id = from_union([from_none, from_str], obj.get("apiCallId")) + bad_request_kind = from_union([from_none, lambda x: parse_enum(ModelCallFailureBadRequestKind, x)], obj.get("badRequestKind")) duration = from_union([from_none, from_timedelta], obj.get("durationMs")) + error_code = from_union([from_none, from_str], obj.get("errorCode")) error_message = from_union([from_none, from_str], obj.get("errorMessage")) + error_type = from_union([from_none, from_str], obj.get("errorType")) initiator = from_union([from_none, from_str], obj.get("initiator")) model = from_union([from_none, from_str], obj.get("model")) provider_call_id = from_union([from_none, from_str], obj.get("providerCallId")) @@ -2485,8 +2527,11 @@ def from_dict(obj: Any) -> "ModelCallFailureData": return ModelCallFailureData( source=source, api_call_id=api_call_id, + bad_request_kind=bad_request_kind, duration=duration, + error_code=error_code, error_message=error_message, + error_type=error_type, initiator=initiator, model=model, provider_call_id=provider_call_id, @@ -2499,10 +2544,16 @@ def to_dict(self) -> dict: result["source"] = to_enum(ModelCallFailureSource, self.source) if self.api_call_id is not None: result["apiCallId"] = from_union([from_none, from_str], self.api_call_id) + if self.bad_request_kind is not None: + result["badRequestKind"] = from_union([from_none, lambda x: to_enum(ModelCallFailureBadRequestKind, x)], self.bad_request_kind) if self.duration is not None: result["durationMs"] = from_union([from_none, to_timedelta_int], self.duration) + if self.error_code is not None: + result["errorCode"] = from_union([from_none, from_str], self.error_code) if self.error_message is not None: result["errorMessage"] = from_union([from_none, from_str], self.error_message) + if self.error_type is not None: + result["errorType"] = from_union([from_none, from_str], self.error_type) if self.initiator is not None: result["initiator"] = from_union([from_none, from_str], self.initiator) if self.model is not None: @@ -2940,7 +2991,7 @@ class PermissionPromptRequestMcp: server_name: str tool_name: str tool_title: str - args: Any | None = None + args: Any = None tool_call_id: str | None = None @staticmethod @@ -2949,7 +3000,7 @@ def from_dict(obj: Any) -> "PermissionPromptRequestMcp": server_name = from_str(obj.get("serverName")) tool_name = from_str(obj.get("toolName")) tool_title = from_str(obj.get("toolTitle")) - args = from_union([from_none, lambda x: x], obj.get("args")) + args = obj.get("args") tool_call_id = from_union([from_none, from_str], obj.get("toolCallId")) return PermissionPromptRequestMcp( server_name=server_name, @@ -2966,7 +3017,7 @@ def to_dict(self) -> dict: result["toolName"] = from_str(self.tool_name) result["toolTitle"] = from_str(self.tool_title) if self.args is not None: - result["args"] = from_union([from_none, lambda x: x], self.args) + result["args"] = self.args if self.tool_call_id is not None: result["toolCallId"] = from_union([from_none, from_str], self.tool_call_id) return result @@ -7357,6 +7408,14 @@ class McpServerTransport(Enum): MEMORY = "memory" +class ModelCallFailureBadRequestKind(Enum): + "For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures." + # The 400 response carried no error body (transient gateway/proxy signature). + BODYLESS = "bodyless" + # The 400 response carried a structured CAPI error envelope (deterministic validation failure). + STRUCTURED_ERROR = "structured_error" + + class ModelCallFailureSource(Enum): "Where the failed model call originated" # Model call from the top-level agent. @@ -7787,6 +7846,7 @@ def session_event_to_dict(x: SessionEvent) -> Any: "McpServerStatus", "McpServerTransport", "McpServersLoadedServer", + "ModelCallFailureBadRequestKind", "ModelCallFailureData", "ModelCallFailureSource", "OmittedBinaryOmittedReason", diff --git a/python/copilot/llm_inference_provider.py b/python/copilot/llm_inference_provider.py new file mode 100644 index 000000000..5e7af8310 --- /dev/null +++ b/python/copilot/llm_inference_provider.py @@ -0,0 +1,421 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""Low-level LLM inference provider types and the RPC adapter. + +The SDK consumer implements :class:`LlmInferenceProvider` (usually by +subclassing the idiomatic :class:`~copilot.llm_request_handler.LlmRequestHandler`). +:func:`create_llm_inference_adapter` converts a provider into an object that +conforms to the generated :class:`~copilot.generated.rpc.LlmInferenceHandler` +protocol, wiring the inbound ``httpRequestStart`` / ``httpRequestChunk`` frames +into the provider and translating the provider's response writes back into +outbound ``httpResponseStart`` / ``httpResponseChunk`` RPCs. +""" + +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + +from .generated.rpc import ( + LlmInferenceHTTPRequestChunkRequest, + LlmInferenceHTTPRequestChunkResult, + LlmInferenceHTTPRequestStartRequest, + LlmInferenceHTTPRequestStartResult, + LlmInferenceHTTPResponseChunkError, + LlmInferenceHTTPResponseChunkRequest, + LlmInferenceHTTPResponseStartRequest, + ServerLlmInferenceApi, +) + +# Headers are multi-valued: a header name maps to a list of values. +LlmInferenceHeaders = dict[str, list[str]] + + +@dataclass +class LlmInferenceResponseInit: + """Response head passed to :meth:`LlmInferenceResponseSink.start`.""" + + status: int + status_text: str | None = None + headers: LlmInferenceHeaders | None = None + + +@runtime_checkable +class LlmInferenceResponseSink(Protocol): + """Sink the consumer writes the upstream response into. + + The state machine is strict: ``start`` once, then zero or more ``write`` + calls, finishing with exactly one of ``end`` or ``error``. Calling out of + order raises. + """ + + async def start(self, init: LlmInferenceResponseInit) -> None: + """Send the response head (status + headers) back to the runtime.""" + ... + + async def write(self, data: str | bytes) -> None: + """Send a body chunk. ``str`` is encoded as UTF-8; ``bytes`` is sent as binary.""" + ... + + async def end(self) -> None: + """Mark end-of-stream cleanly.""" + ... + + async def error(self, message: str, code: str | None = None) -> None: + """Mark end-of-stream with a transport-level failure.""" + ... + + +@dataclass +class LlmInferenceRequest: + """An outbound model-layer HTTP request the runtime is asking the SDK to handle. + + This is a low-level shape: URL / method / headers verbatim, body bytes + delivered as an async iterator, response delivered through + :attr:`response_body`. The runtime does not classify the request; consumers + that need a provider type or endpoint kind derive it from the URL / headers. + """ + + request_id: str + """Opaque runtime-minted id, stable across the request lifecycle.""" + + method: str + """HTTP method (``GET``, ``POST``, ...).""" + + url: str + """Absolute URL.""" + + headers: LlmInferenceHeaders + """HTTP request headers, multi-valued.""" + + transport: str + """``"http"`` (plain HTTP / SSE) or ``"websocket"`` (full-duplex channel).""" + + request_body: AsyncIterator[bytes] + """Request body bytes, yielded as they arrive. Empty bodies yield zero chunks.""" + + cancel_event: asyncio.Event + """Set when the runtime cancels this in-flight request. Pass it through to + your transport so the upstream call is torn down too. After it fires, writes + to :attr:`response_body` are ignored.""" + + response_body: LlmInferenceResponseSink + """Sink the consumer writes the upstream response into.""" + + session_id: str | None = None + """Id of the runtime session that triggered this request, when in scope. + Absent for out-of-session requests (e.g. the startup model catalog).""" + + +@runtime_checkable +class LlmInferenceProvider(Protocol): + """Interface for an LLM inference provider. + + The consumer implements :meth:`on_llm_request`. The same callback handles + both buffered and streaming responses; the consumer just calls + ``response_body.write`` zero or more times before ``end``. + """ + + async def on_llm_request(self, request: LlmInferenceRequest) -> None: + """Service a single outbound LLM HTTP request. + + The consumer must eventually call either ``response_body.end()`` or + ``response_body.error(...)``; failing to do so leaks runtime state. + Raising surfaces a transport-level failure to the runtime. + """ + ... + + +@dataclass +class LlmInferenceConfig: + """Connection-level LLM inference callback configuration. + + Passed as the ``llm_inference`` client option. The ``handler`` is registered + process-wide and invoked for every model-layer HTTP/WebSocket request the + runtime would otherwise issue, for both BYOK and CAPI traffic. + """ + + handler: LlmInferenceProvider + + + +@dataclass +class _BodyItem: + chunk: bytes | None = None + end: bool = False + cancel: bool = False + cancel_reason: str | None = None + + +class _BodyQueue: + """An async iterator of request-body byte chunks fed by the runtime.""" + + def __init__(self) -> None: + self._queue: asyncio.Queue[_BodyItem] = asyncio.Queue() + self._done = False + + def push(self, item: _BodyItem) -> None: + self._queue.put_nowait(item) + + def __aiter__(self) -> AsyncIterator[bytes]: + return self + + async def __anext__(self) -> bytes: + if self._done: + raise StopAsyncIteration + item = await self._queue.get() + if item.cancel: + self._done = True + reason = ( + f"Request cancelled by runtime: {item.cancel_reason}" + if item.cancel_reason + else "Request cancelled by runtime" + ) + raise RuntimeError(reason) + if item.end: + self._done = True + raise StopAsyncIteration + return item.chunk if item.chunk is not None else b"" + + +@dataclass +class _PendingState: + queue: _BodyQueue + cancel_event: asyncio.Event + started: bool = False + finished: bool = False + cancelled: bool = False + task: asyncio.Task[None] | None = field(default=None) + + +def _decode_chunk_data(data: str, binary: bool) -> bytes: + if binary: + return base64.b64decode(data) + return data.encode("utf-8") + + +class _RuntimeRejectedError(RuntimeError): + """Raised when the runtime drops an in-flight request (``accepted: False``).""" + + +def create_llm_inference_adapter( + provider: LlmInferenceProvider, + get_server_rpc: Callable[[], ServerLlmInferenceApi | None], +) -> "_LlmInferenceAdapter": + """Adapt an :class:`LlmInferenceProvider` into the generated handler shape. + + Maintains a per-``request_id`` state table: each ``http_request_start`` + allocates a body queue + response sink and fires ``provider.on_llm_request`` + in the background. Subsequent ``http_request_chunk`` frames are routed into + the queue. The sink translates ``start`` / ``write`` / ``end`` / ``error`` + calls into outbound ``httpResponseStart`` / ``httpResponseChunk`` RPCs. + + ``http_request_start`` returns immediately after registering state so the + runtime's RPC reply is not gated on the consumer's I/O. + """ + return _LlmInferenceAdapter(provider, get_server_rpc) + + +class _LlmInferenceAdapter: + def __init__( + self, + provider: LlmInferenceProvider, + get_server_rpc: Callable[[], ServerLlmInferenceApi | None], + ) -> None: + self._provider = provider + self._get_server_rpc = get_server_rpc + self._pending: dict[str, _PendingState] = {} + # Defense-in-depth backstop: chunks that arrive before their start frame + # (a reordering the runtime's single ordered dispatch should make + # impossible) are staged here and drained the moment the matching + # http_request_start registers state, so a body byte is never dropped. + self._staged: dict[str, list[LlmInferenceHTTPRequestChunkRequest]] = {} + + def _route_chunk(self, state: _PendingState, params: LlmInferenceHTTPRequestChunkRequest) -> None: + if params.cancel: + state.cancelled = True + state.cancel_event.set() + state.queue.push(_BodyItem(cancel=True, cancel_reason=params.cancel_reason)) + return + if params.data: + state.queue.push(_BodyItem(chunk=_decode_chunk_data(params.data, bool(params.binary)))) + if params.end: + state.queue.push(_BodyItem(end=True)) + + def _require_rpc(self) -> ServerLlmInferenceApi: + rpc = self._get_server_rpc() + if rpc is None: + raise RuntimeError("LLM inference response sink used after RPC connection closed.") + return rpc + + def _make_sink(self, request_id: str, state: _PendingState) -> LlmInferenceResponseSink: + adapter = self + + def reject() -> None: + # The runtime acknowledges every response frame with ``accepted``. + # ``accepted: False`` means it has dropped the request, so we abort + # the provider's upstream work and stop emitting. + if not state.cancelled: + state.cancelled = True + state.cancel_event.set() + state.finished = True + adapter._pending.pop(request_id, None) + raise _RuntimeRejectedError( + "LLM inference response was rejected by the runtime (request no longer active)." + ) + + class _Sink: + async def start(self, init: LlmInferenceResponseInit) -> None: + if state.started: + raise RuntimeError("LLM inference response sink.start() called twice.") + if state.finished: + raise RuntimeError("LLM inference response sink already finished.") + state.started = True + result = await adapter._require_rpc().http_response_start( + LlmInferenceHTTPResponseStartRequest( + headers=init.headers or {}, + request_id=request_id, + status=init.status, + status_text=init.status_text, + ) + ) + if not result.accepted: + reject() + + async def write(self, data: str | bytes) -> None: + if state.cancelled: + raise RuntimeError("LLM inference request was cancelled by the runtime.") + if not state.started: + raise RuntimeError("LLM inference response sink.write() called before start().") + if state.finished: + raise RuntimeError("LLM inference response sink.write() called after end()/error().") + is_binary = isinstance(data, bytes | bytearray) + payload = ( + base64.b64encode(bytes(data)).decode("ascii") + if is_binary + else str(data) + ) + result = await adapter._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest( + data=payload, + request_id=request_id, + binary=is_binary or None, + end=False, + ) + ) + if not result.accepted: + reject() + + async def end(self) -> None: + if state.finished: + return + state.finished = True + adapter._pending.pop(request_id, None) + await adapter._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest(data="", request_id=request_id, end=True) + ) + + async def error(self, message: str, code: str | None = None) -> None: + if state.finished: + return + state.finished = True + adapter._pending.pop(request_id, None) + await adapter._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest( + data="", + request_id=request_id, + end=True, + error=LlmInferenceHTTPResponseChunkError(message=message, code=code), + ) + ) + + return _Sink() + + async def _fail_via_sink( + self, sink: LlmInferenceResponseSink, state: _PendingState, message: str + ) -> None: + if state.finished: + return + try: + if not state.started: + await sink.start(LlmInferenceResponseInit(status=502)) + await sink.error(message) + except Exception: + # Best-effort — the connection may already be dead. + pass + + async def _finish_cancelled(self, sink: LlmInferenceResponseSink, state: _PendingState) -> None: + if state.finished: + return + try: + if not state.started: + await sink.start(LlmInferenceResponseInit(status=499)) + await sink.error("Request cancelled by runtime", code="cancelled") + except Exception: + # Best-effort — the runtime already dropped the request on cancel. + pass + + async def _run_provider( + self, request: LlmInferenceRequest, sink: LlmInferenceResponseSink, state: _PendingState + ) -> None: + try: + await self._provider.on_llm_request(request) + if not state.finished: + await self._fail_via_sink( + sink, + state, + "LLM inference provider returned without finalising the response " + "(call response_body.end() or .error()).", + ) + except _RuntimeRejectedError: + # The runtime already dropped the request; nothing more to emit. + pass + except Exception as exc: + if state.cancelled or state.cancel_event.is_set(): + await self._finish_cancelled(sink, state) + return + await self._fail_via_sink(sink, state, str(exc)) + + async def http_request_start( + self, params: LlmInferenceHTTPRequestStartRequest + ) -> LlmInferenceHTTPRequestStartResult: + state = _PendingState(queue=_BodyQueue(), cancel_event=asyncio.Event()) + self._pending[params.request_id] = state + + staged = self._staged.pop(params.request_id, None) + if staged: + for chunk in staged: + self._route_chunk(state, chunk) + + sink = self._make_sink(params.request_id, state) + transport = ( + params.transport.value if params.transport is not None else "http" + ) + request = LlmInferenceRequest( + request_id=params.request_id, + session_id=params.session_id, + method=params.method, + url=params.url, + headers=params.headers, + transport=transport, + request_body=state.queue, + cancel_event=state.cancel_event, + response_body=sink, + ) + state.task = asyncio.create_task(self._run_provider(request, sink, state)) + return LlmInferenceHTTPRequestStartResult() + + async def http_request_chunk( + self, params: LlmInferenceHTTPRequestChunkRequest + ) -> LlmInferenceHTTPRequestChunkResult: + state = self._pending.get(params.request_id) + if state is None: + self._staged.setdefault(params.request_id, []).append(params) + return LlmInferenceHTTPRequestChunkResult() + self._route_chunk(state, params) + return LlmInferenceHTTPRequestChunkResult() diff --git a/python/copilot/llm_request_handler.py b/python/copilot/llm_request_handler.py new file mode 100644 index 000000000..775110ff3 --- /dev/null +++ b/python/copilot/llm_request_handler.py @@ -0,0 +1,415 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""Idiomatic, httpx-based base class for servicing LLM inference requests. + +Most consumers subclass :class:`LlmRequestHandler` and override a single seam: + +* HTTP — override :meth:`LlmRequestHandler.send_request` to mutate the + :class:`httpx.Request`, post-process the :class:`httpx.Response`, or replace + the call entirely. The default forwards via a shared :class:`httpx.AsyncClient`. +* WebSocket — override :meth:`LlmRequestHandler.open_web_socket` to return a + per-connection :class:`CopilotWebSocketHandler`. The default opens a + transparent forwarding connection. + +Consumers who need full control can instead override +:meth:`LlmRequestHandler.on_llm_request` and drive the low-level +:class:`~copilot.llm_inference_provider.LlmInferenceRequest` directly. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from .llm_inference_provider import ( + LlmInferenceHeaders, + LlmInferenceProvider, + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, +) + +if TYPE_CHECKING: + import httpx + + +# Hop-by-hop and length headers the transport recomputes; forwarding them +# verbatim corrupts the request. +_FORBIDDEN_REQUEST_HEADERS = frozenset( + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + } +) + +_shared_http_client: "httpx.AsyncClient | None" = None + + +def _get_shared_http_client() -> "httpx.AsyncClient": + global _shared_http_client + if _shared_http_client is None: + import httpx + + _shared_http_client = httpx.AsyncClient(timeout=None, follow_redirects=False) + return _shared_http_client + + +@dataclass +class LlmRequestContext: + """Per-request context handed to every :class:`LlmRequestHandler` hook.""" + + request_id: str + transport: str + url: str + headers: LlmInferenceHeaders + cancel_event: asyncio.Event + session_id: str | None = None + _bridge: "_LlmWebSocketResponseBridge | None" = field(default=None, repr=False) + + +@dataclass +class LlmWebSocketCloseStatus: + """Terminal status for a callback-owned WebSocket connection.""" + + description: str | None = None + error_code: str | None = None + error: BaseException | None = None + + @classmethod + def normal_closure(cls) -> "LlmWebSocketCloseStatus": + return cls() + + +class CopilotWebSocketHandler: + """Per-connection WebSocket handler returned by :meth:`LlmRequestHandler.open_web_socket`. + + Subclass and override :meth:`send_request_message` (runtime → upstream) to + mutate, drop, or inject messages, and :meth:`send_response_message` + (upstream → runtime) for the reverse direction. A full transport + replacement overrides :meth:`open` to stand up its own connection and + receive loop. + """ + + def __init__(self, context: LlmRequestContext) -> None: + bridge = context._bridge + if bridge is None: + raise RuntimeError("WebSocket response bridge is not attached") + self.context = context + self._response = bridge + self._completion: asyncio.Future[LlmWebSocketCloseStatus] = ( + asyncio.get_event_loop().create_future() + ) + self._closed = False + self._suppress_close_on_dispose = False + + async def send_response_message(self, data: str | bytes) -> None: + """Forward an upstream message to the runtime response.""" + await self._response.write(data) + + async def send_request_message(self, data: str | bytes) -> None: + """Forward a runtime message to the upstream connection. Override to mutate.""" + raise NotImplementedError + + async def close(self, status: LlmWebSocketCloseStatus | None = None) -> None: + """Initiate close: end the runtime response and resolve completion.""" + if self._closed: + return + self._closed = True + status = status or LlmWebSocketCloseStatus.normal_closure() + if status.error is not None: + await self._response.error( + status.description or str(status.error), status.error_code + ) + else: + await self._response.end() + if not self._completion.done(): + self._completion.set_result(status) + + async def open(self) -> None: + """Establish the connection. Default is a no-op for custom transports.""" + + async def aclose(self) -> None: + """Final resource cleanup; closes normally if not already closed.""" + if not self._suppress_close_on_dispose and not self._closed: + await self.close(LlmWebSocketCloseStatus.normal_closure()) + + +class ForwardingWebSocketHandler(CopilotWebSocketHandler): + """Default pass-through WebSocket handler backed by the ``websockets`` library.""" + + def __init__(self, context: LlmRequestContext, url: str | None = None) -> None: + super().__init__(context) + self._url = url or context.url + self._upstream: Any | None = None + self._receive_task: asyncio.Task[None] | None = None + + async def send_request_message(self, data: str | bytes) -> None: + if self._upstream is None: + return + await self._upstream.send(data) + + async def open(self) -> None: + if self._upstream is not None: + return + try: + import websockets + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "WebSocket forwarding requires the 'websockets' package. " + "Install it or override open_web_socket()." + ) from exc + + headers = [ + (name, value) + for name, values in self.context.headers.items() + if name.lower() not in _FORBIDDEN_REQUEST_HEADERS + for value in (values or []) + ] + self._upstream = await websockets.connect(self._url, additional_headers=headers) + self._receive_task = asyncio.create_task(self._receive_loop()) + + async def _receive_loop(self) -> None: + try: + async for message in self._upstream: # type: ignore[union-attr] + await self.send_response_message(message) + await self.close(LlmWebSocketCloseStatus.normal_closure()) + except asyncio.CancelledError: + raise + except Exception as exc: + await self.close(LlmWebSocketCloseStatus(description=str(exc), error=exc)) + + async def close(self, status: LlmWebSocketCloseStatus | None = None) -> None: + if self._upstream is not None: + try: + await self._upstream.close() + except Exception: + # Best-effort; the socket may already be closed. + pass + await super().close(status) + + async def aclose(self) -> None: + try: + await super().aclose() + finally: + if self._receive_task is not None: + self._receive_task.cancel() + if self._upstream is not None: + try: + await self._upstream.close() + except Exception: + pass + + +class LlmRequestHandler(LlmInferenceProvider): + """Base class for consumers that observe or replace LLM inference requests.""" + + async def on_llm_request(self, request: LlmInferenceRequest) -> None: + bridge = _LlmWebSocketResponseBridge(request.response_body) + ctx = LlmRequestContext( + request_id=request.request_id, + session_id=request.session_id, + transport=request.transport, + url=request.url, + headers=request.headers, + cancel_event=request.cancel_event, + _bridge=bridge, + ) + if request.transport == "websocket": + await self._handle_web_socket(request, ctx) + else: + await self._handle_http(request, ctx) + + async def send_request(self, request: "httpx.Request", ctx: LlmRequestContext) -> "httpx.Response": + """Send an HTTP request. Override to mutate request/response or replace the call.""" + return await _get_shared_http_client().send(request, stream=True) + + async def open_web_socket(self, ctx: LlmRequestContext) -> CopilotWebSocketHandler: + """Open a per-connection WebSocket handler. Override to mutate or replace.""" + return ForwardingWebSocketHandler(ctx) + + async def _handle_http(self, req: LlmInferenceRequest, ctx: LlmRequestContext) -> None: + request = await _build_httpx_request(req) + await _run_cancellable( + self._forward_http(request, req, ctx), req.cancel_event + ) + + async def _forward_http( + self, request: "httpx.Request", req: LlmInferenceRequest, ctx: LlmRequestContext + ) -> None: + response = await self.send_request(request, ctx) + try: + await _stream_response_to_sink(response, req) + finally: + await response.aclose() + + async def _handle_web_socket(self, req: LlmInferenceRequest, ctx: LlmRequestContext) -> None: + handler = await self.open_web_socket(ctx) + assert ctx._bridge is not None + try: + await handler.open() + await ctx._bridge.start() + + async def pump_client() -> str: + async for chunk in req.request_body: + await handler.send_request_message(_decode_frame(chunk)) + return "client-complete" + + client_task = asyncio.create_task(pump_client()) + completion = asyncio.ensure_future(handler._completion) + done, _ = await asyncio.wait( + {client_task, completion}, return_when=asyncio.FIRST_COMPLETED + ) + + if client_task in done and client_task.exception() is not None: + handler._suppress_close_on_dispose = True + raise client_task.exception() # type: ignore[misc] + + if client_task in done: + await handler.close(LlmWebSocketCloseStatus.normal_closure()) + await handler._completion + return + + status = await handler._completion + if status.error is not None: + raise status.error + finally: + await handler.aclose() + + +async def _run_cancellable(coro: Any, cancel_event: asyncio.Event) -> None: + """Run ``coro`` but abort it (and raise) when ``cancel_event`` fires.""" + task = asyncio.ensure_future(coro) + waiter = asyncio.ensure_future(cancel_event.wait()) + try: + done, _ = await asyncio.wait( + {task, waiter}, return_when=asyncio.FIRST_COMPLETED + ) + if task in done: + exc = task.exception() + if exc is not None: + raise exc + return + # Cancellation fired first. + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + raise RuntimeError("Request cancelled by runtime") + finally: + if not waiter.done(): + waiter.cancel() + + +async def _build_httpx_request(req: LlmInferenceRequest) -> "httpx.Request": + import httpx + + header_pairs = [ + (name, value) + for name, values in req.headers.items() + if name.lower() not in _FORBIDDEN_REQUEST_HEADERS + for value in (values or []) + ] + method = req.method.upper() + has_body = method not in ("GET", "HEAD") + body = await _drain_async(req.request_body) + content = body if (has_body and body) else None + return httpx.Request(method, req.url, headers=header_pairs, content=content) + + +async def _drain_async(stream: AsyncIterator[bytes]) -> bytes: + parts: list[bytes] = [] + async for chunk in stream: + if chunk: + parts.append(chunk) + return b"".join(parts) + + +async def _stream_response_to_sink(response: "httpx.Response", req: LlmInferenceRequest) -> None: + await req.response_body.start( + LlmInferenceResponseInit( + status=response.status_code, + status_text=response.reason_phrase or None, + headers=_headers_to_multi_map(response.headers), + ) + ) + async for chunk in response.aiter_raw(): + if chunk: + await req.response_body.write(chunk) + await req.response_body.end() + + +def _headers_to_multi_map(headers: Any) -> LlmInferenceHeaders: + out: dict[str, list[str]] = {} + for name, value in headers.multi_items(): + out.setdefault(name, []).append(value) + return out + + +def _decode_frame(chunk: bytes) -> str: + return chunk.decode("utf-8", errors="replace") + + +class _LlmWebSocketResponseBridge: + """Serialises WebSocket response writes into the sink, buffering until start.""" + + def __init__(self, sink: LlmInferenceResponseSink) -> None: + self._sink = sink + self._pending: list[Any] = [] + self._started = False + self._completed = False + self._lock = asyncio.Lock() + + async def start(self) -> None: + async with self._lock: + if self._started: + return + self._started = True + await self._sink.start(LlmInferenceResponseInit(status=101, headers={})) + pending = self._pending + self._pending = [] + for action in pending: + await action() + + async def write(self, data: str | bytes) -> None: + async def action() -> None: + if not self._completed: + await self._sink.write(data) + + await self._enqueue_or_buffer(action) + + async def end(self) -> None: + async def action() -> None: + if self._completed: + return + self._completed = True + await self._sink.end() + + await self._enqueue_or_buffer(action) + + async def error(self, message: str, code: str | None = None) -> None: + async def action() -> None: + if self._completed: + return + self._completed = True + await self._sink.error(message, code) + + await self._enqueue_or_buffer(action) + + async def _enqueue_or_buffer(self, action: Any) -> None: + if not self._started: + self._pending.append(action) + return + async with self._lock: + await action() diff --git a/python/e2e/_llm_inference_helpers.py b/python/e2e/_llm_inference_helpers.py new file mode 100644 index 000000000..c19d5ba0f --- /dev/null +++ b/python/e2e/_llm_inference_helpers.py @@ -0,0 +1,320 @@ +"""Shared fixtures and synthetic-upstream helpers for the LLM inference +callback e2e tests. + +The ``llm_inference*`` tests have no recorded snapshots: the registered +callback fabricates well-formed model responses and the runtime routes all of +its model-layer HTTP/WebSocket traffic through that callback instead of the +CAPI proxy. These helpers centralise the synthetic CAPI shapes (model catalog, +policy, ``/responses`` SSE, ``/chat/completions``) so each test file can focus +on the behaviour it is exercising. + +The leading underscore keeps pytest from collecting this module as a test. +""" + +from __future__ import annotations + +import json +import os +import re + +import pytest_asyncio + +from copilot import ( + CopilotClient, + LlmInferenceConfig, + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmRequestHandler, + RuntimeConnection, +) +from copilot.generated.session_events import AssistantMessageData + +from .testharness import E2ETestContext + +SYNTHETIC_TEXT = "OK from the synthetic stream." + + +def sse(event: str, data: dict) -> str: + """Frame a single Server-Sent Events message: ``event:``/``data:`` + blank line.""" + return f"event: {event}\ndata: {json.dumps(data)}\n\n" + + +def stream_true(body_text: str) -> bool: + return re.search(r'"stream"\s*:\s*true', body_text) is not None + + +def is_inference_url(url: str) -> bool: + u = url.lower() + return ( + u.endswith("/chat/completions") + or u.endswith("/responses") + or u.endswith("/v1/messages") + or u.endswith("/messages") + ) + + +def model_catalog(supported_endpoints: list[str] | None = None) -> dict: + """The synthetic ``/models`` catalog payload. + + Passing ``supported_endpoints=["/responses", "ws:/responses"]`` lets the + runtime pick the WebSocket Responses transport (when the matching ExP flag + is enabled). + """ + model: dict = { + "id": "claude-sonnet-4.5", + "name": "Claude Sonnet 4.5", + "object": "model", + "vendor": "Anthropic", + "version": "1", + "preview": False, + "model_picker_enabled": True, + "capabilities": { + "type": "chat", + "family": "claude-sonnet-4.5", + "tokenizer": "o200k_base", + "limits": {"max_context_window_tokens": 200000, "max_output_tokens": 8192}, + "supports": { + "streaming": True, + "tool_calls": True, + "parallel_tool_calls": True, + "vision": True, + }, + }, + } + if supported_endpoints is not None: + model["supported_endpoints"] = supported_endpoints + return {"data": [model]} + + +def responses_events(text: str, resp_id: str = "resp_stub_1") -> list[dict]: + """The ordered ``/responses`` event objects the runtime's reducer expects. + + Used raw (one object == one WebSocket message) for the WS path and + SSE-framed for the HTTP path. + """ + return [ + { + "type": "response.created", + "response": {"id": resp_id, "object": "response", "status": "in_progress", "output": []}, + }, + { + "type": "response.output_item.added", + "output_index": 0, + "item": {"id": "msg_1", "type": "message", "role": "assistant", "content": []}, + }, + { + "type": "response.content_part.added", + "output_index": 0, + "content_index": 0, + "part": {"type": "output_text", "text": ""}, + }, + {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": text}, + {"type": "response.output_text.done", "output_index": 0, "content_index": 0, "text": text}, + { + "type": "response.completed", + "response": { + "id": resp_id, + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + ], + "usage": {"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + }, + }, + ] + + +async def drain_request(req: LlmInferenceRequest) -> str: + parts: list[bytes] = [] + async for chunk in req.request_body: + parts.append(chunk) + return b"".join(parts).decode("utf-8") + + +async def respond_buffered( + req: LlmInferenceRequest, status: int, headers: dict[str, list[str]], body: str +) -> None: + await drain_request(req) + await req.response_body.start(LlmInferenceResponseInit(status=status, headers=headers)) + if body: + await req.response_body.write(body) + await req.response_body.end() + + +async def service_non_inference(req: LlmInferenceRequest) -> bool: + """Serve the model catalog, model session and policy endpoints. + + Returns ``True`` when the request was one of those (and has been answered), + ``False`` otherwise so the caller can decide how to handle it. + """ + url = req.url.lower() + if url.endswith("/models"): + await respond_buffered( + req, 200, {"content-type": ["application/json"]}, json.dumps(model_catalog()) + ) + return True + if "/models/session" in url: + await respond_buffered(req, 200, {}, "{}") + return True + if "/policy" in url: + await respond_buffered(req, 200, {}, json.dumps({"state": "enabled"})) + return True + return False + + +async def handle_non_inference_model_traffic( + req: LlmInferenceRequest, supported_endpoints: list[str] | None = None +) -> None: + """Serve every non-inference model-layer request, including an empty-JSON + fallback for anything unrecognised.""" + url = req.url.lower() + if url.endswith("/models"): + await respond_buffered( + req, + 200, + {"content-type": ["application/json"]}, + json.dumps(model_catalog(supported_endpoints)), + ) + return + if "/models/session" in url: + await respond_buffered(req, 200, {}, "{}") + return + if "/policy" in url: + await respond_buffered(req, 200, {}, json.dumps({"state": "enabled"})) + return + await respond_buffered(req, 200, {"content-type": ["application/json"]}, "{}") + + +async def handle_inference(req: LlmInferenceRequest, text: str = SYNTHETIC_TEXT) -> None: + """Synthesize a well-formed inference response. + + Dispatches by URL and the request body's ``stream`` flag: ``/responses`` + streams an SSE event sequence (or returns a buffered Responses object when + ``stream`` is false), ``/chat/completions`` streams chat-completion chunks + (or returns a buffered completion). The unified callback carries no field + telling the consumer which code path the runtime took, so it dispatches by + URL exactly as a real reverse proxy would. + """ + body_text = await drain_request(req) + wants_stream = stream_true(body_text) + url = req.url.lower() + + if "/responses" in url: + if not wants_stream: + await req.response_body.start( + LlmInferenceResponseInit(status=200, headers={"content-type": ["application/json"]}) + ) + await req.response_body.write(json.dumps(responses_events(text)[-1]["response"])) + await req.response_body.end() + return + await req.response_body.start( + LlmInferenceResponseInit(status=200, headers={"content-type": ["text/event-stream"]}) + ) + for event in responses_events(text): + await req.response_body.write(sse(event["type"], event)) + await req.response_body.end() + return + + if "/chat/completions" in url and wants_stream: + await req.response_body.start( + LlmInferenceResponseInit(status=200, headers={"content-type": ["text/event-stream"]}) + ) + base = { + "id": "chatcmpl-stub-1", + "object": "chat.completion.chunk", + "created": 1, + "model": "claude-sonnet-4.5", + } + chunks = [ + {**base, "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}]}, + {**base, "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}]}, + { + **base, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + }, + ] + for chunk in chunks: + await req.response_body.write("data: " + json.dumps(chunk) + "\n\n") + await req.response_body.write("data: [DONE]\n\n") + await req.response_body.end() + return + + await req.response_body.start( + LlmInferenceResponseInit(status=200, headers={"content-type": ["application/json"]}) + ) + await req.response_body.write( + json.dumps( + { + "id": "chatcmpl-stub-1", + "object": "chat.completion", + "created": 1, + "model": "claude-sonnet-4.5", + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + } + ) + ) + await req.response_body.end() + + +def assistant_text(event) -> str: + if event is not None and isinstance(event.data, AssistantMessageData): + return event.data.content + return "" + + +def build_isolated_client( + ctx: E2ETestContext, + handler: LlmRequestHandler, + extra_env: dict[str, str] | None = None, +) -> CopilotClient: + """Build a CopilotClient wired to ``handler`` via ``LlmInferenceConfig``. + + The shared ``ctx`` fixture's client has no inference callback, so each + inference test owns an isolated client carrying its own handler. + ``extra_env`` is merged into the spawned runtime's environment (e.g. to + flip an ExP flag for the WebSocket transport). + """ + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + env = ctx.get_env() + if extra_env: + env = {**env, **extra_env} + return CopilotClient( + connection=RuntimeConnection.for_stdio(path=ctx.cli_path), + working_directory=ctx.work_dir, + env=env, + github_token=github_token, + llm_inference=LlmInferenceConfig(handler=handler), + ) + + +def isolated_client_fixture(make_handler, extra_env: dict[str, str] | None = None): + """Build a module-scoped pytest-asyncio fixture yielding ``(client, handler)``. + + ``make_handler`` is a zero-arg callable returning a fresh handler instance. + """ + + @pytest_asyncio.fixture(loop_scope="module") + async def _fixture(ctx: E2ETestContext): + handler = make_handler() + client = build_isolated_client(ctx, handler, extra_env) + try: + yield client, handler + finally: + try: + await client.stop() + except Exception: + pass + + return _fixture diff --git a/python/e2e/test_llm_inference_cancel_e2e.py b/python/e2e/test_llm_inference_cancel_e2e.py new file mode 100644 index 000000000..5a9c68310 --- /dev/null +++ b/python/e2e/test_llm_inference_cancel_e2e.py @@ -0,0 +1,86 @@ +"""E2E test for the runtime → consumer cancellation path. + +Mirrors ``nodejs/test/e2e/llm_inference_cancel.e2e.test.ts``. When an in-flight +turn is aborted via ``session.abort()``, the runtime cancels the +callback-served inference request; the consumer observes ``req.cancel_event`` +firing so it can tear down its upstream call. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + drain_request, + is_inference_url, + isolated_client_fixture, + respond_buffered, + service_non_inference, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +async def _wait_for(predicate, timeout_s: float) -> None: + loop = asyncio.get_event_loop() + start = loop.time() + while not predicate(): + if loop.time() - start > timeout_s: + raise TimeoutError("wait_for timed out") + await asyncio.sleep(0.05) + + +class _CancellingHandler(LlmRequestHandler): + def __init__(self) -> None: + self.inference_entered = False + self.saw_abort = False + self.abort_seen = asyncio.Event() + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + if await service_non_inference(req): + return + if not is_inference_url(req.url): + await respond_buffered(req, 200, {"content-type": ["application/json"]}, "{}") + return + + # Inference: never produce a response. Wait for the runtime to cancel + # us, recording the abort. + await drain_request(req) + self.inference_entered = True + await req.cancel_event.wait() + self.saw_abort = True + self.abort_seen.set() + try: + await req.response_body.error("cancelled by upstream", code="cancelled") + except Exception: + # Runtime already dropped the request on cancel. + pass + + +cancel_client = isolated_client_fixture(_CancellingHandler) + + +class TestLlmInferenceCancel: + async def test_propagates_runtime_cancellation_to_consumer(self, cancel_client): + client, handler = cancel_client + await client.start() + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + try: + await session.send("Say OK.") + await _wait_for(lambda: handler.inference_entered, 60.0) + await session.abort() + await asyncio.wait_for(handler.abort_seen.wait(), timeout=30.0) + finally: + await session.disconnect() + + # The consumer observed the runtime-driven cancellation. + assert handler.inference_entered is True + assert handler.saw_abort is True diff --git a/python/e2e/test_llm_inference_consumer_cancel_e2e.py b/python/e2e/test_llm_inference_consumer_cancel_e2e.py new file mode 100644 index 000000000..8b5e2c167 --- /dev/null +++ b/python/e2e/test_llm_inference_consumer_cancel_e2e.py @@ -0,0 +1,71 @@ +"""E2E test for the consumer → runtime cancellation path. + +Mirrors ``nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts``. When the +consumer itself aborts the upstream call, it signals the runtime via +``response_body.error(code="cancelled")``. The runtime must surface that +faithfully as a request failure rather than hanging waiting for a response. +""" + +from __future__ import annotations + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + drain_request, + is_inference_url, + isolated_client_fixture, + respond_buffered, + service_non_inference, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class _ConsumerCancelHandler(LlmRequestHandler): + def __init__(self) -> None: + self.inference_attempts = 0 + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + if await service_non_inference(req): + return + if not is_inference_url(req.url): + await respond_buffered(req, 200, {"content-type": ["application/json"]}, "{}") + return + + # Consumer-initiated cancellation: the consumer's own upstream call was + # aborted, so it tells the runtime to give up on this request. No + # response head is ever produced; the runtime should see a transport + # failure rather than hanging. + await drain_request(req) + self.inference_attempts += 1 + await req.response_body.error("upstream call aborted by consumer", code="cancelled") + + +consumer_cancel_client = isolated_client_fixture(_ConsumerCancelHandler) + + +class TestLlmInferenceConsumerCancel: + async def test_surfaces_consumer_signalled_cancellation(self, consumer_cancel_client): + client, handler = consumer_cancel_client + await client.start() + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + + caught: BaseException | None = None + try: + await session.send_and_wait("Say OK.") + except BaseException as err: # noqa: BLE001 + caught = err + finally: + await session.disconnect() + + # The runtime reached the inference step and the consumer's + # cancellation terminated it (rather than the runtime hanging). + assert handler.inference_attempts > 0 + if caught is not None: + assert len(str(caught)) > 0 diff --git a/python/e2e/test_llm_inference_e2e.py b/python/e2e/test_llm_inference_e2e.py new file mode 100644 index 000000000..1a2b739a3 --- /dev/null +++ b/python/e2e/test_llm_inference_e2e.py @@ -0,0 +1,73 @@ +"""E2E tests for the LLM inference callback (basic round-trip). + +Mirrors ``nodejs/test/e2e/llm_inference.e2e.test.ts``. The handler fabricates +synthetic model responses, so the runtime routes its model-layer HTTP through +the SDK callback instead of the CAPI proxy. No recorded snapshot is needed. +""" + +from __future__ import annotations + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + handle_non_inference_model_traffic, + isolated_client_fixture, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class _RecordingHandler(LlmRequestHandler): + def __init__(self) -> None: + self.received: list[LlmInferenceRequest] = [] + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + self.received.append(req) + await handle_non_inference_model_traffic(req) + + +llm_client = isolated_client_fixture(_RecordingHandler) + + +class TestLlmInferenceCallback: + async def test_registers_the_provider_on_connect_without_erroring(self, llm_client): + client, _ = llm_client + await client.start() + assert client is not None + + async def test_invokes_callback_for_model_layer_requests_and_threads_session_id( + self, llm_client + ): + client, handler = llm_client + await client.start() + baseline = len(handler.received) + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + try: + # The buffered handler returns empty JSON for inference, which is + # not a valid model response; swallow the resulting transport error. + # What we assert is that the runtime *attempted* the callback. + try: + await session.send_and_wait("Say OK.") + except Exception: + pass + finally: + await session.disconnect() + + assert len(handler.received) > baseline + new_requests = handler.received[baseline:] + for r in new_requests: + assert r.url.startswith("http://") or r.url.startswith("https://") + assert isinstance(r.method, str) + + catalog = next((r for r in new_requests if r.url.lower().endswith("/models")), None) + assert catalog is not None, "expected to intercept the /models catalog request" + + in_session = next((r for r in new_requests if isinstance(r.session_id, str)), None) + if in_session is not None: + assert in_session.session_id diff --git a/python/e2e/test_llm_inference_errors_e2e.py b/python/e2e/test_llm_inference_errors_e2e.py new file mode 100644 index 000000000..63b5bfac6 --- /dev/null +++ b/python/e2e/test_llm_inference_errors_e2e.py @@ -0,0 +1,75 @@ +"""E2E test asserting callback-raised errors surface to the SDK consumer as +transport failures. + +Mirrors ``nodejs/test/e2e/llm_inference_errors.e2e.test.ts``. The handler +services the model catalog / session / policy normally so the agent reaches the +inference step, then raises from the inference callback. The adapter converts +that into a terminal ``http_response_chunk`` carrying ``error``, so the runtime +surfaces it through its existing error machinery rather than hanging. +""" + +from __future__ import annotations + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + drain_request, + isolated_client_fixture, + respond_buffered, + service_non_inference, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class _ThrowingHandler(LlmRequestHandler): + def __init__(self) -> None: + self.total_calls = 0 + self.calls_before_error = 0 + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + self.total_calls += 1 + url = req.url.lower() + + if await service_non_inference(req): + return + + if "/chat/completions" in url or "/responses" in url: + await drain_request(req) + self.calls_before_error += 1 + raise RuntimeError("synthetic-callback-transport-failure") + + await respond_buffered(req, 200, {"content-type": ["application/json"]}, "{}") + + +errors_client = isolated_client_fixture(_ThrowingHandler) + + +class TestLlmInferenceErrors: + async def test_surfaces_callback_thrown_error_to_consumer(self, errors_client): + client, handler = errors_client + await client.start() + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + + caught: BaseException | None = None + try: + await session.send_and_wait("Say OK.") + except BaseException as err: # noqa: BLE001 + caught = err + finally: + await session.disconnect() + + # The agent layer typically wraps inference failures in its own error + # type and may convert them to an event rather than a thrown exception, + # so the assertion is loose: the inference call was attempted at least + # once and the runtime did NOT hang. + assert handler.total_calls > 0 + assert handler.calls_before_error > 0 + if caught is not None: + assert len(str(caught)) > 0 diff --git a/python/e2e/test_llm_inference_handler_e2e.py b/python/e2e/test_llm_inference_handler_e2e.py new file mode 100644 index 000000000..6b3da99cf --- /dev/null +++ b/python/e2e/test_llm_inference_handler_e2e.py @@ -0,0 +1,271 @@ +"""E2E test for the idiomatic ``LlmRequestHandler`` forwarding seams. + +Mirrors ``nodejs/test/e2e/llm_inference_handler.e2e.test.ts``. A single handler +subclass services BOTH transports against a per-test fake upstream: + +* HTTP — :meth:`send_request` rewrites the request to the local HTTP upstream, + mutates an outbound and a response header, and forwards via httpx. +* WebSocket — :meth:`open_web_socket` rewrites the URL to the local WebSocket + upstream and returns a forwarding handler that counts messages in both + directions. + +Unlike the other inference tests (which fabricate responses inline), this one +exercises the default httpx / ``websockets`` forwarding machinery against a +real socket, proving the full chain runtime → handler → upstream → handler → +runtime is intact for whichever transport the agent turn selects. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import threading +from dataclasses import dataclass, field +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +import httpx +import pytest +import pytest_asyncio +import websockets +from websockets.asyncio.server import serve as ws_serve + +from copilot import ( + CopilotClient, + ForwardingWebSocketHandler, + LlmInferenceConfig, + LlmRequestContext, + LlmRequestHandler, + RuntimeConnection, +) +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import assistant_text, model_catalog, responses_events +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + +HTTP_TEXT = "OK from synthetic HTTP upstream." +WS_TEXT = "OK from synthetic WS upstream." + + +@dataclass +class _Counters: + http_requests: int = 0 + http_responses: int = 0 + ws_request_messages: int = 0 + ws_response_messages: int = 0 + + +@dataclass +class _Upstream: + http_url: str + ws_url: str + _http_server: ThreadingHTTPServer + _http_thread: threading.Thread + _ws_server: object + ws_requests: list[int] = field(default_factory=lambda: [0]) + + @property + def ws_request_count(self) -> int: + return self.ws_requests[0] + + async def close(self) -> None: + self._http_server.shutdown() + self._http_thread.join(timeout=5) + self._http_server.server_close() + self._ws_server.close() # type: ignore[attr-defined] + await self._ws_server.wait_closed() # type: ignore[attr-defined] + + +def _sse_body(text: str, resp_id: str) -> bytes: + out = "".join( + f"event: {event['type']}\ndata: {json.dumps(event)}\n\n" + for event in responses_events(text, resp_id) + ) + return out.encode("utf-8") + + +async def _start_fake_upstream() -> _Upstream: + class _Handler(BaseHTTPRequestHandler): + def log_message(self, *_args): # noqa: ANN002 - silence default logging + pass + + def _send(self, status: int, content_type: str, body: bytes) -> None: + self.send_response(status) + self.send_header("content-type", content_type) + self.send_header("content-length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def _route(self) -> None: + path = self.path.split("?", 1)[0].lower() + length = int(self.headers.get("content-length") or 0) + if length: + self.rfile.read(length) + if path.endswith("/models"): + self._send( + 200, + "application/json", + json.dumps( + model_catalog(supported_endpoints=["/responses", "ws:/responses"]) + ).encode("utf-8"), + ) + return + if path.endswith("/models/session"): + self._send(200, "application/json", b"{}") + return + if "/policy" in path: + self._send(200, "application/json", json.dumps({"state": "enabled"}).encode("utf-8")) + return + if path.endswith("/responses"): + self._send(200, "text/event-stream", _sse_body(HTTP_TEXT, "resp_stub_http")) + return + self._send(404, "application/json", json.dumps({"error": "not_found", "path": path}).encode("utf-8")) + + def do_GET(self): # noqa: N802 + self._route() + + def do_POST(self): # noqa: N802 + self._route() + + http_server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler) + http_port = http_server.server_address[1] + http_thread = threading.Thread(target=http_server.serve_forever, daemon=True) + http_thread.start() + + ws_requests = [0] + + async def ws_handler(connection) -> None: + async for _raw in connection: + ws_requests[0] += 1 + for event in responses_events(WS_TEXT, "resp_stub_ws"): + await connection.send(json.dumps(event)) + + ws_server = await ws_serve(ws_handler, "127.0.0.1", 0) + ws_port = ws_server.sockets[0].getsockname()[1] + + return _Upstream( + http_url=f"http://127.0.0.1:{http_port}", + ws_url=f"ws://127.0.0.1:{ws_port}", + _http_server=http_server, + _http_thread=http_thread, + _ws_server=ws_server, + ws_requests=ws_requests, + ) + + +class _CountingSocketHandler(ForwardingWebSocketHandler): + """Forwarding WebSocket handler that counts messages in both directions.""" + + def __init__(self, ctx: LlmRequestContext, url: str, counters: _Counters) -> None: + super().__init__(ctx, url=url) + self._counters = counters + + async def send_request_message(self, data: str | bytes) -> None: + self._counters.ws_request_messages += 1 + await super().send_request_message(data) + + async def send_response_message(self, data: str | bytes) -> None: + self._counters.ws_response_messages += 1 + await super().send_response_message(data) + + +class _TestHandler(LlmRequestHandler): + def __init__(self, upstream: _Upstream, counters: _Counters) -> None: + self._upstream = upstream + self._counters = counters + self._client = httpx.AsyncClient(timeout=None, follow_redirects=False) + + def _rewrite_http(self, url: httpx.URL) -> httpx.URL: + up = httpx.URL(self._upstream.http_url) + return url.copy_with(scheme=up.scheme, host=up.host, port=up.port) + + def _rewrite_ws(self, url: str) -> str: + parsed = httpx.URL(url) + up = httpx.URL(self._upstream.ws_url) + return str(parsed.copy_with(scheme=up.scheme, host=up.host, port=up.port)) + + async def send_request(self, request: httpx.Request, ctx: LlmRequestContext) -> httpx.Response: + self._counters.http_requests += 1 + headers = dict(request.headers) + headers["x-test-mutated"] = "1" + rewritten = httpx.Request( + request.method, + self._rewrite_http(request.url), + headers=headers, + content=request.content, + ) + response = await self._client.send(rewritten, stream=True) + self._counters.http_responses += 1 + response.headers["x-test-response-mutated"] = "1" + return response + + async def open_web_socket(self, ctx: LlmRequestContext): + return _CountingSocketHandler(ctx, self._rewrite_ws(ctx.url), self._counters) + + async def aclose(self) -> None: + await self._client.aclose() + + +@dataclass +class _HandlerFixture: + client: CopilotClient + upstream: _Upstream + counters: _Counters + + +@pytest_asyncio.fixture(loop_scope="module") +async def handler_fixture(ctx: E2ETestContext): + upstream = await _start_fake_upstream() + counters = _Counters() + handler = _TestHandler(upstream, counters) + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + env = {**ctx.get_env(), "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES": "true"} + client = CopilotClient( + connection=RuntimeConnection.for_stdio(path=ctx.cli_path), + working_directory=ctx.work_dir, + env=env, + github_token=github_token, + llm_inference=LlmInferenceConfig(handler=handler), + ) + try: + yield _HandlerFixture(client=client, upstream=upstream, counters=counters) + finally: + try: + await client.stop() + except Exception: + pass + await handler.aclose() + await upstream.close() + + +class TestLlmInferenceHandler: + async def test_services_http_and_websocket_via_one_handler(self, handler_fixture): + fx = handler_fixture + await fx.client.start() + session = await fx.client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + # The HTTP seam fired — the runtime issued model-layer GETs (catalog, + # policy) and possibly a single-shot inference through send_request. + assert fx.counters.http_requests > 0, "expected send_request to fire" + assert fx.counters.http_responses > 0, "expected send_request response mutation to fire" + + # The WebSocket seam fired — the main agent turn went over the WS path + # and we observed messages in both directions. + assert fx.counters.ws_request_messages > 0, "expected runtime → upstream ws messages" + assert fx.counters.ws_response_messages > 0, "expected upstream → runtime ws messages" + assert fx.upstream.ws_request_count > 0, "expected upstream WS to receive request messages" + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from synthetic" in text and "upstream" in text diff --git a/python/e2e/test_llm_inference_session_id_e2e.py b/python/e2e/test_llm_inference_session_id_e2e.py new file mode 100644 index 000000000..35dbfea83 --- /dev/null +++ b/python/e2e/test_llm_inference_session_id_e2e.py @@ -0,0 +1,115 @@ +"""E2E tests asserting the runtime threads its session id into the LLM +inference callback for both CAPI and BYOK sessions. + +Mirrors ``nodejs/test/e2e/llm_inference_session_id.e2e.test.ts``. The callback +alone services every model-layer request (no upstream server, no CAPI proxy +acting as the inference endpoint), so the only source of ``req.session_id`` is +the runtime's own per-client threading. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + assistant_text, + handle_inference, + handle_non_inference_model_traffic, + is_inference_url, + isolated_client_fixture, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +@dataclass +class _InterceptedRequest: + url: str + session_id: str | None + + +class _SessionIdHandler(LlmRequestHandler): + def __init__(self) -> None: + self.records: list[_InterceptedRequest] = [] + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + self.records.append(_InterceptedRequest(url=req.url, session_id=req.session_id)) + if is_inference_url(req.url): + await handle_inference(req) + else: + await handle_non_inference_model_traffic(req) + + +session_id_client = isolated_client_fixture(_SessionIdHandler) + + +class TestLlmInferenceSessionId: + capi_session_id: str | None = None + + async def test_threads_session_id_into_capi_session(self, session_id_client): + client, handler = session_id_client + await client.start() + baseline = len(handler.records) + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + TestLlmInferenceSessionId.capi_session_id = session.session_id + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + inference = [r for r in handler.records[baseline:] if is_inference_url(r.url)] + assert len(inference) > 0, "expected at least one intercepted inference request" + for r in inference: + assert r.session_id == session.session_id, ( + "CAPI inference request must carry the runtime session id" + ) + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic" in text + + async def test_threads_session_id_into_byok_session(self, session_id_client): + client, handler = session_id_client + await client.start() + baseline = len(handler.records) + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + model="claude-sonnet-4.5", + provider={ + "type": "openai", + "wire_api": "responses", + "base_url": "https://byok.invalid/v1", + "api_key": "byok-secret", + "model_id": "claude-sonnet-4.5", + "wire_model": "claude-sonnet-4.5", + }, + ) + byok_session_id = session.session_id + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + inference = [r for r in handler.records[baseline:] if is_inference_url(r.url)] + assert len(inference) > 0, "expected at least one intercepted BYOK inference request" + for r in inference: + assert r.session_id == byok_session_id, ( + "BYOK inference request must carry the runtime session id" + ) + + # Session ids are per-session, so the two turns must differ. + assert byok_session_id != TestLlmInferenceSessionId.capi_session_id + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic" in text diff --git a/python/e2e/test_llm_inference_stream_e2e.py b/python/e2e/test_llm_inference_stream_e2e.py new file mode 100644 index 000000000..e08a6a752 --- /dev/null +++ b/python/e2e/test_llm_inference_stream_e2e.py @@ -0,0 +1,62 @@ +"""E2E test for the LLM inference callback over a fully-mocked streaming +response. + +Mirrors ``nodejs/test/e2e/llm_inference_stream.e2e.test.ts``. The callback +services every model-layer request and answers the inference call with a +chunked SSE event stream; the test asserts the synthetic content surfaces in +the assistant turn. +""" + +from __future__ import annotations + +import pytest + +from copilot import LlmInferenceRequest, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + assistant_text, + handle_inference, + handle_non_inference_model_traffic, + is_inference_url, + isolated_client_fixture, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class _StreamingHandler(LlmRequestHandler): + def __init__(self) -> None: + self.received: list[LlmInferenceRequest] = [] + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + self.received.append(req) + if is_inference_url(req.url): + await handle_inference(req) + else: + await handle_non_inference_model_traffic(req) + + +stream_client = isolated_client_fixture(_StreamingHandler) + + +class TestLlmInferenceStream: + async def test_completes_a_turn_via_chunked_sse_response(self, stream_client): + client, handler = stream_client + await client.start() + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + inference = [r for r in handler.received if is_inference_url(r.url)] + assert len(inference) > 0, "expected at least one inference request via the callback" + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic" in text diff --git a/python/e2e/test_llm_inference_websocket_e2e.py b/python/e2e/test_llm_inference_websocket_e2e.py new file mode 100644 index 000000000..16473aefa --- /dev/null +++ b/python/e2e/test_llm_inference_websocket_e2e.py @@ -0,0 +1,108 @@ +"""E2E test for the LLM inference callback over the full-duplex WebSocket +transport. + +Mirrors ``nodejs/test/e2e/llm_inference_websocket.e2e.test.ts``. The fake model +catalog advertises ``/responses`` and ``ws:/responses`` so the runtime selects +the Responses wire API and is allowed to pick the WebSocket transport (the ExP +flag is enabled via the env var below). The handler services the WS channel by +answering each inbound ``response.create`` message with the ordered +``/responses`` event objects — one event per outbound WS message, raw JSON +(not SSE-framed). +""" + +from __future__ import annotations + +import json + +import pytest + +from copilot import LlmInferenceRequest, LlmInferenceResponseInit, LlmRequestHandler +from copilot.session import PermissionHandler + +from ._llm_inference_helpers import ( + assistant_text, + drain_request, + handle_non_inference_model_traffic, + is_inference_url, + isolated_client_fixture, + responses_events, +) +from .testharness import E2ETestContext # noqa: F401 (ctx fixture dependency) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + +WS_TEXT = "OK from the synthetic ws." + + +async def _handle_http_inference(req: LlmInferenceRequest) -> None: + """Synthesize the ``/responses`` SSE stream for single-shot HTTP inference + requests (e.g. title generation) that don't pick the WebSocket transport.""" + await drain_request(req) + await req.response_body.start( + LlmInferenceResponseInit(status=200, headers={"content-type": ["text/event-stream"]}) + ) + for event in responses_events(WS_TEXT, "resp_stub_ws_1"): + await req.response_body.write(f"event: {event['type']}\ndata: {json.dumps(event)}\n\n") + await req.response_body.end() + + +class _WebSocketHandler(LlmRequestHandler): + def __init__(self) -> None: + self.received: list[LlmInferenceRequest] = [] + self.ws_request_count = 0 + + async def _handle_web_socket(self, req: LlmInferenceRequest) -> None: + # Ack the upgrade (status 101-equivalent) before any message flows. + await req.response_body.start(LlmInferenceResponseInit(status=101, headers={})) + try: + # One inbound chunk == one WS message (a `response.create` request). + async for _outbound in req.request_body: + self.ws_request_count += 1 + for event in responses_events(WS_TEXT, "resp_stub_ws_1"): + await req.response_body.write(json.dumps(event)) + except Exception: + # Expected: the runtime cancels the request body when it closes the + # socket at session teardown. Nothing more to do. + pass + + async def on_llm_request(self, req: LlmInferenceRequest) -> None: + self.received.append(req) + if req.transport == "websocket": + await self._handle_web_socket(req) + return + if is_inference_url(req.url): + await _handle_http_inference(req) + else: + await handle_non_inference_model_traffic( + req, supported_endpoints=["/responses", "ws:/responses"] + ) + + +ws_client = isolated_client_fixture( + _WebSocketHandler, + extra_env={"COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES": "true"}, +) + + +class TestLlmInferenceWebSocket: + async def test_completes_a_turn_over_the_websocket_transport(self, ws_client): + client, handler = ws_client + await client.start() + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + # The main agent turn (tools present, not single-shot) selected the + # WebSocket transport and drove it through the callback. + ws_reqs = [r for r in handler.received if r.transport == "websocket"] + assert len(ws_reqs) > 0, "expected at least one websocket request via the callback" + assert handler.ws_request_count > 0, "expected the runtime to send at least one ws message" + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic ws" in text diff --git a/python/pyproject.toml b/python/pyproject.toml index 596e07be2..ea15b2d71 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ dependencies = [ "python-dateutil>=2.9.0.post0", "pydantic>=2.0", + "httpx>=0.24.0", ] [project.urls] @@ -41,7 +42,7 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-timeout>=2.0.0", - "httpx>=0.24.0", + "websockets>=12.0", "opentelemetry-sdk>=1.0.0", ] telemetry = [ diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs index cd352a7fc..d85fd3924 100644 --- a/rust/src/generated/api_types.rs +++ b/rust/src/generated/api_types.rs @@ -89,6 +89,12 @@ pub mod rpc_methods { pub const RUNTIME_SHUTDOWN: &str = "runtime.shutdown"; /// `sessionFs.setProvider` pub const SESSIONFS_SETPROVIDER: &str = "sessionFs.setProvider"; + /// `llmInference.setProvider` + pub const LLMINFERENCE_SETPROVIDER: &str = "llmInference.setProvider"; + /// `llmInference.httpResponseStart` + pub const LLMINFERENCE_HTTPRESPONSESTART: &str = "llmInference.httpResponseStart"; + /// `llmInference.httpResponseChunk` + pub const LLMINFERENCE_HTTPRESPONSECHUNK: &str = "llmInference.httpResponseChunk"; /// `sessions.open` pub const SESSIONS_OPEN: &str = "sessions.open"; /// `sessions.fork` @@ -1335,13 +1341,23 @@ pub struct ApiKeyAuthInfo { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AttachmentBlob { - /// Base64-encoded content - pub data: String, + /// Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub asset_id: Option, + /// Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub byte_length: Option, + /// Base64-encoded content. Present on input and for external consumers; replaced by an internal `assetId` reference in persisted events when interned to a content-addressed asset. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, /// User-facing display name for the attachment #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, /// MIME type of the inline data pub mime_type: String, + /// Internal: why model-facing bytes are absent from persistence. Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub omitted_reason: Option, /// Attachment type discriminator pub r#type: AttachmentBlobType, } @@ -1423,11 +1439,23 @@ pub struct AttachmentFileLineRange { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AttachmentFile { + /// Internal: content-addressed id of the session.binary_asset event holding this attachment's model-facing bytes (e.g. "sha256:..."). Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub asset_id: Option, + /// Internal: decoded byte length of the attachment's model-facing bytes. Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub byte_length: Option, /// User-facing display name for the attachment pub display_name: String, /// Optional line range to scope the attachment to a specific section of the file #[serde(skip_serializing_if = "Option::is_none")] pub line_range: Option, + /// Internal: MIME type of the file's model-facing bytes (post-resize for images). Set when the file's bytes are interned to an asset. Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Internal: why model-facing bytes are absent from persistence. Absent externally. + #[serde(skip_serializing_if = "Option::is_none")] + pub omitted_reason: Option, /// Absolute file path pub path: String, /// Attachment type discriminator @@ -3295,6 +3323,167 @@ pub struct InstructionsGetSourcesResult { pub sources: Vec, } +/// A request body chunk or cancellation signal. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestChunkRequest { + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + #[serde(skip_serializing_if = "Option::is_none")] + pub binary: Option, + /// When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + #[serde(skip_serializing_if = "Option::is_none")] + pub cancel: Option, + /// Optional human-readable reason for the cancellation, propagated for logging. + #[serde(skip_serializing_if = "Option::is_none")] + pub cancel_reason: Option, + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + pub data: String, + /// When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + #[serde(skip_serializing_if = "Option::is_none")] + pub end: Option, + /// Matches the requestId from the originating httpRequestStart frame. + pub request_id: RequestId, +} + +/// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestChunkResult {} + +/// The head of an outbound model-layer HTTP request. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestStartRequest { + pub headers: HashMap>, + /// HTTP method, e.g. GET, POST. + pub method: String, + /// Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + pub request_id: RequestId, + /// Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + #[serde(skip_serializing_if = "Option::is_none")] + pub transport: Option, + /// Absolute request URL. + pub url: String, +} + +/// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestStartResult {} + +/// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpResponseChunkError { + /// Optional machine-readable error code. + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + /// Human-readable failure description. + pub message: String, +} + +/// A response body chunk or terminal error. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpResponseChunkRequest { + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + #[serde(skip_serializing_if = "Option::is_none")] + pub binary: Option, + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + pub data: String, + /// When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + #[serde(skip_serializing_if = "Option::is_none")] + pub end: Option, + /// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Matches the requestId from the originating httpRequestStart frame. + pub request_id: RequestId, +} + +/// Whether the chunk was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpResponseChunkResult { + /// True when the chunk was matched to a pending request; false when unknown. + pub accepted: bool, +} + +/// Response head. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpResponseStartRequest { + pub headers: HashMap>, + /// Matches the requestId from the originating httpRequestStart frame. + pub request_id: RequestId, + /// HTTP status code. + pub status: i64, + /// Optional HTTP status reason phrase. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_text: Option, +} + +/// Whether the start frame was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpResponseStartResult { + /// True when the response start was matched to a pending request; false when unknown. + pub accepted: bool, +} + +/// Indicates whether the calling client was registered as the LLM inference provider. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceSetProviderResult { + /// Whether the provider was set successfully + pub success: bool, +} + /// Pre-resolved working-directory context for session startup. /// ///
@@ -4882,9 +5071,12 @@ pub struct MetadataSnapshotRemoteMetadata { #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ModelBillingTokenPricesLongContext { - /// AI Credits cost per billing batch of cached tokens + /// AI Credits cost per billing batch of cache-read tokens #[serde(skip_serializing_if = "Option::is_none")] pub cache_price: Option, + /// AI Credits cost per billing batch of cache-write (cache creation) tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_price: Option, /// Prompt token budget (max_prompt_tokens) for the long context tier. The total context window is this value plus the model's max_output_tokens. #[serde(skip_serializing_if = "Option::is_none")] pub context_max: Option, @@ -4903,9 +5095,12 @@ pub struct ModelBillingTokenPrices { /// Number of tokens per standard billing batch #[serde(skip_serializing_if = "Option::is_none")] pub batch_size: Option, - /// AI Credits cost per billing batch of cached tokens + /// AI Credits cost per billing batch of cache-read tokens #[serde(skip_serializing_if = "Option::is_none")] pub cache_price: Option, + /// AI Credits cost per billing batch of cache-write (cache creation) tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_price: Option, /// Prompt token budget (max_prompt_tokens) for the default tier. The total context window is this value plus the model's max_output_tokens. #[serde(skip_serializing_if = "Option::is_none")] pub context_max: Option, @@ -8293,7 +8488,7 @@ pub struct SandboxConfig { pub add_current_working_directory: Option, /// Raw `ContainerConfig` (per `@microsoft/mxc-sdk`) passed directly to `spawnSandboxFromConfig`, bypassing policy merging. #[serde(skip_serializing_if = "Option::is_none")] - pub config: Option>, + pub config: Option, /// Whether sandboxing is enabled for the session. pub enabled: bool, /// User-managed sandbox policy fragment merged into the auto-discovered base policy. @@ -16431,6 +16626,16 @@ pub struct CanvasOpenResult { pub url: Option, } +/// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+pub type LlmInferenceHeaders = HashMap>; + /// MCP CreateMessageResult payload (with optional 'tools' extension), present when action='success'. Treated as opaque at the schema layer; consumers should construct/consume it per the MCP CreateMessageResult shape. /// ///
@@ -16789,6 +16994,28 @@ pub enum ApiKeyAuthInfoType { ApiKey, } +/// Why the binary data is absent: it exceeded the inline size limit, or its asset was unavailable +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum OmittedBinaryOmittedReason { + /// Bytes exceeded the session's inline size limit. + #[serde(rename = "too_large")] + TooLarge, + /// The referenced binary asset could not be found (e.g. a truncated log). + #[serde(rename = "asset_unavailable")] + AssetUnavailable, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// Attachment type discriminator #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum AttachmentBlobType { @@ -17405,6 +17632,21 @@ pub enum InstructionSourceType { Unknown, } +/// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum LlmInferenceHttpRequestStartTransport { + /// Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. + #[serde(rename = "http")] + Http, + /// Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. + #[serde(rename = "websocket")] + Websocket, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// Repository host type /// ///
diff --git a/rust/src/generated/rpc.rs b/rust/src/generated/rpc.rs index 82319dbbf..40af0a10e 100644 --- a/rust/src/generated/rpc.rs +++ b/rust/src/generated/rpc.rs @@ -49,6 +49,13 @@ impl<'a> ClientRpc<'a> { } } + /// `llmInference.*` sub-namespace. + pub fn llm_inference(&self) -> ClientRpcLlmInference<'a> { + ClientRpcLlmInference { + client: self.client, + } + } + /// `mcp.*` sub-namespace. pub fn mcp(&self) -> ClientRpcMcp<'a> { ClientRpcMcp { @@ -386,6 +393,106 @@ impl<'a> ClientRpcInstructions<'a> { } } +/// `llmInference.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcLlmInference<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcLlmInference<'a> { + /// Registers an SDK client as the LLM inference callback provider. + /// + /// Wire method: `llmInference.setProvider`. + /// + /// # Returns + /// + /// Indicates whether the calling client was registered as the LLM inference provider. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn set_provider(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::LLMINFERENCE_SETPROVIDER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames. + /// + /// Wire method: `llmInference.httpResponseStart`. + /// + /// # Parameters + /// + /// * `params` - Response head. + /// + /// # Returns + /// + /// Whether the start frame was accepted. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn http_response_start( + &self, + params: LlmInferenceHttpResponseStartRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call( + rpc_methods::LLMINFERENCE_HTTPRESPONSESTART, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError. + /// + /// Wire method: `llmInference.httpResponseChunk`. + /// + /// # Parameters + /// + /// * `params` - A response body chunk or terminal error. + /// + /// # Returns + /// + /// Whether the chunk was accepted. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn http_response_chunk( + &self, + params: LlmInferenceHttpResponseChunkRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call( + rpc_methods::LLMINFERENCE_HTTPRESPONSECHUNK, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + /// `mcp.*` RPCs. #[derive(Clone, Copy)] pub struct ClientRpcMcp<'a> { diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs index 639b466d1..7fdd48ac2 100644 --- a/rust/src/generated/session_events.rs +++ b/rust/src/generated/session_events.rs @@ -1513,12 +1513,21 @@ pub struct ModelCallFailureData { /// Completion ID from the model provider (e.g., chatcmpl-abc123) #[serde(skip_serializing_if = "Option::is_none")] pub api_call_id: Option, + /// For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. + #[serde(skip_serializing_if = "Option::is_none")] + pub bad_request_kind: Option, /// Duration of the failed API call in milliseconds #[serde(skip_serializing_if = "Option::is_none")] pub duration_ms: Option, + /// For HTTP 400 failures only: the `code` from the CAPI error envelope (e.g. 'model_max_prompt_tokens_exceeded') identifying which deterministic validation failure occurred. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + #[serde(skip_serializing_if = "Option::is_none")] + pub error_code: Option, /// Raw provider/runtime error message for restricted telemetry #[serde(skip_serializing_if = "Option::is_none")] pub error_message: Option, + /// For HTTP 400 failures only: the `type` from the CAPI error envelope (e.g. 'websocket_error'), a coarser companion to errorCode for envelopes that carry no code. Raw server-controlled string, emitted only through restricted telemetry. Absent for bodyless or non-400 failures. + #[serde(skip_serializing_if = "Option::is_none")] + pub error_type: Option, /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls #[serde(skip_serializing_if = "Option::is_none")] pub initiator: Option, @@ -3372,7 +3381,7 @@ pub struct CanvasRegistryChangedCanvasAction { pub description: Option, /// JSON Schema for action input #[serde(skip_serializing_if = "Option::is_none")] - pub input_schema: Option>, + pub input_schema: Option, /// Action name pub name: String, } @@ -3397,7 +3406,7 @@ pub struct CanvasRegistryChangedCanvas { pub extension_name: Option, /// JSON Schema for canvas open input #[serde(skip_serializing_if = "Option::is_none")] - pub input_schema: Option>, + pub input_schema: Option, } /// Session event "session.canvas.registry_changed". @@ -3708,6 +3717,21 @@ pub enum AssistantUsageApiEndpoint { Unknown, } +/// For HTTP 400 failures only: whether the response carried a structured CAPI error envelope (structured_error, a deterministic validation failure) or no error body (bodyless, the transient gateway/proxy signature). Absent for non-400 failures. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelCallFailureBadRequestKind { + /// The 400 response carried no error body (transient gateway/proxy signature). + #[serde(rename = "bodyless")] + Bodyless, + /// The 400 response carried a structured CAPI error envelope (deterministic validation failure). + #[serde(rename = "structured_error")] + StructuredError, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// Where the failed model call originated #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum ModelCallFailureSource { diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index 7b0f64a77..e1ceea5b1 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -2297,6 +2297,142 @@ function emitClientSessionApiRegistration(clientSchema: Record, return lines; } +/** + * Emit C# handler interfaces + a process-wide registration for client + * *global* API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `RegisterClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record, classes: string[]): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const { methods } of groups) { + for (const method of methods) { + const resultSchema = getMethodResultSchema(method); + if (!isVoidSchema(resultSchema) && !isOpaqueJson(resultSchema)) { + emitRpcResultType(resultTypeName(method), resultSchema!, "public", classes); + } + + const effectiveParams = resolveMethodParamsSchema(method); + if (effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0) { + const paramsClass = emitRpcClass(paramsTypeName(method), effectiveParams, "public", classes); + if (paramsClass) classes.push(paramsClass); + } + } + } + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + const groupDeprecated = isNodeFullyDeprecated(groupNode); + lines.push(`/// Handles \`${groupName}\` client global API methods.`); + if (groupExperimental) { + pushExperimentalAttribute(lines); + } + if (groupDeprecated) { + pushObsoleteAttributes(lines); + } + lines.push(`public interface ${interfaceName}`); + lines.push(`{`); + for (const method of methods) { + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const taskType = resultTaskType(method); + pushRpcMethodXmlDocs( + lines, + method, + " ", + [ + ...(hasParams ? [{ name: "request", description: rpcParamsDescription(method, effectiveParams) }] : []), + { name: "cancellationToken", description: CANCELLATION_TOKEN_DESCRIPTION, escapeDescription: false }, + ], + resultSchema, + `Handles "${method.rpcMethod}".` + ); + if (method.stability === "experimental" && !groupExperimental) { + pushExperimentalAttribute(lines, " "); + } + if (method.deprecated && !groupDeprecated) { + pushObsoleteAttributes(lines, " "); + } + if (hasParams) { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(${paramsTypeName(method)} request, CancellationToken cancellationToken = default);`); + } else { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(CancellationToken cancellationToken = default);`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/// Provides all client global API handler groups for a connection.`); + lines.push(`public sealed class ClientGlobalApiHandlers`); + lines.push(`{`); + for (const { groupName } of groups) { + lines.push(` /// Optional handler for ${toPascalCase(groupName)} client global API methods.`); + lines.push(` public ${clientHandlerInterfaceName(groupName)}? ${toPascalCase(groupName)} { get; set; }`); + lines.push(""); + } + if (lines[lines.length - 1] === "") lines.pop(); + lines.push(`}`); + lines.push(""); + + lines.push(`/// Registers client global API handlers on a JSON-RPC connection.`); + lines.push(`internal static class ClientGlobalApiRegistration`); + lines.push(`{`); + lines.push(` /// `); + lines.push(` /// Registers handlers for server-to-client global API calls.`); + lines.push(` /// Unlike client session APIs, these methods carry no implicit`); + lines.push(` /// sessionId dispatch key — a single set of handlers serves the`); + lines.push(` /// entire connection.`); + lines.push(` /// `); + lines.push(` public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers)`); + lines.push(` {`); + for (const { groupName, methods } of groups) { + for (const method of methods) { + const handlerProperty = toPascalCase(groupName); + const handlerMethod = clientHandlerMethodName(method.rpcMethod); + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const paramsClass = paramsTypeName(method); + const taskType = handlerTaskType(method); + + if (hasParams) { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(request, cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(request, cancellationToken);`); + } + lines.push(` }), singleObjectParam: true);`); + } else { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func)(async cancellationToken =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(cancellationToken);`); + } + lines.push(` }));`); + } + } + } + lines.push(` }`); + lines.push(`}`); + + return lines; +} + function generateRpcCode( schema: ApiSchema, externalJsonSerializableRefs: Map> = new Map(), @@ -2315,6 +2451,7 @@ function generateRpcCode( ...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {}), ...collectRpcMethods(schema.clientSession || {}), + ...collectRpcMethods(schema.clientGlobal || {}), ]; for (const name of collectRpcMethodReferencedDefinitionNames( allMethods.filter((method) => method.stability !== "experimental"), @@ -2343,6 +2480,9 @@ function generateRpcCode( let clientSessionParts: string[] = []; if (schema.clientSession) clientSessionParts = emitClientSessionApiRegistration(schema.clientSession, classes); + let clientGlobalParts: string[] = []; + if (schema.clientGlobal) clientGlobalParts = emitClientGlobalApiRegistration(schema.clientGlobal, classes); + const lines: string[] = []; lines.push(`${COPYRIGHT} @@ -2368,6 +2508,7 @@ namespace GitHub.Copilot.Rpc; for (const part of serverRpcParts) lines.push(part, ""); for (const part of sessionRpcParts) lines.push(part, ""); if (clientSessionParts.length > 0) lines.push(...clientSessionParts, ""); + if (clientGlobalParts.length > 0) lines.push(...clientGlobalParts, ""); // Add JsonSerializerContext for AOT/trimming support const typeNames = [...emittedRpcClassSchemas.keys(), ...emittedRpcEnumResultTypes].sort(); diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 9c74977fb..5403fb444 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -3952,7 +3952,7 @@ async function generateRpc(schemaPath?: string): Promise { if (generatedTypeCode.includes("time.Time")) { imports.push(`"time"`); } - if (schema.clientSession) { + if (schema.clientSession || schema.clientGlobal) { imports.push(`"errors"`, `"fmt"`); } imports.push(`"github.com/github/copilot-sdk/go/internal/jsonrpc2"`); @@ -3987,6 +3987,10 @@ async function generateRpc(schemaPath?: string): Promise { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType, generatedRpcCode.discriminatedUnions); } + if (schema.clientGlobal) { + emitClientGlobalApiRegistration(lines, schema.clientGlobal, resolveType, generatedRpcCode.discriminatedUnions); + } + const outPath = await writeGeneratedFile("go/rpc/zrpc.go", wrapGeneratedGoComments(lines.join("\n"))); console.log(` ✓ ${outPath}`); @@ -4348,7 +4352,106 @@ function emitClientSessionApiRegistration(lines: string[], clientSchema: Record< lines.push(``); } -// ── Main ──────────────────────────────────────────────────────────────────── +function emitClientGlobalApiRegistration(lines: string[], clientSchema: Record, resolveType: (name: string) => string, unionInfos: Map): void { + const groups = collectClientGroups(clientSchema); + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + const groupDeprecated = isNodeFullyDeprecated(groupNode); + if (groupDeprecated) { + pushGoComment(lines, `Deprecated: ${interfaceName} contains deprecated APIs that will be removed in a future version.`); + } + if (groupExperimental) { + pushGoExperimentalApiComment(lines, interfaceName); + } + lines.push(`type ${interfaceName} interface {`); + for (const method of methods) { + const resultSchema = getMethodResultSchema(method); + pushGoRpcMethodComment( + lines, + clientHandlerMethodName(method.rpcMethod), + method, + resultSchema, + goRpcParamsDescription(method, getMethodParamsSchema(method)), + "\t", + "handles" + ); + if (method.deprecated && !groupDeprecated) { + pushGoComment(lines, `Deprecated: ${clientHandlerMethodName(method.rpcMethod)} is deprecated and will be removed in a future version.`, "\t"); + } + if (method.stability === "experimental" && !groupExperimental) { + pushGoExperimentalMethodComment(lines, clientHandlerMethodName(method.rpcMethod), "\t"); + } + const paramsType = resolveType(goParamsTypeName(method)); + const nullableInner = resultSchema ? getNullableInner(resultSchema) : undefined; + let returnType: string; + if (isOpaqueJson(resultSchema)) { + returnType = "any"; + } else { + const resultType = nullableInner + ? resolveType(goNullableResultTypeName(method, nullableInner)) + : resolveType(goResultTypeName(method)); + returnType = unionInfos.has(resultType) ? resultType : `*${resultType}`; + } + lines.push(`\t${clientHandlerMethodName(method.rpcMethod)}(request *${paramsType}) (${returnType}, error)`); + } + lines.push(`}`); + lines.push(``); + } + + lines.push(`// ClientGlobalAPIHandlers provides all client-global API handler groups.`); + lines.push(`//`); + lines.push(`// Unlike client-session handlers these carry no implicit session id dispatch`); + lines.push(`// key; a single set of handlers serves the entire connection.`); + lines.push(`type ClientGlobalAPIHandlers struct {`); + for (const { groupName } of groups) { + lines.push(`\t${toGoFieldName(groupName)} ${clientHandlerInterfaceName(groupName)}`); + } + lines.push(`}`); + lines.push(``); + + lines.push(`func clientGlobalHandlerError(err error) *jsonrpc2.Error {`); + lines.push(`\tif err == nil {`); + lines.push(`\t\treturn nil`); + lines.push(`\t}`); + lines.push(`\tvar rpcErr *jsonrpc2.Error`); + lines.push(`\tif errors.As(err, &rpcErr) {`); + lines.push(`\t\treturn rpcErr`); + lines.push(`\t}`); + lines.push(`\treturn &jsonrpc2.Error{Code: -32603, Message: err.Error()}`); + lines.push(`}`); + lines.push(``); + + lines.push(`// RegisterClientGlobalAPIHandlers registers handlers for server-to-client client-global API calls.`); + lines.push(`func RegisterClientGlobalAPIHandlers(client *jsonrpc2.Client, handlers *ClientGlobalAPIHandlers) {`); + for (const { groupName, methods } of groups) { + const handlerField = toGoFieldName(groupName); + for (const method of methods) { + const paramsType = resolveType(goParamsTypeName(method)); + lines.push(`\tclient.SetRequestHandler("${method.rpcMethod}", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) {`); + lines.push(`\t\tvar request ${paramsType}`); + lines.push(`\t\tif err := json.Unmarshal(params, &request); err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\tif handlers == nil || handlers.${handlerField} == nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: "No ${groupName} client-global handler registered"}`); + lines.push(`\t\t}`); + lines.push(`\t\tresult, err := handlers.${handlerField}.${clientHandlerMethodName(method.rpcMethod)}(&request)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, clientGlobalHandlerError(err)`); + lines.push(`\t\t}`); + lines.push(`\t\traw, err := json.Marshal(result)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\treturn raw, nil`); + lines.push(`\t})`); + } + } + lines.push(`}`); + lines.push(``); +} async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { let apiSchemaForSharing: ApiSchema | undefined; diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 783ac8244..e0c4e2141 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -3190,6 +3190,9 @@ def _patch_model_capabilities(data: dict) -> dict: if (schema.clientSession) { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); } + if (schema.clientGlobal) { + emitClientGlobalApiRegistration(lines, schema.clientGlobal, resolveType); + } // Patch models.list to normalize capabilities before deserialization let finalCode = lines.join("\n"); @@ -3712,7 +3715,107 @@ function emitClientSessionRegistrationMethod( lines.push(` client.set_request_handler("${method.rpcMethod}", ${handlerVariableName})`); } -// ── Main ──────────────────────────────────────────────────────────────────── +function emitClientGlobalApiRegistration( + lines: string[], + node: Record, + resolveType: (name: string) => string +): void { + const groups = Object.entries(node).filter(([, value]) => typeof value === "object" && value !== null && !isRpcMethod(value)); + + for (const [groupName, groupNode] of groups) { + const handlerName = `${toPascalCase(groupName)}Handler`; + const groupExperimental = isNodeFullyExperimental(groupNode as Record); + const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); + if (groupDeprecated) { + lines.push(`# Deprecated: this API group is deprecated and will be removed in a future version.`); + } + if (groupExperimental) { + pushPyExperimentalApiGroupComment(lines); + } + lines.push(`class ${handlerName}(Protocol):`); + const methods = collectRpcMethods(groupNode as Record); + for (const method of methods) { + // Client-global handler methods reuse the session handler shape; the + // only difference is dispatch (no implicit session_id key). + emitClientSessionHandlerMethod(lines, method, resolveType, groupExperimental, groupDeprecated); + } + lines.push(``); + } + + lines.push(`@dataclass`); + lines.push(`class ClientGlobalApiHandlers:`); + if (groups.length === 0) { + lines.push(` pass`); + } else { + for (const [groupName] of groups) { + lines.push(` ${toSnakeCase(groupName)}: ${toPascalCase(groupName)}Handler | None = None`); + } + } + lines.push(``); + + lines.push(`def register_client_global_api_handlers(`); + lines.push(` client: "JsonRpcClient",`); + lines.push(` handlers: ClientGlobalApiHandlers,`); + lines.push(`) -> None:`); + lines.push(` """Register client-global request handlers on a JSON-RPC connection.`); + lines.push(``); + lines.push(` Unlike client-session handlers these methods carry no implicit`); + lines.push(` session_id dispatch key; a single set of handlers serves the entire`); + lines.push(` connection.`); + lines.push(` """`); + if (groups.length === 0) { + lines.push(` return`); + } else { + for (const [groupName, groupNode] of groups) { + const methods = collectRpcMethods(groupNode as Record); + for (const method of methods) { + emitClientGlobalRegistrationMethod(lines, groupName, method, resolveType); + } + } + } + lines.push(``); +} + +function emitClientGlobalRegistrationMethod( + lines: string[], + groupName: string, + method: RpcMethod, + resolveType: (name: string) => string +): void { + const rpcSegments = method.rpcMethod.split("."); + const handlerVariableName = `handle_${rpcSegments.map(toSnakeCase).join("_")}`; + const paramsType = resolveType(pythonParamsTypeName(method)); + const resultSchema = getMethodResultSchema(method); + const nullableInner = resultSchema ? getNullableInner(resultSchema) : undefined; + const hasResult = !isVoidSchema(resultSchema) && !nullableInner; + const handlerField = toSnakeCase(groupName); + const handlerMethod = clientSessionHandlerMethodName(method.rpcMethod); + + lines.push(` async def ${handlerVariableName}(params: dict) -> dict | None:`); + lines.push(` request = ${paramsType}.from_dict(params)`); + lines.push(` handler = handlers.${handlerField}`); + lines.push(` if handler is None: raise RuntimeError("No ${handlerField} client-global handler registered")`); + if (hasResult) { + lines.push(` result = await handler.${handlerMethod}(request)`); + if (isObjectSchema(resultSchema)) { + lines.push(` return result.to_dict()`); + } else { + lines.push(` return result.value if hasattr(result, 'value') else result`); + } + } else if (nullableInner) { + lines.push(` result = await handler.${handlerMethod}(request)`); + const resolvedInner = resolveSchema(nullableInner, rpcDefinitions) ?? nullableInner; + if (isObjectSchema(resolvedInner) || nullableInner.$ref) { + lines.push(` return result.to_dict() if result is not None else None`); + } else { + lines.push(` return result`); + } + } else { + lines.push(` await handler.${handlerMethod}(request)`); + lines.push(` return None`); + } + lines.push(` client.set_request_handler("${method.rpcMethod}", ${handlerVariableName})`); +} async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { await generateSessionEvents(sessionSchemaPath); diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index bba360b47..1303a4979 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -516,7 +516,8 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; const allMethods = [...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {})]; const clientSessionMethods = collectRpcMethods(schema.clientSession || {}); - const rpcMethods = [...allMethods, ...clientSessionMethods]; + const clientGlobalMethods = collectRpcMethods(schema.clientGlobal || {}); + const rpcMethods = [...allMethods, ...clientSessionMethods, ...clientGlobalMethods]; const seenBlocks = new Map(); // Build a single combined schema with shared definitions and all method types. @@ -717,6 +718,13 @@ function hasInternalMethods(node: Record): boolean { lines.push(...emitClientSessionApiRegistration(schema.clientSession)); } + // Generate client *global* API handler interfaces and registration function. + // Unlike client-session APIs, these methods do not carry a `sessionId` dispatch + // key — the SDK consumer registers a single process-wide handler per group. + if (schema.clientGlobal) { + lines.push(...emitClientGlobalApiRegistration(schema.clientGlobal)); + } + const outPath = await writeGeneratedFile("nodejs/src/generated/rpc.ts", lines.join("\n")); console.log(` ✓ ${outPath}`); } @@ -926,6 +934,105 @@ function emitClientSessionApiRegistration(clientSchema: Record) return lines; } +/** + * Generate handler interfaces and a registration function for client *global* + * API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `registerClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const [groupName, methods] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + const groupDeprecated = isNodeFullyDeprecated(clientSchema[groupName] as Record); + const groupExperimental = isNodeFullyExperimental(clientSchema[groupName] as Record); + if (groupDeprecated) { + lines.push(`/** @deprecated Handler for \`${groupName}\` client global API methods. */`); + } else if (groupExperimental) { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + lines.push(TS_EXPERIMENTAL_JSDOC); + } else { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + } + lines.push(`export interface ${interfaceName} {`); + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + const pType = hasParams ? paramsTypeName(method) : ""; + const rType = tsResultType(method); + + pushTsRpcMethodJsDoc(lines, " ", method, { + summaryFallback: `Handles \`${method.rpcMethod}\`.`, + paramsName: hasParams ? "params" : undefined, + paramsDescription: rpcParamsDescription(method, getMethodParamsSchema(method)), + includeDeprecated: method.deprecated && !groupDeprecated, + includeExperimental: method.stability === "experimental" && !groupExperimental, + }); + if (hasParams) { + lines.push(` ${name}(params: ${pType}): Promise<${rType}>;`); + } else { + lines.push(` ${name}(): Promise<${rType}>;`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/** All client global API handler groups. */`); + lines.push(`export interface ClientGlobalApiHandlers {`); + for (const [groupName] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + lines.push(` ${groupName}?: ${interfaceName};`); + } + lines.push(`}`); + lines.push(""); + + lines.push(`/**`); + lines.push(` * Register client global API handlers on a JSON-RPC connection.`); + lines.push(` * The server calls these methods to delegate work to the client.`); + lines.push(` * Unlike session-scoped client APIs, these methods carry no implicit`); + lines.push(` * \`sessionId\` dispatch key — a single set of handlers serves the entire`); + lines.push(` * connection.`); + lines.push(` */`); + lines.push(`export function registerClientGlobalApiHandlers(`); + lines.push(` connection: MessageConnection,`); + lines.push(` handlers: ClientGlobalApiHandlers,`); + lines.push(`): void {`); + + for (const [groupName, methods] of groups) { + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const pType = paramsTypeName(method); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + + if (hasParams) { + lines.push(` connection.onRequest("${method.rpcMethod}", async (params: ${pType}) => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}(params);`); + lines.push(` });`); + } else { + lines.push(` connection.onRequest("${method.rpcMethod}", async () => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}();`); + lines.push(` });`); + } + } + } + + lines.push(`}`); + lines.push(""); + + return lines; +} + // ── Main ──────────────────────────────────────────────────────────────────── async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { diff --git a/scripts/codegen/utils.ts b/scripts/codegen/utils.ts index 3917dad44..726485c93 100644 --- a/scripts/codegen/utils.ts +++ b/scripts/codegen/utils.ts @@ -470,6 +470,7 @@ export interface ApiSchema { server?: Record; session?: Record; clientSession?: Record; + clientGlobal?: Record; } export function isRpcMethod(node: unknown): node is RpcMethod { @@ -519,6 +520,7 @@ export function fixNullableRequiredRefsInApiSchema(schema: ApiSchema): ApiSchema server: walkApiNode(schema.server), session: walkApiNode(schema.session), clientSession: walkApiNode(schema.clientSession), + clientGlobal: walkApiNode(schema.clientGlobal), }; }