Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions .github/workflows/blue-pr-prechecks.yml
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
name: blue-pr-prechecks

on:
pull_request:
types:
- opened
- synchronize
- reopened
- ready_for_review
- edited
workflow_dispatch:

permissions:
contents: read
pull-requests: read

concurrency:
group: blue-pr-prechecks-${{ github.event.pull_request.number }}
group: blue-pr-prechecks-${{ github.event.pull_request.number || github.run_id }}
cancel-in-progress: true

jobs:
prechecks:
# Temporarily disabled while blue submissions are inactive.
if: ${{ false }}
runs-on: ubuntu-latest
timeout-minutes: 20
env:
Expand Down
11 changes: 4 additions & 7 deletions .github/workflows/blue-pr-sync.yml
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
name: blue-pr-sync

on:
pull_request_target:
types:
- opened
- synchronize
- reopened
- edited
workflow_dispatch:

permissions:
contents: read
pull-requests: read

concurrency:
group: blue-pr-sync-${{ github.event.pull_request.number }}
group: blue-pr-sync-${{ github.event.pull_request.number || github.run_id }}
cancel-in-progress: true

jobs:
gate:
# Temporarily disabled while blue submissions are inactive.
if: ${{ false }}
runs-on: ubuntu-latest
timeout-minutes: 10
outputs:
Expand Down
19 changes: 3 additions & 16 deletions .github/workflows/trusted-blue-eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,22 @@ on:
required: false
default: true
type: boolean
# Auto-fire after blue-pr-sync registers a new patch in the API queue.
# The claim step returns HTTP 204 when the queue is empty, so spurious
# triggers are no-ops.
workflow_run:
workflows: ["blue-pr-sync"]
types: [completed]

permissions:
contents: read

env:
KERNELGUARD_API_BASE_URL: https://kguard.sinatras.dev

# Per-trigger concurrency group so multiple PRs syncing within seconds of
# each other do not cancel each other out (GitHub allows only 1 pending +
# 1 running per group, beyond which new runs auto-cancel). Each
# blue-pr-sync run carries the candidate head SHA on
# event.workflow_run.head_sha; manual workflow_dispatch falls back to the
# unique run id so manual drains never collide with each other either.
# Manual workflow_dispatch falls back to the unique run id.
concurrency:
group: trusted-blue-eval-${{ github.event.workflow_run.head_sha || github.run_id }}
cancel-in-progress: false

jobs:
claim:
# Skip when triggered by a failed blue-pr-sync (rejected gate, static
# security failure, etc.). Manual workflow_dispatch always passes this
# check because workflow_run context is empty in that case.
if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }}
# Temporarily disabled while blue submissions are inactive.
if: ${{ false }}
runs-on: ubuntu-latest
timeout-minutes: 10
environment: trusted-blue-eval
Expand Down
157 changes: 156 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,6 +2846,149 @@ def detect_hardcoded_shapes(code: str | SubmissionFacts) -> list[dict]:
return matches


def detect_benchmark_unsafe_algo_dispatch(code: str | SubmissionFacts) -> list[dict]:
"""Shape/data-family dispatch into unsafe specialized algorithms.

Per-shape tuning is legitimate. This detector requires the stronger
combination: benchmark-shaped branches plus data-family classifiers and
fast-path terms that imply reduced precision, truncated work, or a generic
exact fallback for cases outside the public fixture.
"""
facts = ensure_submission_facts(code)
raw_code = facts.raw_code
active_code = facts.python_active
combined = f"{active_code}\n{raw_code}"
low = combined.lower()
matches = []

shape_gate_count = len(re.findall(
r'(?:\(\s*[Bb]\s*,\s*[Nn]\s*\)\s*==\s*\(\s*\d{1,5}\s*,\s*\d{1,5}\s*\)|'
r'\b(?:B|b|batch|N|n)\s*(?:==|!=|<=|>=|<|>)\s*\d{1,5})',
combined,
))
benchmark_constant_count = len(re.findall(
r'\b(?:60|176|352|512|640|768|1024|2048|4096|1280)\b',
combined,
))
family_terms = re.findall(
r'\b(?:classify_512|classify_1024|detect_label|rankdef|rank_def|clustered|'
r'nearrank|near_rank|nearcol|near_col|mixed|rowscale|colnorm|zerofrac|'
r'colrange|tail_max|neardiff|inactive_tail)\b',
low,
)
unsafe_terms = re.findall(
r'\b(?:zero_tail|trunc|truncate|stop_at|switch_at|ncap|ncols|limit|'
r'_trunc_stop|tf32|fp16|bf16|fast_path|unsafe|rankdef)\b',
low,
)
fallback_terms = re.findall(
r'\b(?:torch\.geqrf|torch\.linalg\.(?:qr|lu)|fallback|fall back|exact|generic)\b',
combined,
re.IGNORECASE,
)
raw_native_cache = (
bool(re.search(r'\bnvcc\b', low))
and bool(re.search(r'\b(?:ctypes|CDLL|subprocess\.run|Popen|\.so|\.cu)\b', combined))
and not bool(re.search(r'\bload_inline\s*\(', active_code))
and len(re.findall(r'\b(?:data_ptr|ctypes\.c_void_p|uint64_t|void\s*\*)\b', combined)) >= 4
)

strong_shape_fixture = shape_gate_count >= 8 and benchmark_constant_count >= 20
data_family_routing = len(family_terms) >= 4 and len(unsafe_terms) >= 4
exact_escape_hatch = len(fallback_terms) >= 2

if raw_native_cache and strong_shape_fixture and data_family_routing:
signals = [
"raw native compile/load cache",
f"{shape_gate_count} shape gates",
f"{len(set(family_terms))} data-family terms",
f"{len(set(unsafe_terms))} unsafe/truncation terms",
]
if exact_escape_hatch:
signals.append("exact/generic fallback path")
matches.append({
"pattern": "BENCHMARK_UNSAFE_ALGO_DISPATCH",
"severity": "high",
"evidence": "; ".join(signals),
})

return matches


def detect_input_structure_truncation_dispatch(code: str | SubmissionFacts) -> list[dict]:
"""Input probes routing to partial computation or exact subset repair.

This is not a ban on algorithm selection. It targets classifiers/probes of
input structure that steer only a detected subset into truncated fast paths,
tail-zeroing, or reference/exact repair, which is the static shape of a
public-fixture assumption.
"""
facts = ensure_submission_facts(code)
active_code = facts.python_active
raw_code = facts.raw_code
combined = f"{active_code}\n{raw_code}"
low = combined.lower()
matches = []

probe_terms = re.findall(
r'\b(?:probe_cond|probe_n|_probe|_probe_conditioning|_s3_collinear|'
r'spread_thresh|sparse_thresh|corr_thresh|bit0|bit1|classify_512|'
r'classify_1024|detect_label|colnorm|neardiff|zerofrac|tail_max)\b',
low,
)
subset_terms = re.findall(
r'\b(?:bad_idx|good_idx|bad_count|good_count|index_select|index_copy|'
r'scatter_exact|compact_label|panel_labels|label_bad_cache|gather_.*bad|'
r'compact_.*bad)\b',
low,
)
partial_terms = re.findall(
r'(?:tau\s*\[[^\]]*limit\s*:\s*\]\s*=\s*0|H\s*\[[^\]]*limit\s*:\s*\]\s*=\s*0|'
r'\b(?:zero_tail|trunc|truncate|inactive_tail|stop_at|switch_at|ncap|ncols|limit\s*<\s*n)\b)',
combined,
re.IGNORECASE,
)
repair_terms = re.findall(
r'\b(?:repair|exact|torch\.geqrf|torch\.linalg\.(?:qr|lu)|index_copy|index_select)\b',
combined,
re.IGNORECASE,
)
probe_set = set(probe_terms)
subset_set = set(subset_terms)
partial_set = set(partial_terms)
repair_set = set(repair_terms)

probe_driven_repair = (
bool(probe_set & {"probe_cond", "probe_n", "_probe", "_s3_collinear"})
and len(probe_set) >= 4
and len(subset_set) >= 3
and len(repair_set) >= 3
)
classifier_driven_partial = (
bool(probe_set & {"classify_512", "classify_1024", "detect_label"})
and len(probe_set) >= 4
and len(partial_set) >= 3
and len(subset_set) >= 2
)

if probe_driven_repair or classifier_driven_partial:
signals = [
f"{len(probe_set)} structure/probe terms",
f"{len(subset_set)} subset-routing terms",
]
if partial_set:
signals.append(f"{len(partial_set)} partial/truncation terms")
if repair_set:
signals.append(f"{len(repair_set)} exact-repair terms")
matches.append({
"pattern": "INPUT_STRUCTURE_TRUNCATION_DISPATCH",
"severity": "high",
"evidence": "; ".join(signals),
})

return matches


def detect_unsync_multistream(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 14: Unsynchronized multi-stream dispatch to exploit timing.

Expand Down Expand Up @@ -7289,6 +7432,14 @@ class RulePolicy:
"HARDCODED_SHAPES", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
),
"BENCHMARK_UNSAFE_ALGO_DISPATCH": RulePolicy(
"BENCHMARK_UNSAFE_ALGO_DISPATCH", "fixture_specialization", "hard", AUTO_FILTER, (),
(), "keep",
),
"INPUT_STRUCTURE_TRUNCATION_DISPATCH": RulePolicy(
"INPUT_STRUCTURE_TRUNCATION_DISPATCH", "fixture_specialization", "hard", AUTO_FILTER, (),
(), "keep",
),
"UNSYNC_MULTISTREAM": RulePolicy(
"UNSYNC_MULTISTREAM", "timing_manipulation", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -7556,6 +7707,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_torch_compile_cache,
detect_cuda_graph_python,
detect_hardcoded_shapes,
detect_benchmark_unsafe_algo_dispatch,
detect_input_structure_truncation_dispatch,
detect_unsync_multistream,
detect_cuda_event_disable_timing,
detect_token_paste_cuda_api,
Expand Down Expand Up @@ -7617,6 +7770,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("torch_compile_cache", detect_torch_compile_cache),
("cuda_graph_python", detect_cuda_graph_python),
("hardcoded_shapes", detect_hardcoded_shapes),
("benchmark_unsafe_algo_dispatch", detect_benchmark_unsafe_algo_dispatch),
("input_structure_truncation_dispatch", detect_input_structure_truncation_dispatch),
("unsync_multistream", detect_unsync_multistream),
("cuda_event_disable_timing", detect_cuda_event_disable_timing),
("token_paste_cuda_api", detect_token_paste_cuda_api),
Expand Down Expand Up @@ -8522,7 +8677,7 @@ def _worker_parquet(args: tuple) -> dict:
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
"HARDCODED_SHAPES", "TRIVIAL_PROBE",
"HARDCODED_SHAPES", "BENCHMARK_UNSAFE_ALGO_DISPATCH", "INPUT_STRUCTURE_TRUNCATION_DISPATCH", "TRIVIAL_PROBE",
"OBFUSCATED_EXEC", "DYNAMIC_EXECUTION", "MODULE_RELOAD", "THREAD_INJECTION", "LAZY_TENSOR",
"TOKEN_PASTE_CUDA_API", "SEQUENCE_BATCH_GRAPH", "PARTIAL_GRAPH_KEY", "RUNTIME_PACKAGE_INSTALL",
"PRECISION_DOWNGRADE", "SCORE_PHYSICS_FLOOR", "SCORE_IMPOSSIBLE", "SCORE_SUSPECT_FLOOR",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "kernelguard"
version = "0.3.0"
version = "0.3.1"
description = "Rule-based GPU kernel hack detector."
readme = "README.md"
requires-python = ">=3.11"
Expand Down
Loading