diff --git a/kernelguard.py b/kernelguard.py index 275ad4a..79bb805 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -586,6 +586,336 @@ def _looks_stateful_name(name: str) -> bool: return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo")) +_ENTRYPOINT_METHOD_NAMES = ("__call__", "forward", "run", "solve") + + +def _iter_non_nested_nodes(node: ast.AST): + """Yield descendants without descending into nested function/class scopes.""" + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): + yield child + continue + yield child + yield from _iter_non_nested_nodes(child) + + +def _function_input_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + args = list(fn.args.posonlyargs) + list(fn.args.args) + list(fn.args.kwonlyargs) + if fn.name in _ENTRYPOINT_METHOD_NAMES and args and args[0].arg in {"self", "cls"}: + args = args[1:] + names = {arg.arg for arg in args} + if fn.args.vararg is not None: + names.add(fn.args.vararg.arg) + if fn.args.kwarg is not None: + names.add(fn.args.kwarg.arg) + return names + + +def _method_from_class(cls: ast.ClassDef, preferred: tuple[str, ...] = _ENTRYPOINT_METHOD_NAMES): + methods = { + child.name: child + for child in cls.body + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + for name in preferred: + if name in methods: + return methods[name] + return None + + +def _factory_returned_function(fn: ast.FunctionDef | ast.AsyncFunctionDef): + nested = { + child.name: child + for child in fn.body + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + for stmt in fn.body: + if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name): + returned = nested.get(stmt.value.id) + if returned is not None: + return returned + return None + + +def _entrypoint_function_nodes(facts: SubmissionFacts) -> list[ast.FunctionDef | ast.AsyncFunctionDef]: + """Resolve simple Python callable exports for entrypoint-scoped detectors.""" + tree = facts.ast_tree + if tree is None: + return [] + + functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {} + classes: dict[str, ast.ClassDef] = {} + instances: dict[str, str] = {} + aliases: dict[str, str] = {} + resolved: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] + seen: set[int] = set() + + for stmt in tree.body: + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + functions[stmt.name] = stmt + elif isinstance(stmt, ast.ClassDef): + classes[stmt.name] = stmt + + def add(fn: ast.FunctionDef | ast.AsyncFunctionDef | None) -> None: + if fn is None or id(fn) in seen: + return + seen.add(id(fn)) + resolved.append(fn) + + def resolve_name(name: str) -> str: + while name in aliases and aliases[name] != name: + name = aliases[name] + return name + + for stmt in tree.body: + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and is_entrypoint_name(stmt.name): + add(stmt) + elif isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name): + add(_method_from_class(stmt)) + + if not isinstance(stmt, ast.Assign): + continue + target_names = [t.id for t in stmt.targets if isinstance(t, ast.Name)] + if not target_names: + continue + + value = stmt.value + if isinstance(value, ast.Name): + value_name = resolve_name(value.id) + for target in target_names: + aliases[target] = value_name + if is_entrypoint_name(target): + add(functions.get(value_name)) + add(_method_from_class(classes[value_name]) if value_name in classes else None) + elif value_name in classes: + instances[target] = value_name + elif isinstance(value, ast.Call): + callee = value.func + if isinstance(callee, ast.Name): + callee_name = resolve_name(callee.id) + if callee_name in classes: + for target in target_names: + instances[target] = callee_name + if is_entrypoint_name(target): + add(_method_from_class(classes[callee_name])) + elif callee_name == "partial" and value.args and isinstance(value.args[0], ast.Name): + fn = functions.get(resolve_name(value.args[0].id)) + for target in target_names: + if is_entrypoint_name(target): + add(fn) + elif callee_name in functions: + for target in target_names: + if is_entrypoint_name(target): + add(_factory_returned_function(functions[callee_name]) or functions[callee_name]) + elif isinstance(callee, ast.Attribute): + if callee.attr == "partial" and value.args and isinstance(value.args[0], ast.Name): + fn = functions.get(resolve_name(value.args[0].id)) + for target in target_names: + if is_entrypoint_name(target): + add(fn) + owner = callee.value + if isinstance(owner, ast.Call) and isinstance(owner.func, ast.Name): + class_name = resolve_name(owner.func.id) + if class_name in classes and callee.attr in _ENTRYPOINT_METHOD_NAMES: + for target in target_names: + if is_entrypoint_name(target): + add(_method_from_class(classes[class_name], (callee.attr,))) + elif isinstance(value, ast.Attribute): + owner = value.value + if value.attr in _ENTRYPOINT_METHOD_NAMES and isinstance(owner, ast.Name): + owner_name = resolve_name(owner.id) + class_name = instances.get(owner_name, owner_name if owner_name in classes else "") + if class_name in classes: + for target in target_names: + if is_entrypoint_name(target): + add(_method_from_class(classes[class_name], (value.attr,))) + + return resolved + + +def _expr_is_none(expr: ast.AST | None) -> bool: + return isinstance(expr, ast.Constant) and expr.value is None + + +def _static_string(expr: ast.AST | None) -> Optional[str]: + if isinstance(expr, ast.Constant) and isinstance(expr.value, str): + return expr.value + if isinstance(expr, ast.JoinedStr): + parts: list[str] = [] + for value in expr.values: + if not isinstance(value, ast.Constant) or not isinstance(value.value, str): + return None + parts.append(value.value) + return "".join(parts) + if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.Add): + left = _static_string(expr.left) + right = _static_string(expr.right) + if left is not None and right is not None: + return left + right + if ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "join" + and _static_string(expr.func.value) is not None + and len(expr.args) == 1 + and isinstance(expr.args[0], (ast.List, ast.Tuple)) + ): + parts = [_static_string(elt) for elt in expr.args[0].elts] + if all(part is not None for part in parts): + return _static_string(expr.func.value).join(parts) # type: ignore[arg-type] + return None + + +def _expr_has_benchmark_literal(expr: ast.AST | None) -> bool: + if expr is None: + return False + for node in ast.walk(expr): + if isinstance(node, ast.Constant) and isinstance(node.value, str): + if re.search(r'Ranked\s+Benchmark|BENCHMARK_PASSED|\bbenchmark\b\s*[:=]?|score\s*[:=]', node.value, re.IGNORECASE): + return True + return False + + +def _expr_has_decode_like_call(expr: ast.AST | None, helper_names: set[str] | None = None) -> bool: + if expr is None: + return False + helper_names = helper_names or set() + decode_names = { + "decode", "decompress", "b64decode", "b32decode", "b16decode", + "urlsafe_b64decode", "decodebytes", "decodestring", "unhexlify", + "a2b_hex", "a2b_base64", "bytes", "bytearray", "chr", + } + for node in ast.walk(expr): + if not isinstance(node, ast.Call): + continue + if isinstance(node.func, ast.Name) and (node.func.id in decode_names or node.func.id in helper_names): + return True + if isinstance(node.func, ast.Attribute) and node.func.attr in decode_names: + return True + return False + + +def _expr_contains_input_derived_call(expr: ast.AST | None, input_names: set[str]) -> bool: + if expr is None: + return False + for node in ast.walk(expr): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute) and _ast_root_name(node.func.value) in input_names: + return True + if isinstance(node.func, ast.Name) and any(_expr_names(arg) & input_names for arg in node.args): + return True + return False + + +def _is_input_float_call(expr: ast.AST | None, input_names: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in input_names + and not expr.args + and not expr.keywords + ) + + +def _is_input_attr_float_call(expr: ast.AST | None, owner_name: str) -> Optional[str]: + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Attribute) + and isinstance(expr.func.value.value, ast.Name) + and expr.func.value.value.id == owner_name + and not expr.args + and not expr.keywords + ): + return None + return expr.func.value.attr + + +def _lambda_input_names(expr: ast.Lambda) -> set[str]: + args = list(expr.args.posonlyargs) + list(expr.args.args) + list(expr.args.kwonlyargs) + if args and args[0].arg in {"self", "cls"}: + args = args[1:] + return {arg.arg for arg in args} + + +def _lambda_returns_input_float(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Lambda): + return False + input_names = _lambda_input_names(expr) + return bool(input_names) and _is_input_float_call(expr.body, input_names) + + +def _torch_alias_sets(facts: SubmissionFacts) -> tuple[set[str], dict[str, str]]: + torch_aliases = {"torch"} + imported: dict[str, str] = {} + tree = facts.ast_tree + if tree is None: + return torch_aliases, imported + for node in facts._imports: + for alias in node.names: + if alias.name == "torch": + torch_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module != "torch": + continue + for alias in node.names: + imported[alias.asname or alias.name] = alias.name + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + dotted = _ast_dotted_name(stmt.value) + if dotted and dotted.startswith(tuple(f"{alias}." for alias in torch_aliases)): + attr = dotted.rsplit(".", 1)[-1] + for target in stmt.targets: + if isinstance(target, ast.Name): + imported[target.id] = attr + return torch_aliases, imported + + +def _torch_call_name(call: ast.Call, torch_aliases: set[str], imported: dict[str, str]) -> Optional[str]: + if isinstance(call.func, ast.Name): + return imported.get(call.func.id) + if isinstance(call.func, ast.Attribute) and isinstance(call.func.value, ast.Name): + if call.func.value.id in torch_aliases: + return call.func.attr + return None + + +def _is_transpose_of(expr: ast.AST | None, name: str) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr in {"t", "transpose", "transpose_"} + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id == name + ) + + +def _looks_output_value_name(name: str) -> bool: + lowered = name.lower() + return ( + lowered in {"h", "q", "r", "y", "z"} + or any(token in lowered for token in ("out", "output", "result", "res", "ans")) + ) + + +def _cache_name_allows_output_replay(name: Optional[str]) -> bool: + if not name: + return False + lowered = name.lower() + if any(token in lowered for token in ( + "compile", "compiled", "module", "plan", "workspace", "preprocess", + "scale", "decoded", "meta", "sort", "view", "shape", "quant", "pad", + )): + return False + return any(token in lowered for token in ( + "cache", "reuse", "saved", "memo", "result", "out", "compute", "gemm", "bank", + )) + + # --------------------------------------------------------------------------- # Detectors # --------------------------------------------------------------------------- @@ -625,6 +955,87 @@ def detect_timer_monkeypatch(code: str | SubmissionFacts) -> list[dict]: "severity": "critical", "evidence": "Dummy/Fake Event class defined", }) + tree = facts.ast_tree + if tree is not None: + time_aliases = {"time"} + patch_aliases = {"patch"} + mock_module_aliases = {"mock"} + for node in facts._imports: + for alias in node.names: + if alias.name == "time": + time_aliases.add(alias.asname or alias.name) + elif alias.name == "unittest.mock": + mock_module_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "unittest.mock": + for alias in node.names: + if alias.name == "patch": + patch_aliases.add(alias.asname or alias.name) + elif node.module == "unittest": + for alias in node.names: + if alias.name == "mock": + mock_module_aliases.add(alias.asname or alias.name) + + clock_attrs = { + "perf_counter", "perf_counter_ns", "monotonic", "monotonic_ns", + "process_time", "process_time_ns", "time", "time_ns", + } + seen_time_patches: set[str] = set() + seen_cuda_timer_patches: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, (ast.Assign, ast.AnnAssign, ast.AugAssign)): + targets = node.targets if isinstance(node, ast.Assign) else [node.target] + for target in targets: + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id in time_aliases + and target.attr in clock_attrs + ): + seen_time_patches.add(f"{target.value.id}.{target.attr}") + if isinstance(target, ast.Subscript): + key = _static_string(target.slice) + owner = _ast_dotted_name(target.value) + if owner == "torch.cuda.__dict__" and key in {"Event", "synchronize"}: + seen_cuda_timer_patches.add(f"torch.cuda.__dict__[{key!r}]") + elif ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "setattr" + and len(node.args) >= 2 + and isinstance(node.args[0], ast.Name) + and node.args[0].id in time_aliases + ): + attr = _static_string(node.args[1]) + if attr in clock_attrs: + seen_time_patches.add(f"setattr({node.args[0].id}, {attr!r}, ...)") + elif isinstance(node, ast.Call): + func_name = _ast_dotted_name(node.func) + is_patch_call = ( + (isinstance(node.func, ast.Name) and node.func.id in patch_aliases) + or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "patch" + and isinstance(node.func.value, ast.Name) + and node.func.value.id in mock_module_aliases + ) + or func_name == "unittest.mock.patch" + ) + target = _static_string(node.args[0]) if is_patch_call and node.args else None + if target in {"torch.cuda.Event", "torch.cuda.synchronize", "torch.cuda.Event.elapsed_time"}: + seen_cuda_timer_patches.add(f"unittest.mock.patch({target!r})") + if seen_time_patches: + matches.append({ + "pattern": "TIMER_MONKEYPATCH", + "severity": "critical", + "evidence": "host timing function patched: " + ", ".join(sorted(seen_time_patches)), + }) + if seen_cuda_timer_patches: + matches.append({ + "pattern": "TIMER_MONKEYPATCH", + "severity": "critical", + "evidence": "CUDA timing primitive patched: " + ", ".join(sorted(seen_cuda_timer_patches)), + }) return matches @@ -673,6 +1084,73 @@ def add(pattern: str, severity: str, evidence: str): else: add("STDIO_REDIRECT", "medium", "sys.stderr reassigned") + tree = facts.ast_tree + if tree is not None: + sys_aliases = {"sys"} + os_aliases = {"os"} + os_write_aliases = set() + for node in facts._imports: + for alias in node.names: + if alias.name == "sys": + sys_aliases.add(alias.asname or alias.name) + elif alias.name == "os": + os_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "os": + for alias in node.names: + if alias.name == "write": + os_write_aliases.add(alias.asname or alias.name) + + print_fns = list(_entrypoint_function_nodes(facts)) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == "kernel": + print_fns.append(node) + + for fn in print_fns: + direct_emits: set[str] = set() + for node in _iter_non_nested_nodes(fn): + if not isinstance(node, ast.Call): + continue + payload = " ".join( + part + for arg in node.args + if (part := _static_string(arg)) is not None + ) + spoof_payload = bool( + re.search(r'Ranked\s+Benchmark|BENCHMARK_PASSED|\bbenchmark\b\s*[:=]?|score\s*[:=]', payload, re.IGNORECASE) + or any(_expr_has_benchmark_literal(arg) for arg in node.args) + ) + if isinstance(node.func, ast.Name): + if node.func.id == "print" and spoof_payload: + direct_emits.add("print(...) in entrypoint") + elif node.func.id in os_write_aliases and node.args: + fd = node.args[0] + if isinstance(fd, ast.Constant) and fd.value in (1, 2) and spoof_payload: + direct_emits.add(f"{node.func.id}({fd.value}, ...) in entrypoint") + elif isinstance(node.func, ast.Attribute): + if ( + node.func.attr == "write" + and isinstance(node.func.value, ast.Attribute) + and node.func.value.attr in {"stdout", "stderr"} + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id in sys_aliases + and spoof_payload + ): + direct_emits.add(f"{node.func.value.value.id}.{node.func.value.attr}.write(...) in entrypoint") + elif ( + node.func.attr == "write" + and isinstance(node.func.value, ast.Name) + and node.func.value.id in os_aliases + and node.args + ): + fd = node.args[0] + if isinstance(fd, ast.Constant) and fd.value in (1, 2) and spoof_payload: + direct_emits.add(f"{node.func.value.id}.write({fd.value}, ...) in entrypoint") + if direct_emits: + fake_emit = True + add("FAKE_BENCHMARK_EMIT", "critical", "; ".join(sorted(direct_emits))) + break + # Do not keep the old broad PRINT_INJECTION marker; the split rules carry # the action semantics now. return matches @@ -1114,6 +1592,170 @@ def detect_result_caching(code: str | SubmissionFacts) -> list[dict]: return matches +def detect_helper_output_replay_cache(code: str | SubmissionFacts) -> list[dict]: + """Detect helper-level output caches returned by the submitted entrypoint.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + node.name: node + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + if not functions: + return [] + + torch_aliases, imported_torch = _torch_alias_sets(facts) + tensor_return_ops = { + "matmul", "mm", "bmm", "zeros", "empty", "ones", "full", + "zeros_like", "empty_like", "ones_like", "full_like", + } + + def tensor_output_expr(expr: ast.AST | None, input_names: set[str], output_names: set[str]) -> bool: + if expr is None: + return False + if _expr_names(expr) & (input_names | output_names): + return True + for call in [n for n in ast.walk(expr) if isinstance(n, ast.Call)]: + if _torch_call_name(call, torch_aliases, imported_torch) in tensor_return_ops: + return True + for kw in call.keywords: + if kw.arg == "device" and _static_string(kw.value) == "cuda": + return True + return False + + def cache_lookup(value: ast.AST | None) -> tuple[Optional[str], bool]: + if isinstance(value, ast.Subscript): + cache_name = _ast_root_name(value.value) + return cache_name, _cache_name_allows_output_replay(cache_name) + if isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute): + if value.func.attr in {"get", "pop"}: + cache_name = _ast_root_name(value.func.value) + return cache_name, _cache_name_allows_output_replay(cache_name) + return None, False + + def helper_has_cache_replay(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> Optional[str]: + input_names = _function_input_names(fn) + output_names: set[str] = set() + cache_hits: set[str] = set() + stored_caches: set[str] = set() + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, (ast.Assign, ast.AnnAssign)): + continue + targets = stmt.targets if isinstance(stmt, ast.Assign) else [stmt.target] + value = stmt.value + cache_name, is_cache_lookup = cache_lookup(value) + if is_cache_lookup: + for target in targets: + cache_hits.update(_target_names(target)) + if cache_name: + cache_hits.add(cache_name) + continue + if tensor_output_expr(value, input_names, output_names): + for target in targets: + output_names.update(_target_names(target)) + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + if not isinstance(target, ast.Subscript): + continue + cache_name = _ast_root_name(target.value) + if not _cache_name_allows_output_replay(cache_name): + continue + if tensor_output_expr(stmt.value, input_names, output_names): + stored_caches.add(cache_name or "") + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Subscript): + cache_name = _ast_root_name(stmt.target.value) + if _cache_name_allows_output_replay(cache_name) and tensor_output_expr(stmt.value, input_names, output_names): + stored_caches.add(cache_name or "") + + if not stored_caches: + return None + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.If): + continue + for inner in stmt.body: + if isinstance(inner, ast.Return) and inner.value is not None: + cache_name, is_cache_lookup = cache_lookup(inner.value) + if is_cache_lookup and (not cache_name or cache_name in stored_caches): + return cache_name or next(iter(stored_caches)) + if _expr_names(inner.value) & cache_hits: + return next(iter(stored_caches)) + return None + + def helper_has_lru_tensor_return(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> bool: + has_lru = False + for deco in fn.decorator_list: + dotted = _ast_dotted_name(deco.func if isinstance(deco, ast.Call) else deco) + if dotted and dotted.endswith("lru_cache"): + has_lru = True + break + if not has_lru: + return False + input_names = _function_input_names(fn) + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Return) and tensor_output_expr(stmt.value, input_names, set()): + return True + return False + + def returned_helper_calls(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + assigned: dict[str, str] = {} + returned: set[str] = set() + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + callee = stmt.value.func.id + if callee in functions: + for target in stmt.targets: + for name in _target_names(target): + assigned[name] = callee + elif isinstance(stmt, ast.Return) and stmt.value is not None: + value = stmt.value + if isinstance(value, ast.Call) and isinstance(value.func, ast.Name) and value.func.id in functions: + returned.add(value.func.id) + elif isinstance(value, ast.Name) and value.id in assigned: + returned.add(assigned[value.id]) + elif isinstance(stmt, ast.If): + for inner in stmt.body + stmt.orelse: + if isinstance(inner, ast.Return) and isinstance(inner.value, ast.Name) and inner.value.id in assigned: + returned.add(assigned[inner.value.id]) + return returned + + pending: list[str] = [] + seen: set[str] = set() + for entrypoint in _entrypoint_function_nodes(facts): + pending.extend(sorted(returned_helper_calls(entrypoint))) + + while pending: + name = pending.pop(0) + if name in seen: + continue + seen.add(name) + fn = functions.get(name) + if fn is None: + continue + cache_name = helper_has_cache_replay(fn) + if cache_name is not None: + return [{ + "pattern": "OUTPUT_REPLAY_CACHE", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} returns helper-level cached output from {name} via {cache_name}", + }] + if helper_has_lru_tensor_return(fn): + return [{ + "pattern": "OUTPUT_REPLAY_CACHE", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} returns tensor output from lru_cache helper {name}", + }] + pending.extend(sorted(returned_helper_calls(fn))) + + return [] + + def detect_value_keyed_output_replay(code: str | SubmissionFacts) -> list[dict]: """Detect cached output replay guarded by full tensor-value equality. @@ -1432,47 +2074,704 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches -def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: - """Pattern 11: torch.compile for pipeline graph caching.""" - facts = ensure_submission_facts(code) - code = facts.python_only - matches = [] - if RE_TORCH_COMPILE.search(code): - matches.append({ - "pattern": "TORCH_COMPILE_CACHE", - "severity": "low", - "evidence": "torch.compile() used", - }) - return matches +def _input_float_return_from_body(body: list[ast.stmt], input_names: set[str]) -> bool: + aliases: set[str] = set() + saw_return = False + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + if _is_input_float_call(stmt.value, input_names): + aliases.add(stmt.targets[0].id) + continue + if ( + isinstance(stmt, ast.With) + and stmt.items + and all(_is_nullcontext_call(item.context_expr) for item in stmt.items) + ): + if _input_float_return_from_body(stmt.body, input_names): + saw_return = True + continue + return False + if isinstance(stmt, ast.Return): + if _is_input_float_call(stmt.value, input_names): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + if isinstance(stmt.value, ast.IfExp): + branches = [stmt.value.body, stmt.value.orelse] + if all( + _is_input_float_call(branch, input_names) + or (isinstance(branch, ast.Name) and branch.id in aliases) + for branch in branches + ): + saw_return = True + continue + return False + return False + return saw_return -def detect_cuda_graph_python(code: str | SubmissionFacts) -> list[dict]: - """Pattern 12: Python CUDAGraph caching with data_ptr keys + .replay().""" - facts = ensure_submission_facts(code) - python_only = facts.python_only - matches = [] - signals = [] +def _none_guarded_return(stmt: ast.stmt, names: set[str]) -> bool: + if not isinstance(stmt, ast.If): + return False + test = stmt.test + if not ( + isinstance(test, ast.Compare) + and isinstance(test.left, ast.Name) + and test.left.id in names + and len(test.ops) == 1 + and isinstance(test.ops[0], ast.IsNot) + and len(test.comparators) == 1 + and _expr_is_none(test.comparators[0]) + and len(stmt.body) == 1 + and isinstance(stmt.body[0], ast.Return) + and isinstance(stmt.body[0].value, ast.Name) + and stmt.body[0].value.id == test.left.id + and not stmt.orelse + ): + return False + return not _body_has_calls(stmt.body) - has_graph = bool(RE_CUDA_GRAPH_PYTHON.search(python_only)) - has_replay = bool(RE_CUDA_GRAPH_REPLAY.search(python_only)) - has_cache_key = bool(RE_CUDA_GRAPH_CACHE_KEY.search(python_only)) - has_graph_cache = bool(RE_GRAPH_CACHE_NAME.search(python_only)) - if has_graph: - signals.append("torch.cuda.CUDAGraph()") - if has_replay: - signals.append(".replay()") - if has_cache_key: - signals.append("data_ptr() as graph cache key") - if has_graph_cache: - signals.append("graph cache dict") +def _optional_cache_input_float_return( + body: list[ast.stmt], + data_names: set[str], + cache_names: set[str], +) -> bool: + aliases: set[str] = set() + saw_guard = False + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not saw_guard and _none_guarded_return(stmt, cache_names): + saw_guard = True + continue + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + if _is_input_float_call(stmt.value, data_names): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + return saw_guard and ( + _is_input_float_call(stmt.value, data_names) + or (isinstance(stmt.value, ast.Name) and stmt.value.id in aliases) + ) + return False + return False - # Need at least CUDAGraph + replay or CUDAGraph + cache key. - # CUDA graph replay is a legitimate optimization (reduces kernel launch - # overhead). Keep at medium — exploits using CUDA graphs would also - # have MODULE_MUTATION / GLOBALS_MUTATION which are critical/high. - if has_graph and (has_replay or has_cache_key): - matches.append({ + +def _global_ifexp_input_float_return( + body: list[ast.stmt], + input_names: set[str], + global_names: set[str], + none_inited: set[str], +) -> bool: + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass, ast.Global)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not isinstance(stmt, ast.Return) or not isinstance(stmt.value, ast.IfExp): + return False + expr = stmt.value + if not ( + isinstance(expr.body, ast.Name) + and expr.body.id in global_names + and expr.body.id in none_inited + and _is_input_float_call(expr.orelse, input_names) + ): + return False + test = expr.test + return ( + isinstance(test, ast.Compare) + and isinstance(test.left, ast.Name) + and test.left.id == expr.body.id + and len(test.ops) == 1 + and isinstance(test.ops[0], ast.IsNot) + and len(test.comparators) == 1 + and _expr_is_none(test.comparators[0]) + ) + return False + + +def _nonlocal_input_float_state( + fn: ast.FunctionDef | ast.AsyncFunctionDef, + local_none: set[str], +) -> Optional[str]: + input_names = _function_input_names(fn) + nonlocal_names = { + name + for stmt in fn.body + if isinstance(stmt, ast.Nonlocal) + for name in stmt.names + } + state_names = nonlocal_names & local_none + for state_name in state_names: + has_guard = False + has_store = False + has_return = False + for stmt in fn.body: + if isinstance(stmt, ast.Nonlocal): + continue + if _none_guarded_return(stmt, {state_name}): + has_guard = True + continue + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1: + target = stmt.targets[0] + if ( + isinstance(target, ast.Name) + and target.id == state_name + and _is_input_float_call(stmt.value, input_names) + ): + has_store = True + continue + if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name) and stmt.value.id == state_name: + has_return = True + continue + return None + if has_guard and has_store and has_return: + return state_name + return None + + +def _is_nullcontext_call(expr: ast.AST | None) -> bool: + return ( + isinstance(expr, ast.Call) + and ( + (isinstance(expr.func, ast.Name) and expr.func.id == "nullcontext") + or ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr == "nullcontext" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id == "contextlib" + ) + ) + ) + + +def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint fabricates output by returning input.float().""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + function_defs: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + node.name: node + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + if input_names and _input_float_return_from_body(fn.body, input_names): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns an input tensor cast to float without compute", + }] + + positional_args = list(fn.args.posonlyargs) + list(fn.args.args) + if fn.name in _ENTRYPOINT_METHOD_NAMES and positional_args and positional_args[0].arg in {"self", "cls"}: + positional_args = positional_args[1:] + default_args = positional_args[-len(fn.args.defaults):] if fn.args.defaults else [] + cache_names = { + arg.arg + for arg, default in zip(default_args, fn.args.defaults) + if _expr_is_none(default) + } & input_names + data_names = input_names - cache_names + if cache_names and data_names and _optional_cache_input_float_return(fn.body, data_names, cache_names): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns optional cached output or input.float()", + }] + + global_names = { + name + for stmt in fn.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + if global_names and _global_ifexp_input_float_return( + fn.body, input_names, global_names, facts._none_inited + ): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns global sentinel output or input.float()", + }] + + local_none = { + target.id + for stmt in fn.body + if isinstance(stmt, ast.Assign) + and _expr_is_none(stmt.value) + for target in stmt.targets + if isinstance(target, ast.Name) + } + if local_none: + nonlocal_helpers = { + stmt.name + for stmt in fn.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + and _nonlocal_input_float_state(stmt, local_none) is not None + } + for stmt in fn.body: + if ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in nonlocal_helpers + and len(stmt.value.args) == 1 + and isinstance(stmt.value.args[0], ast.Name) + and stmt.value.args[0].id in input_names + and not stmt.value.keywords + ): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} delegates to nonlocal state storing input.float()", + }] + + none_slots = { + target.id + for stmt in fn.body + if isinstance(stmt, ast.Assign) + and _expr_is_none(stmt.value) + for target in stmt.targets + if isinstance(target, ast.Name) + } + for stmt in fn.body: + if not ( + isinstance(stmt, ast.If) + and isinstance(stmt.test, ast.Compare) + and isinstance(stmt.test.left, ast.Name) + and stmt.test.left.id in none_slots + and len(stmt.test.ops) == 1 + and len(stmt.test.comparators) == 1 + ): + continue + slot = stmt.test.left.id + assigned_float = any( + isinstance(inner, ast.Assign) + and any(isinstance(t, ast.Name) and t.id == slot for t in inner.targets) + and _is_input_float_call(inner.value, input_names) + for inner in stmt.body + stmt.orelse + ) + returned_slot = any( + isinstance(inner, ast.Return) + and isinstance(inner.value, ast.Name) + and inner.value.id == slot + for inner in stmt.body + stmt.orelse + ) + if assigned_float and returned_slot: + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} lazily returns cached input.float()", + }] + + return [] + + +def detect_indirect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Detect indirect entrypoints that still only return input.float().""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + stmt.name: stmt + for stmt in tree.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + classes: dict[str, ast.ClassDef] = { + stmt.name: stmt for stmt in tree.body if isinstance(stmt, ast.ClassDef) + } + + def function_returns_input_float(name: str | None) -> bool: + fn = functions.get(name or "") + return fn is not None and _input_float_return_from_body(fn.body, _function_input_names(fn)) + + def callable_value_returns_input_float(expr: ast.AST | None) -> bool: + if _lambda_returns_input_float(expr): + return True + if isinstance(expr, ast.Name): + return function_returns_input_float(expr.id) + if ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id in {"staticmethod", "classmethod"} + and expr.args + ): + return callable_value_returns_input_float(expr.args[0]) + return False + + def dynamic_type_returns_input_float(expr: ast.AST | None) -> bool: + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "type" + and len(expr.args) >= 3 + and isinstance(expr.args[2], ast.Dict) + ): + return False + for key, value in zip(expr.args[2].keys, expr.args[2].values): + if _static_string(key) == "__call__": + return callable_value_returns_input_float(value) + return False + + # K = type("K", (), {"__call__": lambda self, x: x.float()}); custom_kernel = K() + dynamic_classes: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not dynamic_type_returns_input_float(stmt.value): + continue + for target in stmt.targets: + if isinstance(target, ast.Name): + dynamic_classes.add(target.id) + if dynamic_classes: + dynamic_instances: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_classes + and not stmt.value.args + and not stmt.value.keywords + ): + for target in stmt.targets: + if isinstance(target, ast.Name): + dynamic_instances.add(target.id) + if is_entrypoint_name(target.id): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} is a dynamic type() callable returning input.float()", + }] + if any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + if isinstance(stmt.value, ast.Name) and stmt.value.id in dynamic_instances: + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} aliases a dynamic type() callable returning input.float()", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if ( + isinstance(stmt.value, ast.Call) + and not stmt.value.args + and not stmt.value.keywords + and dynamic_type_returns_input_float(stmt.value.func) + ): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} directly instantiates a dynamic type() callable returning input.float()", + }] + + # Base.__init_subclass__ installs Child.custom_kernel = lambda x: x.float() + subclass_export_bases: set[str] = set() + for cls in classes.values(): + for item in cls.body: + if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) or item.name != "__init_subclass__": + continue + cls_arg = item.args.args[0].arg if item.args.args else "cls" + for stmt in _iter_non_nested_nodes(item): + if not isinstance(stmt, ast.Assign): + continue + for target in stmt.targets: + if ( + isinstance(target, ast.Attribute) + and target.attr in ENTRYPOINT_NAMES + and isinstance(target.value, ast.Name) + and target.value.id == cls_arg + and callable_value_returns_input_float(stmt.value) + ): + subclass_export_bases.add(cls.name) + if subclass_export_bases: + subclass_exports = { + cls.name + for cls in classes.values() + if any(isinstance(base, ast.Name) and base.id in subclass_export_bases for base in cls.bases) + } + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if ( + isinstance(stmt.value, ast.Attribute) + and stmt.value.attr in ENTRYPOINT_NAMES + and isinstance(stmt.value.value, ast.Name) + and stmt.value.value.id in subclass_exports + ): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} is installed by __init_subclass__ as input.float()", + }] + + # class D: def __get__(self, obj, cls): return obj.x.float() + # class W: out = D(); def __init__(self, x): self.x = x + # def custom_kernel(x): return W(x).out + descriptor_slots: dict[str, str] = {} + for cls in classes.values(): + get_method = _method_from_class(cls, ("__get__",)) + if get_method is None or len(get_method.args.args) < 2: + continue + obj_name = get_method.args.args[1].arg + for stmt in get_method.body: + if isinstance(stmt, ast.Return): + slot = _is_input_attr_float_call(stmt.value, obj_name) + if slot: + descriptor_slots[cls.name] = slot + break + + if descriptor_slots: + wrapper_descriptor_slots: dict[tuple[str, str], str] = {} + wrapper_init_inputs: dict[str, set[str]] = {} + for cls in classes.values(): + attr_slots: dict[str, str] = {} + for stmt in cls.body: + if not isinstance(stmt, ast.Assign): + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in descriptor_slots + ): + for target in stmt.targets: + if isinstance(target, ast.Name): + attr_slots[target.id] = descriptor_slots[stmt.value.func.id] + if not attr_slots: + continue + init_method = _method_from_class(cls, ("__init__",)) + if init_method is None or len(init_method.args.args) < 2: + continue + self_name = init_method.args.args[0].arg + assigned_slots: set[str] = set() + for stmt in _iter_non_nested_nodes(init_method): + if not isinstance(stmt, ast.Assign): + continue + if not (_expr_names(stmt.value) & _function_input_names(init_method)): + continue + for target in stmt.targets: + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == self_name + ): + assigned_slots.add(target.attr) + wrapper_init_inputs[cls.name] = assigned_slots + for public_attr, source_slot in attr_slots.items(): + if source_slot in assigned_slots: + wrapper_descriptor_slots[(cls.name, public_attr)] = source_slot + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.Return): + continue + value = stmt.value + if not ( + isinstance(value, ast.Attribute) + and isinstance(value.value, ast.Call) + and isinstance(value.value.func, ast.Name) + and value.value.args + and isinstance(value.value.args[0], ast.Name) + and value.value.args[0].id in input_names + and (value.value.func.id, value.attr) in wrapper_descriptor_slots + ): + continue + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns descriptor-backed input.float()", + }] + + return [] + + +def detect_input_reduction_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns torch.any/all(input) as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + torch_aliases, imported = _torch_alias_sets(facts) + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + def reduction_input(expr: ast.AST | None, input_names: set[str]) -> bool: + if not isinstance(expr, ast.Call): + return False + if _torch_call_name(expr, torch_aliases, imported) not in {"any", "all"}: + return False + return ( + len(expr.args) == 1 + and isinstance(expr.args[0], ast.Name) + and expr.args[0].id in input_names + and not expr.keywords + ) + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + aliases: set[str] = set() + for stmt in fn.body: + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + if reduction_input(stmt.value, input_names): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if reduction_input(stmt.value, input_names): + return [{ + "pattern": "INPUT_REDUCTION_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns torch.any/all(input) as output", + }] + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + return [{ + "pattern": "INPUT_REDUCTION_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns torch.any/all(input) alias as output", + }] + return [] + + +def detect_self_matmul_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns an input multiplied by itself as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + torch_aliases, imported = _torch_alias_sets(facts) + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + def is_self_product(expr: ast.AST | None, input_names: set[str], aliases: set[str]) -> bool: + names = input_names | aliases + if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.MatMult): + if isinstance(expr.left, ast.Name) and expr.left.id in names: + return _is_transpose_of(expr.right, expr.left.id) or ( + isinstance(expr.right, ast.Name) and expr.right.id == expr.left.id + ) + if not isinstance(expr, ast.Call): + return False + call_name = _torch_call_name(expr, torch_aliases, imported) + if call_name in {"mm", "matmul"} and len(expr.args) >= 2: + first, second = expr.args[:2] + if isinstance(first, ast.Name) and first.id in names: + return _is_transpose_of(second, first.id) or ( + isinstance(second, ast.Name) and second.id == first.id + ) + if call_name == "bmm" and len(expr.args) >= 2: + first, second = expr.args[:2] + if isinstance(first, ast.Name) and first.id in aliases: + return _is_transpose_of(second, first.id) + return False + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + aliases: set[str] = set() + result_aliases: set[str] = set() + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + target = stmt.targets[0].id + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Attribute) + and stmt.value.func.attr == "unsqueeze" + and isinstance(stmt.value.func.value, ast.Name) + and stmt.value.func.value.id in input_names + ): + aliases.add(target) + if is_self_product(stmt.value, input_names, aliases): + result_aliases.add(target) + elif isinstance(stmt, ast.Return): + if is_self_product(stmt.value, input_names, aliases): + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns input self-matmul output", + }] + if isinstance(stmt.value, ast.Name) and stmt.value.id in result_aliases: + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns input self-matmul alias", + }] + return [] + + +def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: + """Pattern 11: torch.compile for pipeline graph caching.""" + facts = ensure_submission_facts(code) + code = facts.python_only + matches = [] + if RE_TORCH_COMPILE.search(code): + matches.append({ + "pattern": "TORCH_COMPILE_CACHE", + "severity": "low", + "evidence": "torch.compile() used", + }) + return matches + + +def detect_cuda_graph_python(code: str | SubmissionFacts) -> list[dict]: + """Pattern 12: Python CUDAGraph caching with data_ptr keys + .replay().""" + facts = ensure_submission_facts(code) + python_only = facts.python_only + matches = [] + signals = [] + + has_graph = bool(RE_CUDA_GRAPH_PYTHON.search(python_only)) + has_replay = bool(RE_CUDA_GRAPH_REPLAY.search(python_only)) + has_cache_key = bool(RE_CUDA_GRAPH_CACHE_KEY.search(python_only)) + has_graph_cache = bool(RE_GRAPH_CACHE_NAME.search(python_only)) + + if has_graph: + signals.append("torch.cuda.CUDAGraph()") + if has_replay: + signals.append(".replay()") + if has_cache_key: + signals.append("data_ptr() as graph cache key") + if has_graph_cache: + signals.append("graph cache dict") + + # Need at least CUDAGraph + replay or CUDAGraph + cache key. + # CUDA graph replay is a legitimate optimization (reduces kernel launch + # overhead). Keep at medium — exploits using CUDA graphs would also + # have MODULE_MUTATION / GLOBALS_MUTATION which are critical/high. + if has_graph and (has_replay or has_cache_key): + matches.append({ "pattern": "CUDA_GRAPH_PYTHON", "severity": "medium", "evidence": "Python CUDAGraph caching: " + ", ".join(signals), @@ -1736,6 +3035,70 @@ def detect_runtime_package_install(code: str | SubmissionFacts) -> list[dict]: "severity": "critical", "evidence": "Runtime pip install inside submission code (sandbox violation)", }] + tree = facts.ast_tree + if tree is not None: + os_aliases = {"os"} + subprocess_aliases = {"subprocess"} + socket_aliases = {"socket"} + imported_calls: dict[str, str] = {} + for node in facts._imports: + for alias in node.names: + if alias.name == "os": + os_aliases.add(alias.asname or alias.name) + elif alias.name == "subprocess": + subprocess_aliases.add(alias.asname or alias.name) + elif alias.name == "socket": + socket_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "subprocess": + for alias in node.names: + if alias.name in {"run", "call", "check_call", "check_output", "Popen"}: + imported_calls[alias.asname or alias.name] = f"subprocess.{alias.name}" + elif node.module == "socket": + for alias in node.names: + if alias.name in {"socket", "create_connection"}: + imported_calls[alias.asname or alias.name] = f"socket.{alias.name}" + + def risky_process_call(node: ast.Call) -> bool: + if any(kw.arg == "shell" and isinstance(kw.value, ast.Constant) and kw.value.value is True for kw in node.keywords): + return True + static_parts = [] + for arg in node.args[:2]: + if (value := _static_string(arg)) is not None: + static_parts.append(value) + elif isinstance(arg, (ast.List, ast.Tuple)): + for elt in arg.elts: + if (value := _static_string(elt)) is not None: + static_parts.append(value) + command = " ".join(static_parts).lower() + return bool(re.search(r'\b(?:pip|install|curl|wget|nc|netcat|ssh|scp|bash|sh|powershell|python\s+-m\s+pip)\b', command)) + + risky_calls: set[str] = set() + for fn in _entrypoint_function_nodes(facts): + for node in _iter_non_nested_nodes(fn): + if not isinstance(node, ast.Call): + continue + if isinstance(node.func, ast.Name) and node.func.id in imported_calls: + target = imported_calls[node.func.id] + if target.startswith("socket.") or risky_process_call(node): + risky_calls.add(target) + elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + owner = node.func.value.id + attr = node.func.attr + if owner in os_aliases and attr in {"popen", "system"}: + if risky_process_call(node): + risky_calls.add(f"{owner}.{attr}") + elif owner in subprocess_aliases and attr in {"run", "call", "check_call", "check_output", "Popen"}: + if risky_process_call(node): + risky_calls.add(f"{owner}.{attr}") + elif owner in socket_aliases and attr in {"socket", "create_connection"}: + risky_calls.add(f"{owner}.{attr}") + if risky_calls: + return [{ + "pattern": "RUNTIME_PACKAGE_INSTALL", + "severity": "critical", + "evidence": "entrypoint process/network call: " + ", ".join(sorted(risky_calls)), + }] return [] @@ -2069,7 +3432,7 @@ def detect_code_replacement(code: str | SubmissionFacts) -> list[dict]: "generate_input", "check_implementation", "clear_l2_cache", "clear_l2_cache_large", "_clone_data", "ref_kernel", "calculate_stats", "run_single_benchmark", "Stats", -}) +}) | BASE_TRUSTED_FUNCTIONS TRUSTED_HARNESS_NAMES = BASE_TRUSTED_HARNESS_NAMES @@ -2089,8 +3452,67 @@ def detect_harness_runtime_patching(code: str | SubmissionFacts) -> list[dict]: trusted_targets: set[str] = set() globals_write = False setattr_write = False + builtins_aliases = {"__builtins__"} + builtins_write = False + host_aliases: set[str] = set() + namespace_aliases: set[str] = set() + string_aliases: dict[str, str] = {} for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "builtins": + builtins_aliases.add(alias.asname or alias.name) + + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + static_value = _static_string(node.value) + if static_value is not None: + string_aliases[target.id] = static_value + if isinstance(node.value, ast.Name) and node.value.id in builtins_aliases: + builtins_aliases.add(target.id) + if ( + isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Attribute) + and isinstance(node.value.func.value, ast.Name) + and node.value.func.value.id == "importlib" + and node.value.func.attr == "import_module" + and node.value.args + and _static_string(node.value.args[0]) == "__main__" + ): + host_aliases.add(target.id) + dynamic_discovery.add("importlib.import_module('__main__')") + elif ( + isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "__import__" + and node.value.args + and _static_string(node.value.args[0]) == "__main__" + ): + host_aliases.add(target.id) + dynamic_discovery.add("__import__('__main__')") + elif ( + isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "vars" + and node.value.args + ): + arg = node.value.args[0] + if isinstance(arg, ast.Name) and arg.id in host_aliases: + namespace_aliases.add(target.id) + elif ( + isinstance(arg, ast.Call) + and isinstance(arg.func, ast.Attribute) + and isinstance(arg.func.value, ast.Name) + and arg.func.value.id == "importlib" + and arg.func.attr == "import_module" + and arg.args + and _static_string(arg.args[0]) == "__main__" + ): + namespace_aliases.add(target.id) + dynamic_discovery.add("vars(importlib.import_module('__main__'))") + if (isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "sys" @@ -2118,14 +3540,40 @@ def detect_harness_runtime_patching(code: str | SubmissionFacts) -> list[dict]: if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "setattr": if len(node.args) >= 2: name_arg = node.args[1] - if isinstance(name_arg, ast.Constant) and name_arg.value in TRUSTED_HARNESS_NAMES: + name_value = _static_string(name_arg) + if name_value is None and isinstance(name_arg, ast.Name): + name_value = string_aliases.get(name_arg.id) + if isinstance(node.args[0], ast.Name) and node.args[0].id in builtins_aliases: + builtins_write = True + if name_value in TRUSTED_HARNESS_NAMES: setattr_write = True + trusted_targets.add(name_value) elif isinstance(name_arg, ast.Name): setattr_write = True + if isinstance(node.args[0], ast.Name) and node.args[0].id in host_aliases: + dynamic_discovery.add("main module alias") if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Subscript): + target_root = _ast_root_name(target.value) + key = _static_string(target.slice) + if target_root in builtins_aliases: + builtins_write = True + if target_root in namespace_aliases: + if key in TRUSTED_HARNESS_NAMES: + globals_write = True + trusted_targets.add(key) + dynamic_discovery.add("vars(main)") + if ( + isinstance(target.value, ast.Attribute) + and target.value.attr == "__dict__" + and _ast_root_name(target.value) in host_aliases + and key in TRUSTED_HARNESS_NAMES + ): + globals_write = True + trusted_targets.add(key) + dynamic_discovery.add("main module __dict__") if any( isinstance(sub, ast.Attribute) and sub.attr == "__globals__" for sub in ast.walk(target) @@ -2133,10 +3581,18 @@ def detect_harness_runtime_patching(code: str | SubmissionFacts) -> list[dict]: globals_write = True # Direct attribute assignment: mod.calculate_stats = fake_fn if isinstance(target, ast.Attribute): + if isinstance(target.value, ast.Name) and target.value.id in builtins_aliases: + builtins_write = True if target.attr in TRUSTED_HARNESS_NAMES: setattr_write = True trusted_targets.add(target.attr) + if builtins_write: + return [{ + "pattern": "HARNESS_RUNTIME_PATCHING", + "severity": "critical", + "evidence": "builtins namespace mutated at runtime", + }] if dynamic_discovery and (globals_write or setattr_write) and len(trusted_targets) >= 1: channels = ", ".join(sorted(dynamic_discovery)) targets = ", ".join(sorted(trusted_targets)) @@ -2478,155 +3934,1451 @@ def detect_shape_output_replay(code: str | SubmissionFacts) -> list[dict]: return [] -def detect_timed_input_replay(code: str | SubmissionFacts) -> list[dict]: - """Pattern: timed-input preprocess/quant replay keyed by shape or pointer.""" +def detect_timed_input_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: timed-input preprocess/quant replay keyed by shape or pointer.""" + facts = ensure_submission_facts(code) + active = facts.python_active + indicators = [] + + if re.search(r'_q_fp8_cache', active) and re.search(r'q\.data_ptr\s*\(', active): + indicators.append("q_fp8 cached by q.data_ptr()") + if re.search(r"c\['_q_ptr'\]\s*!=\s*q_ptr", active) and re.search(r"c\['_q_fp8'\]", active): + indicators.append("stateful _q_ptr/_q_fp8 reuse") + if re.search(r'_q_scale', active) and re.search(r'q_scale\s*=\s*_q_scale', active): + indicators.append("reused q_scale buffer") + if re.search(r'_FP8_BUF_CACHE', active) and re.search(r'dynamic_per_tensor_quant', active): + indicators.append("shape-keyed fp8 quant buffer cache") + if re.search(r'_s\s*=\s*\[\s*-?1\s*,\s*None\s*\]', active) and re.search(r'return\s+_s\s*\[\s*1\s*\]', active): + indicators.append("saved output reused from mutable default state") + + # Only promote when the cached object is tied to live input quantization or + # transformed q tensors, not just generic workspace reuse. + if indicators and ( + re.search(r'q_fp8', active) + or re.search(r'dynamic_per_tensor_quant', active) + or "saved output reused from mutable default state" in indicators + ): + return [{ + "pattern": "TIMED_INPUT_REPLAY", + "severity": "critical", + "evidence": "timed-input quantization / transformed-Q replay: " + ", ".join(sorted(set(indicators))), + }] + return [] + + +def detect_pointer_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: single-slot output replay keyed by input pointer equality.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + def _is_data_ptr_call(expr: ast.AST | None) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "data_ptr" + ) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + pointer_aliases: set[str] = set() + saved_ptr = None + saved_out = None + + for child in ast.walk(node): + if isinstance(child, ast.Assign): + if len(child.targets) == 1 and isinstance(child.targets[0], ast.Name): + # Catches bare data_ptr() AND tuples/containers that contain data_ptr() + if _expr_has_data_ptr_fast(child.value, facts._nodes_with_data_ptr): + pointer_aliases.add(child.targets[0].id) + + def _pointer_pair(left: ast.AST, right: ast.AST) -> Optional[tuple[str, Optional[str]]]: + if isinstance(left, ast.Name) and left.id in pointer_aliases and isinstance(right, ast.Name): + return right.id, left.id + if isinstance(right, ast.Name) and right.id in pointer_aliases and isinstance(left, ast.Name): + return left.id, right.id + if _is_data_ptr_call(left) and isinstance(right, ast.Name): + return right.id, None + if _is_data_ptr_call(right) and isinstance(left, ast.Name): + return left.id, None + return None + + for idx, stmt in enumerate(node.body): + if not isinstance(stmt, ast.If): + continue + compare = stmt.test + if not (isinstance(compare, ast.Compare) and len(compare.ops) == 1 and isinstance(compare.ops[0], ast.NotEq)): + continue + pair = _pointer_pair(compare.left, compare.comparators[0]) + if pair is None: + continue + saved_ptr_name, pointer_alias_name = pair + stored_out_names: set[str] = set() + stores_ptr = False + for inner in ast.walk(stmt): + if not isinstance(inner, ast.Assign): + continue + for target in inner.targets: + if not isinstance(target, ast.Name): + continue + if target.id == saved_ptr_name: + if ( + (pointer_alias_name is not None and isinstance(inner.value, ast.Name) and inner.value.id == pointer_alias_name) + or _is_data_ptr_call(inner.value) + or _expr_has_data_ptr_fast(inner.value, facts._nodes_with_data_ptr) + ): + stores_ptr = True + elif _looks_output_value_name(target.id) or _looks_stateful_name(target.id): + stored_out_names.add(target.id) + if not stores_ptr or not stored_out_names: + continue + for follow in node.body[idx + 1:]: + if isinstance(follow, ast.Return) and follow.value is not None: + if _expr_names(follow.value) & stored_out_names: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} refreshes saved output on pointer mismatch then replays it", + }] + break + + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + compare = child.test + if not (isinstance(compare, ast.Compare) and len(compare.ops) == 1): + continue + op = compare.ops[0] + left = compare.left + right = compare.comparators[0] + + if isinstance(op, ast.In): + # Pattern: if cache_key in cache_dict — dict-based pointer cache + if not (isinstance(left, ast.Name) and left.id in pointer_aliases): + continue + dict_name = _ast_root_name(right) + if dict_name is None: + continue + has_return = any( + isinstance(n, ast.Return) and n.value is not None + for n in ast.walk(child) + ) + if not has_return: + continue + dict_written = any( + isinstance(n, ast.Assign) and any( + _ast_root_name(t) == dict_name for t in n.targets + ) + for n in ast.walk(node) + ) or any( + isinstance(n, ast.Assign) and any( + _ast_root_name(t) == dict_name for t in n.targets + ) + for n in tree.body + if isinstance(n, ast.Assign) + ) + if dict_written: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns saved output from dict keyed by input data_ptr", + }] + continue + + if not isinstance(op, ast.Eq): + continue + pair = _pointer_pair(left, right) + if pair is None: + continue + returns = [stmt for stmt in child.body if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name)] + if not returns: + continue + saved_ptr = pair[0] + saved_out = returns[0].value.id + break + + if not saved_ptr or not saved_out: + continue + + stores_ptr = False + stores_out = False + for child in ast.walk(node): + if not isinstance(child, ast.Assign): + continue + for target in child.targets: + if not isinstance(target, ast.Name): + continue + if target.id == saved_ptr: + if (isinstance(child.value, ast.Name) and child.value.id in pointer_aliases) or _is_data_ptr_call(child.value): + stores_ptr = True + elif target.id == saved_out: + stores_out = True + + if stores_ptr and stores_out: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns saved output when input data_ptr matches previous pointer", + }] + + return [] + + +def detect_pointer_cache_hit_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect pointer/id-derived cache hits that return cached outputs.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + function_defs: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + node.name: node + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + def cache_ok(name: Optional[str]) -> bool: + if not name: + return False + lowered = name.lower() + if any(token in lowered for token in ("compile", "compiled", "kernel", "module", "plan", "workspace", "preprocess", "scale", "meta")): + return False + return any(token in lowered for token in ("cache", "saved", "memo", "result", "out")) + + def identity_expr(expr: ast.AST | None, input_names: set[str], aliases: set[str]) -> bool: + if expr is None: + return False + if isinstance(expr, ast.Name) and expr.id in input_names: + return True + if _expr_has_data_ptr_fast(expr, facts._nodes_with_data_ptr): + return bool(_expr_names(expr) & (input_names | aliases)) + for node in ast.walk(expr): + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "id" + and node.args + and (_expr_names(node.args[0]) & (input_names | aliases)) + ): + return True + return False + + def returned_names(body: list[ast.stmt]) -> set[str]: + names: set[str] = set() + for stmt in body: + if isinstance(stmt, ast.Return) and stmt.value is not None: + names.update(_expr_names(stmt.value)) + return names + + def thread_target_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + targets: set[str] = set() + for node in _iter_non_nested_nodes(fn): + if not isinstance(node, ast.Call): + continue + is_thread_ctor = ( + (isinstance(node.func, ast.Name) and node.func.id == "Thread") + or (isinstance(node.func, ast.Attribute) and node.func.attr == "Thread") + ) + if not is_thread_ctor: + continue + for kw in node.keywords: + if kw.arg == "target" and isinstance(kw.value, ast.Name): + targets.add(kw.value.id) + return targets + + def helper_identity_cache_stores(helper: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + helper_inputs = _function_input_names(helper) + stores: set[str] = set() + for node in _iter_non_nested_nodes(helper): + if not isinstance(node, ast.Assign): + continue + for target in node.targets: + if not isinstance(target, ast.Subscript): + continue + cache_name = _ast_root_name(target.value) + if not cache_name: + continue + if not (_expr_names(target.slice) & helper_inputs and _expr_names(node.value) & helper_inputs): + continue + lowered = cache_name.lower() + if cache_ok(cache_name) or "mailbox" in lowered: + stores.add(cache_name) + return stores + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + aliases = set(input_names) + key_aliases: set[str] = set() + cache_hits: dict[str, tuple[str, str]] = {} + cache_stores: set[tuple[str, str]] = set() + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.Assign): + continue + value_names = _expr_names(stmt.value) + for target in stmt.targets: + if isinstance(target, ast.Name): + if value_names & aliases: + aliases.add(target.id) + if identity_expr(stmt.value, input_names, aliases): + key_aliases.add(target.id) + aliases.add(target.id) + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + if not isinstance(target, ast.Name): + continue + value = stmt.value + cache_name = None + key_expr = None + if isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute): + if value.func.attr in {"get", "pop"} and value.args: + cache_name = _ast_root_name(value.func.value) + key_expr = value.args[0] + elif isinstance(value, ast.Subscript): + cache_name = _ast_root_name(value.value) + key_expr = value.slice + if cache_ok(cache_name) and ( + (isinstance(key_expr, ast.Name) and key_expr.id in key_aliases) + or identity_expr(key_expr, input_names, aliases) + ): + cache_hits[target.id] = (cache_name or "", target.id) + + for target in stmt.targets: + if not isinstance(target, ast.Subscript): + continue + cache_name = _ast_root_name(target.value) + if not cache_ok(cache_name): + continue + key_expr = target.slice + if not ( + (isinstance(key_expr, ast.Name) and key_expr.id in key_aliases) + or identity_expr(key_expr, input_names, aliases) + ): + continue + if _expr_names(stmt.value) & input_names or _expr_contains_input_derived_call(stmt.value, input_names): + cache_stores.add((cache_name or "", "identity")) + elif any(_looks_output_value_name(name) for name in _expr_names(stmt.value)): + cache_stores.add((cache_name or "", "identity")) + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.If) or _body_has_calls(stmt.body): + continue + names = returned_names(stmt.body) + for hit_name, (cache_name, _) in cache_hits.items(): + if hit_name in names and (cache_name, "identity") in cache_stores: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns cache hit keyed by input pointer/id", + }] + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.For): + continue + iter_expr = stmt.iter + if ( + isinstance(iter_expr, ast.Call) + and isinstance(iter_expr.func, ast.Name) + and iter_expr.func.id == "list" + and iter_expr.args + ): + iter_expr = iter_expr.args[0] + if not ( + isinstance(iter_expr, ast.Call) + and isinstance(iter_expr.func, ast.Attribute) + and iter_expr.func.attr == "items" + and isinstance(stmt.target, ast.Tuple) + and len(stmt.target.elts) == 2 + and isinstance(stmt.target.elts[0], ast.Name) + and isinstance(stmt.target.elts[1], ast.Name) + ): + continue + cache_name = _ast_root_name(iter_expr.func.value) + if not cache_ok(cache_name): + continue + key_name = stmt.target.elts[0].id + value_name = stmt.target.elts[1].id + for child in stmt.body: + if not isinstance(child, ast.If) or _body_has_calls(child.body): + continue + if value_name not in returned_names(child.body): + continue + if not identity_expr(child.test, input_names, aliases | {key_name}): + continue + if (cache_name or "", "identity") in cache_stores: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns cache item whose key matches input pointer/id", + }] + + thread_cache_names: set[str] = set() + for helper_name in thread_target_names(fn): + helper = function_defs.get(helper_name) + if helper is not None: + thread_cache_names.update(helper_identity_cache_stores(helper)) + if thread_cache_names: + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + value = stmt.value + cache_name = None + key_expr = None + if isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute): + if value.func.attr in {"get", "pop"} and value.args: + cache_name = _ast_root_name(value.func.value) + key_expr = value.args[0] + elif isinstance(value, ast.Subscript): + cache_name = _ast_root_name(value.value) + key_expr = value.slice + if cache_name not in thread_cache_names: + continue + if ( + (isinstance(key_expr, ast.Name) and key_expr.id in key_aliases) + or identity_expr(key_expr, input_names, aliases) + ): + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns thread-populated cache keyed by input identity", + }] + + return [] + + +def detect_partial_bound_storage_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect functools.partial entrypoints that replay from bound mutable state.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + node.name: node + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + functools_aliases = {"functools"} + partial_names = {"partial"} + types_aliases = {"types"} + namespace_names = {"SimpleNamespace"} + for node in facts._imports: + for alias in node.names: + if alias.name == "functools": + functools_aliases.add(alias.asname or alias.name) + elif alias.name == "types": + types_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "functools": + for alias in node.names: + if alias.name == "partial": + partial_names.add(alias.asname or alias.name) + elif node.module == "types": + for alias in node.names: + if alias.name == "SimpleNamespace": + namespace_names.add(alias.asname or alias.name) + + def is_partial_call(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Call): + return False + if isinstance(expr.func, ast.Name): + return expr.func.id in partial_names + return ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr == "partial" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in functools_aliases + ) + + def mutable_literal(expr: ast.AST | None) -> bool: + if isinstance(expr, (ast.List, ast.Dict, ast.Set)): + return True + if not isinstance(expr, ast.Call): + return False + if isinstance(expr.func, ast.Name) and expr.func.id in namespace_names: + return True + return ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr == "SimpleNamespace" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in types_aliases + ) + + mutable_roots: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign) or not mutable_literal(stmt.value): + continue + for target in stmt.targets: + if isinstance(target, ast.Name): + mutable_roots.add(target.id) + + def mutable_bound_arg(expr: ast.AST | None) -> bool: + return mutable_literal(expr) or ( + isinstance(expr, ast.Name) and expr.id in mutable_roots + ) + + def return_from_bound(expr: ast.AST | None, bound_params: set[str], aliases: set[str]) -> bool: + if isinstance(expr, ast.Name): + return expr.id in aliases + return isinstance(expr, (ast.Subscript, ast.Attribute)) and _ast_root_name(expr) in bound_params + + mutating_methods = { + "add", "append", "clear", "extend", "insert", "pop", "popitem", + "remove", "rotate", "setdefault", "update", + } + + def partial_bindings() -> list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, set[str]]]: + bindings: list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, set[str]]] = [] + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if not is_partial_call(stmt.value) or not stmt.value.args: + continue + target_name = stmt.value.args[0].id if isinstance(stmt.value.args[0], ast.Name) else None + target_fn = functions.get(target_name or "") + if target_fn is None: + continue + positional_params = list(target_fn.args.posonlyargs) + list(target_fn.args.args) + bound_params: set[str] = set() + for arg, value in zip(positional_params, stmt.value.args[1:]): + if mutable_bound_arg(value): + bound_params.add(arg.arg) + for keyword in stmt.value.keywords: + if keyword.arg and mutable_bound_arg(keyword.value): + bound_params.add(keyword.arg) + if bound_params: + bindings.append((target_fn, bound_params)) + return bindings + + def function_replays_bound_storage(fn: ast.FunctionDef | ast.AsyncFunctionDef, bound_params: set[str]) -> bool: + input_params = _function_input_names(fn) - bound_params + if not input_params: + return False + tainted = set(input_params) + stores_bound_output = False + has_fast_return = False + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + if _expr_names(stmt.value) & tainted or _expr_contains_input_derived_call(stmt.value, input_params): + for target in stmt.targets: + tainted.update(_target_names(target)) + for target in stmt.targets: + if _ast_root_name(target) in bound_params and ( + _expr_names(stmt.value) & tainted + or _expr_contains_input_derived_call(stmt.value, input_params) + or any(_looks_output_value_name(name) for name in _expr_names(stmt.value)) + ): + stores_bound_output = True + elif isinstance(stmt, ast.AugAssign): + if _ast_root_name(stmt.target) in bound_params: + stores_bound_output = True + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr in mutating_methods + and _ast_root_name(stmt.func.value) in bound_params + ): + if ( + any(_expr_names(arg) & tainted for arg in stmt.args) + or any(_expr_contains_input_derived_call(arg, input_params) for arg in stmt.args) + or stmt.func.attr in {"clear", "pop", "popitem", "remove"} + ): + stores_bound_output = True + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if _expr_names(stmt.test) & input_params: + continue + if not (_expr_names(stmt.test) & bound_params): + continue + aliases: set[str] = set() + for inner in stmt.body: + if isinstance(inner, ast.Assign): + if isinstance(inner.value, (ast.Subscript, ast.Attribute)) and _ast_root_name(inner.value) in bound_params: + for target in inner.targets: + aliases.update(_target_names(target)) + elif isinstance(inner, ast.Return) and return_from_bound(inner.value, bound_params, aliases): + has_fast_return = True + + return stores_bound_output and has_fast_return + + for fn, bound_params in partial_bindings(): + if function_replays_bound_storage(fn, bound_params): + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} partial replays mutated bound mutable state", + }] + + return [] + + +def detect_class_pointer_sentinel_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect class __call__ entrypoints replaying state by input data_ptr.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + none_roots = set(facts._none_inited) + if not none_roots: + return [] + + def data_ptr_owner(expr: ast.AST | None) -> Optional[str]: + if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "data_ptr": + return _ast_root_name(expr.func.value) + return None + + def compare_state_none(expr: ast.AST | None) -> Optional[str]: + if not isinstance(expr, ast.Compare) or len(expr.ops) != 1 or len(expr.comparators) != 1: + return None + if not isinstance(expr.ops[0], (ast.Is, ast.IsNot)): + return None + left_root = _ast_root_name(expr.left) + right_root = _ast_root_name(expr.comparators[0]) + if _expr_is_none(expr.left) and right_root in none_roots: + return right_root + if _expr_is_none(expr.comparators[0]) and left_root in none_roots: + return left_root + return None + + def compare_input_state_ptr(expr: ast.AST | None, input_names: set[str]) -> Optional[str]: + if not isinstance(expr, ast.Compare) or len(expr.ops) != 1 or not isinstance(expr.ops[0], ast.Eq): + return None + if len(expr.comparators) != 1: + return None + left_owner = data_ptr_owner(expr.left) + right_owner = data_ptr_owner(expr.comparators[0]) + if left_owner in input_names and right_owner in none_roots: + return right_owner + if right_owner in input_names and left_owner in none_roots: + return left_owner + return None + + for cls in facts._class_defs: + if not is_entrypoint_name(cls.name): + continue + method = _method_from_class(cls, ("__call__",)) + if method is None: + continue + input_names = _function_input_names(method) + if not input_names: + continue + for stmt in _iter_non_nested_nodes(method): + if not isinstance(stmt, ast.If) or _body_has_calls(stmt.body): + continue + tests = stmt.test.values if isinstance(stmt.test, ast.BoolOp) else [stmt.test] + none_states = {state for test in tests if (state := compare_state_none(test))} + ptr_states = {state for test in tests if (state := compare_input_state_ptr(test, input_names))} + replay_states = none_states & ptr_states + if not replay_states: + continue + returns_replay_state = any( + isinstance(inner, ast.Return) + and inner.value is not None + and _ast_root_name(inner.value) in replay_states + for inner in stmt.body + ) + if returns_replay_state: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"class {cls.name}.__call__ replays state guarded by input data_ptr", + }] + + return [] + + +def detect_mutable_default_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect persistent default-argument containers used as replay state.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + mutating_methods = { + "add", "append", "clear", "extend", "insert", "pop", "popitem", + "remove", "rotate", "setdefault", "update", + } + + def is_mutable_default(expr: ast.AST | None) -> bool: + return isinstance(expr, (ast.List, ast.Dict, ast.Set)) or ( + isinstance(expr, ast.Tuple) and bool(expr.elts) + ) + + def return_from_default(expr: ast.AST | None, default_names: set[str], aliases: set[str]) -> bool: + if isinstance(expr, ast.Name) and expr.id in aliases: + return True + return isinstance(expr, (ast.Subscript, ast.Attribute, ast.Name)) and _ast_root_name(expr) in default_names + + for fn in _entrypoint_function_nodes(facts): + positional = list(fn.args.posonlyargs) + list(fn.args.args) + default_args = positional[-len(fn.args.defaults):] if fn.args.defaults else [] + default_names = { + arg.arg for arg, default in zip(default_args, fn.args.defaults) if is_mutable_default(default) + } + if not default_names: + continue + input_names = _function_input_names(fn) - default_names + if not input_names: + continue + tainted = set(input_names) + aliases: set[str] = set() + stores_default = False + has_fast_return = False + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + value_tainted = bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names) + if value_tainted: + for target in stmt.targets: + tainted.update(_target_names(target)) + for target in stmt.targets: + if isinstance(target, ast.Name) and _ast_root_name(stmt.value) in default_names: + aliases.add(target.id) + if _ast_root_name(target) in default_names and ( + value_tainted + or any(_looks_output_value_name(name) for name in _expr_names(stmt.value)) + ): + stores_default = True + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr in mutating_methods + and _ast_root_name(stmt.func.value) in default_names + ): + if any((_expr_names(arg) & tainted) or _expr_contains_input_derived_call(arg, input_names) for arg in stmt.args): + stores_default = True + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(inner, ast.Return) + and return_from_default(inner.value, default_names, aliases) + for inner in stmt.body + ): + has_fast_return = True + + if stores_default and has_fast_return: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} replays output from persistent mutable default state", + }] + + return [] + + +def detect_function_attribute_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect replay state stored on the exported function object.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + def function_attr(expr: ast.AST | None, fn_name: str) -> Optional[str]: + if ( + isinstance(expr, ast.Attribute) + and isinstance(expr.value, ast.Name) + and expr.value.id == fn_name + ): + return expr.attr + if ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "getattr" + and len(expr.args) >= 2 + and isinstance(expr.args[0], ast.Name) + and expr.args[0].id == fn_name + and isinstance(expr.args[1], ast.Constant) + and isinstance(expr.args[1].value, str) + ): + return expr.args[1].value + return None + + def uses_attr(expr: ast.AST | None, fn_name: str, aliases: set[str]) -> bool: + if expr is None: + return False + if isinstance(expr, ast.Name) and expr.id in aliases: + return True + return any(function_attr(node, fn_name) is not None for node in ast.walk(expr)) + + for fn in _entrypoint_function_nodes(facts): + if not is_entrypoint_name(fn.name): + continue + input_names = _function_input_names(fn) + if not input_names: + continue + aliases: set[str] = set() + stores_attr = False + has_fast_return = False + tainted = set(input_names) + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + if function_attr(stmt.value, fn.name): + for target in stmt.targets: + aliases.update(_target_names(target)) + value_tainted = bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names) + if value_tainted: + for target in stmt.targets: + tainted.update(_target_names(target)) + for target in stmt.targets: + if function_attr(target, fn.name) and value_tainted: + stores_attr = True + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Name) + and stmt.func.id == "setattr" + and len(stmt.args) >= 3 + and isinstance(stmt.args[0], ast.Name) + and stmt.args[0].id == fn.name + and (bool(_expr_names(stmt.args[2]) & tainted) or _expr_contains_input_derived_call(stmt.args[2], input_names)) + ): + stores_attr = True + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(inner, ast.Return) and uses_attr(inner.value, fn.name, aliases) + for inner in stmt.body + ): + has_fast_return = True + + if stores_attr and has_fast_return: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} replays output stored on its function object", + }] + + return [] + + +def detect_nonlocal_factory_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect factory-returned closures that replay nonlocal state.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + factories: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + stmt.name: stmt + for stmt in tree.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + exported_factories: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + if stmt.value.func.id in factories: + exported_factories.add(stmt.value.func.id) + + def inner_replays_state(inner: ast.FunctionDef | ast.AsyncFunctionDef, state_names: set[str]) -> bool: + nonlocal_names = { + name + for stmt in inner.body + if isinstance(stmt, ast.Nonlocal) + for name in stmt.names + } & state_names + if not nonlocal_names: + return False + input_names = _function_input_names(inner) + tainted = set(input_names) + has_fast_return = False + stores_state = False + for stmt in inner.body: + if isinstance(stmt, ast.Nonlocal): + continue + if isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(ret, ast.Return) + and isinstance(ret.value, ast.Name) + and ret.value.id in nonlocal_names + for ret in stmt.body + ): + has_fast_return = True + continue + if isinstance(stmt, ast.Assign): + value_tainted = bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names) + if value_tainted: + for target in stmt.targets: + tainted.update(_target_names(target)) + if any(isinstance(target, ast.Name) and target.id in nonlocal_names for target in stmt.targets) and value_tainted: + stores_state = True + continue + if isinstance(stmt, ast.Return): + continue + return has_fast_return and stores_state + + for factory_name in exported_factories: + factory = factories[factory_name] + local_none = { + target.id + for stmt in factory.body + if isinstance(stmt, ast.Assign) and _expr_is_none(stmt.value) + for target in stmt.targets + if isinstance(target, ast.Name) + } + if not local_none: + continue + nested = { + stmt.name: stmt + for stmt in factory.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + returned = _factory_returned_function(factory) + candidates = [returned] if returned is not None else list(nested.values()) + for inner in candidates: + if inner is not None and inner_replays_state(inner, local_none): + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} replays nonlocal factory state", + }] + + return [] + + +def detect_contextvar_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect contextvars used as inter-call output replay storage.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + contextvars_aliases = {"contextvars"} + contextvar_names = {"ContextVar"} + for node in facts._imports: + for alias in node.names: + if alias.name == "contextvars": + contextvars_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "contextvars": + for alias in node.names: + if alias.name == "ContextVar": + contextvar_names.add(alias.asname or alias.name) + + context_slots: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + is_contextvar = ( + isinstance(stmt.value, ast.Call) + and ( + (isinstance(stmt.value.func, ast.Name) and stmt.value.func.id in contextvar_names) + or ( + isinstance(stmt.value.func, ast.Attribute) + and stmt.value.func.attr == "ContextVar" + and isinstance(stmt.value.func.value, ast.Name) + and stmt.value.func.value.id in contextvars_aliases + ) + ) + ) + if is_contextvar: + for target in stmt.targets: + if isinstance(target, ast.Name): + context_slots.add(target.id) + if not context_slots: + return [] + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + tainted = set(input_names) + get_aliases: set[str] = set() + has_fast_return = False + stores_output = False + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Attribute) + and stmt.value.func.attr == "get" + and _ast_root_name(stmt.value.func.value) in context_slots + ): + for target in stmt.targets: + get_aliases.update(_target_names(target)) + if bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names): + for target in stmt.targets: + tainted.update(_target_names(target)) + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(inner, ast.Return) + and isinstance(inner.value, ast.Name) + and inner.value.id in get_aliases + for inner in stmt.body + ): + has_fast_return = True + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr == "set" + and _ast_root_name(stmt.func.value) in context_slots + and stmt.args + and (bool(_expr_names(stmt.args[0]) & tainted) or _expr_contains_input_derived_call(stmt.args[0], input_names)) + ): + stores_output = True + if has_fast_return and stores_output: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} replays output through ContextVar state", + }] + + return [] + + +def detect_alias_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect local aliases to captured state used for output replay.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + mutating_methods = {"append", "extend", "insert", "update", "setdefault", "add", "__setitem__"} + top_level_defs = { + stmt.name + for stmt in tree.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) + } + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + if not input_names: + continue + local_names = { + name + for stmt in _iter_non_nested_nodes(fn) + if isinstance(stmt, ast.Assign) + for target in stmt.targets + for name in _target_names(target) + } + captured_roots = { + name + for name in _expr_names(fn) + if name not in input_names and name not in local_names and name not in top_level_defs + } + aliases: set[str] = set() + tainted = set(input_names) + stores_alias = False + has_fast_return = False + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + target_name = stmt.targets[0].id + value = stmt.value + value_root = _ast_root_name(value) + if ( + ( + isinstance(value, (ast.Attribute, ast.Subscript)) + and value_root not in input_names + and value_root is not None + ) + or ( + isinstance(value, ast.Call) + and isinstance(value.func, ast.Name) + and value.func.id in {"vars", "list", "tuple"} + and value.args + and _ast_root_name(value.args[0]) not in input_names + ) + ): + aliases.add(target_name) + value_tainted = bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names) + if value_tainted: + for target in stmt.targets: + tainted.update(_target_names(target)) + for target in stmt.targets: + if _ast_root_name(target) in aliases and value_tainted: + stores_alias = True + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr in mutating_methods + and _ast_root_name(stmt.func.value) in aliases + and any((bool(_expr_names(arg) & tainted) or _expr_contains_input_derived_call(arg, input_names)) for arg in stmt.args) + ): + stores_alias = True + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(inner, ast.Return) + and inner.value is not None + and _ast_root_name(inner.value) in aliases + for inner in stmt.body + ): + has_fast_return = True + + if stores_alias and has_fast_return: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} replays output through local alias to captured state", + }] + + return [] + + +def detect_partial_bound_method_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect functools.partial around a bound method that replays self state.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + classes = {stmt.name: stmt for stmt in tree.body if isinstance(stmt, ast.ClassDef)} + instances: dict[str, str] = {} + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in classes + ): + for target in stmt.targets: + if isinstance(target, ast.Name): + instances[target.id] = stmt.value.func.id + + functools_aliases = {"functools"} + partial_names = {"partial"} + for node in facts._imports: + for alias in node.names: + if alias.name == "functools": + functools_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "functools": + for alias in node.names: + if alias.name == "partial": + partial_names.add(alias.asname or alias.name) + + def is_partial_call(expr: ast.AST | None) -> bool: + return isinstance(expr, ast.Call) and ( + (isinstance(expr.func, ast.Name) and expr.func.id in partial_names) + or ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr == "partial" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in functools_aliases + ) + ) + + def resolve_bound_method(expr: ast.AST | None) -> tuple[ast.ClassDef, str] | None: + if not isinstance(expr, ast.Attribute): + return None + owner = expr.value + class_name = None + if isinstance(owner, ast.Name): + class_name = instances.get(owner.id) + elif isinstance(owner, ast.Call) and isinstance(owner.func, ast.Name) and owner.func.id in classes: + class_name = owner.func.id + if class_name in classes: + return classes[class_name], expr.attr + return None + + def method_replays_self(method: ast.FunctionDef | ast.AsyncFunctionDef) -> bool: + if not method.args.args: + return False + self_name = method.args.args[0].arg + input_names = _function_input_names(method) + tainted = set(input_names) + stores_attr = False + has_fast_return = False + for stmt in _iter_non_nested_nodes(method): + if isinstance(stmt, ast.Assign): + value_tainted = bool(_expr_names(stmt.value) & tainted) or _expr_contains_input_derived_call(stmt.value, input_names) + if value_tainted: + for target in stmt.targets: + tainted.update(_target_names(target)) + for target in stmt.targets: + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == self_name + and value_tainted + ): + stores_attr = True + elif isinstance(stmt, ast.If): + if _body_has_calls(stmt.body): + continue + if any( + isinstance(inner, ast.Return) + and isinstance(inner.value, (ast.Attribute, ast.Subscript)) + and _ast_root_name(inner.value) == self_name + for inner in stmt.body + ): + has_fast_return = True + return stores_attr and has_fast_return + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if not is_partial_call(stmt.value) or not stmt.value.args: + continue + bound = resolve_bound_method(stmt.value.args[0]) + if bound is None: + continue + cls, method_name = bound + method = _method_from_class(cls, (method_name,)) + if method is not None and method_replays_self(method): + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} is partial-bound to a method replaying self state", + }] + + return [] + + +def detect_generator_send_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect persistent generator send() state machines used for replay.""" facts = ensure_submission_facts(code) - active = facts.python_active - indicators = [] + tree = facts.ast_tree + if tree is None: + return [] - if re.search(r'_q_fp8_cache', active) and re.search(r'q\.data_ptr\s*\(', active): - indicators.append("q_fp8 cached by q.data_ptr()") - if re.search(r"c\['_q_ptr'\]\s*!=\s*q_ptr", active) and re.search(r"c\['_q_fp8'\]", active): - indicators.append("stateful _q_ptr/_q_fp8 reuse") - if re.search(r'_q_scale', active) and re.search(r'q_scale\s*=\s*_q_scale', active): - indicators.append("reused q_scale buffer") - if re.search(r'_FP8_BUF_CACHE', active) and re.search(r'dynamic_per_tensor_quant', active): - indicators.append("shape-keyed fp8 quant buffer cache") - if re.search(r'_s\s*=\s*\[\s*-?1\s*,\s*None\s*\]', active) and re.search(r'return\s+_s\s*\[\s*1\s*\]', active): - indicators.append("saved output reused from mutable default state") + functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = { + stmt.name: stmt + for stmt in tree.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + def generator_replays(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> bool: + send_names = { + target.id + for stmt in ast.walk(fn) + if isinstance(stmt, ast.Assign) + and isinstance(stmt.value, (ast.Yield, ast.YieldFrom)) + for target in stmt.targets + if isinstance(target, ast.Name) + } + if not send_names: + return False + tainted = set(send_names) + state_names: set[str] = set() + for stmt in ast.walk(fn): + if isinstance(stmt, ast.Assign): + value_tainted = bool(_expr_names(stmt.value) & tainted) + if value_tainted: + for target in stmt.targets: + names = _target_names(target) + tainted.update(names) + state_names.update(names) + yielded_state = any( + isinstance(stmt, (ast.Yield, ast.YieldFrom)) + and bool(_expr_names(stmt.value) & state_names) + for stmt in ast.walk(fn) + ) + return bool(state_names and yielded_state) + + replay_generators = {name for name, fn in functions.items() if generator_replays(fn)} + if not replay_generators: + return [] + + generator_instances: set[str] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id in replay_generators: + for target in stmt.targets: + if isinstance(target, ast.Name): + generator_instances.add(target.id) + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + for stmt in _iter_non_nested_nodes(fn): + if ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Attribute) + and stmt.value.func.attr == "send" + and _ast_root_name(stmt.value.func.value) in generator_instances + and stmt.value.args + and _expr_names(stmt.value.args[0]) & input_names + ): + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} sends input into persistent replay generator", + }] - # Only promote when the cached object is tied to live input quantization or - # transformed q tensors, not just generic workspace reuse. - if indicators and ( - re.search(r'q_fp8', active) - or re.search(r'dynamic_per_tensor_quant', active) - or "saved output reused from mutable default state" in indicators - ): - return [{ - "pattern": "TIMED_INPUT_REPLAY", - "severity": "critical", - "evidence": "timed-input quantization / transformed-Q replay: " + ", ".join(sorted(set(indicators))), - }] return [] -def detect_pointer_replay(code: str | SubmissionFacts) -> list[dict]: - """Pattern: single-slot output replay keyed by input pointer equality.""" +def detect_class_self_pointer_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect class callables replaying self.state under input data_ptr equality.""" facts = ensure_submission_facts(code) tree = facts.ast_tree if tree is None: return [] - entrypoint_name = entrypoint_label(facts.entrypoint_name) - - def _is_data_ptr_call(expr: ast.AST | None) -> bool: - return ( - isinstance(expr, ast.Call) - and isinstance(expr.func, ast.Attribute) - and expr.func.attr == "data_ptr" - ) - for node in ast.walk(tree): - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + exported_classes: set[str] = {cls.name for cls in facts._class_defs if is_entrypoint_name(cls.name)} + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): continue - if not is_entrypoint_name(node.name): + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + exported_classes.add(stmt.value.func.id) - pointer_aliases: set[str] = set() - saved_ptr = None - saved_out = None - - for child in ast.walk(node): - if isinstance(child, ast.Assign): - if len(child.targets) == 1 and isinstance(child.targets[0], ast.Name): - # Catches bare data_ptr() AND tuples/containers that contain data_ptr() - if _expr_has_data_ptr_fast(child.value, facts._nodes_with_data_ptr): - pointer_aliases.add(child.targets[0].id) + def data_ptr_owner(expr: ast.AST | None) -> Optional[str]: + if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "data_ptr": + return _ast_root_name(expr.func.value) + return None - for child in ast.walk(node): - if not isinstance(child, ast.If): - continue - compare = child.test - if not (isinstance(compare, ast.Compare) and len(compare.ops) == 1): + for cls in facts._class_defs: + if cls.name not in exported_classes: + continue + method = _method_from_class(cls, ("__call__",)) + if method is None or not method.args.args: + continue + self_name = method.args.args[0].arg + input_names = _function_input_names(method) + if not input_names: + continue + has_pointer_return = False + for stmt in _iter_non_nested_nodes(method): + if not isinstance(stmt, ast.If) or _body_has_calls(stmt.body): continue - op = compare.ops[0] - left = compare.left - right = compare.comparators[0] - - if isinstance(op, ast.In): - # Pattern: if cache_key in cache_dict — dict-based pointer cache - if not (isinstance(left, ast.Name) and left.id in pointer_aliases): - continue - dict_name = _ast_root_name(right) - if dict_name is None: - continue - has_return = any( - isinstance(n, ast.Return) and n.value is not None - for n in ast.walk(child) - ) - if not has_return: + ptr_match = False + for cmp in ast.walk(stmt.test): + if not isinstance(cmp, ast.Compare) or len(cmp.ops) != 1 or not isinstance(cmp.ops[0], ast.Eq): continue - dict_written = any( - isinstance(n, ast.Assign) and any( - _ast_root_name(t) == dict_name for t in n.targets - ) - for n in ast.walk(node) - ) or any( - isinstance(n, ast.Assign) and any( - _ast_root_name(t) == dict_name for t in n.targets - ) - for n in tree.body - if isinstance(n, ast.Assign) - ) - if dict_written: - return [{ - "pattern": "POINTER_REPLAY", - "severity": "critical", - "evidence": f"{entrypoint_name} returns saved output from dict keyed by input data_ptr", - }] - continue - - if not isinstance(op, ast.Eq): - continue - pair = None - if isinstance(left, ast.Name) and left.id in pointer_aliases and isinstance(right, ast.Name): - pair = (right.id, left.id) - elif isinstance(right, ast.Name) and right.id in pointer_aliases and isinstance(left, ast.Name): - pair = (left.id, right.id) - elif _is_data_ptr_call(left) and isinstance(right, ast.Name): - pair = (right.id, None) - elif _is_data_ptr_call(right) and isinstance(left, ast.Name): - pair = (left.id, None) - if pair is None: + left_owner = data_ptr_owner(cmp.left) + right_owner = data_ptr_owner(cmp.comparators[0]) + if (left_owner == self_name and right_owner in input_names) or (right_owner == self_name and left_owner in input_names): + ptr_match = True + break + if not ptr_match: continue - returns = [stmt for stmt in child.body if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name)] - if not returns: + if any( + isinstance(inner, ast.Return) + and inner.value is not None + and _ast_root_name(inner.value) == self_name + for inner in stmt.body + ): + has_pointer_return = True + break + if not has_pointer_return: + continue + stores_self_state = False + for stmt in _iter_non_nested_nodes(method): + if not isinstance(stmt, ast.Assign): continue - saved_ptr = pair[0] - saved_out = returns[0].value.id - break + value_names = _expr_names(stmt.value) + for target in stmt.targets: + if _ast_root_name(target) == self_name and (value_names & input_names or any(_looks_output_value_name(name) for name in value_names)): + stores_self_state = True + break + if stores_self_state: + break + if stores_self_state: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": f"class {cls.name}.__call__ replays self state guarded by input data_ptr", + }] - if not saved_ptr or not saved_out: - continue + return [] - stores_ptr = False - stores_out = False - for child in ast.walk(node): - if not isinstance(child, ast.Assign): - continue - for target in child.targets: - if not isinstance(target, ast.Name): - continue - if target.id == saved_ptr: - if (isinstance(child.value, ast.Name) and child.value.id in pointer_aliases) or _is_data_ptr_call(child.value): - stores_ptr = True - elif target.id == saved_out: - stores_out = True - if stores_ptr and stores_out: +def detect_lambda_pointer_dispatch_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect lambda entrypoints using operator/setitem pointer-keyed caches.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + + def identity_expr(expr: ast.AST | None, input_names: set[str]) -> bool: + if expr is None: + return False + if _expr_has_data_ptr_fast(expr, facts._nodes_with_data_ptr) and bool(_expr_names(expr) & input_names): + return True + return any( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "id" + and node.args + and bool(_expr_names(node.args[0]) & input_names) + for node in ast.walk(expr) + ) + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + if not isinstance(stmt.value, ast.Lambda): + continue + lambda_node = stmt.value + input_names = _lambda_input_names(lambda_node) + if not input_names: + continue + stores: set[str] = set() + returns: set[str] = set() + for node in ast.walk(lambda_node.body): + if isinstance(node, ast.Call): + cache_name = None + key_expr = None + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "__setitem__" + and len(node.args) >= 2 + ): + cache_name = _ast_root_name(node.func.value) + key_expr = node.args[0] + elif ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "setitem" + and len(node.args) >= 3 + ): + cache_name = _ast_root_name(node.args[0]) + key_expr = node.args[1] + elif ( + isinstance(node.func, ast.Name) + and node.func.id == "setitem" + and len(node.args) >= 3 + ): + cache_name = _ast_root_name(node.args[0]) + key_expr = node.args[1] + if cache_name and identity_expr(key_expr, input_names): + stores.add(cache_name) + if isinstance(node.func, ast.Attribute) and node.func.attr == "get" and node.args: + cache_name = _ast_root_name(node.func.value) + if cache_name and identity_expr(node.args[0], input_names): + returns.add(cache_name) + elif isinstance(node, ast.Subscript): + cache_name = _ast_root_name(node.value) + if cache_name and identity_expr(node.slice, input_names): + returns.add(cache_name) + if stores & returns: return [{ "pattern": "POINTER_REPLAY", "severity": "critical", - "evidence": f"{entrypoint_name} returns saved output when input data_ptr matches previous pointer", + "evidence": f"{entrypoint_label(facts.entrypoint_name)} lambda caches outputs by input pointer/id", }] return [] @@ -2794,6 +5546,8 @@ def _has_pointer_cached_return(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> bo continue if not _expr_has_pointer_identity(if_node.test, input_names, pointer_aliases): continue + if _body_has_calls(if_node.body): + continue for stmt in if_node.body: for nested in ast.walk(stmt): if not isinstance(nested, ast.Return) or nested.value is None: @@ -3154,6 +5908,219 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_first_call_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Detect first-call/sentinel and captured-state output replay.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + class_none_attrs: set[tuple[str, str]] = set() + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + inherited = { + (stmt.name, attr) + for base in stmt.bases + if isinstance(base, ast.Name) + for cls, attr in class_none_attrs + if cls == base.id + } + class_none_attrs.update(inherited) + for child in stmt.body: + if not isinstance(child, ast.Assign) or not _expr_is_none(child.value): + continue + for target in child.targets: + if isinstance(target, ast.Name): + class_none_attrs.add((stmt.name, target.id)) + + def slot_key(expr: ast.AST | None) -> Optional[str]: + if expr is None: + return None + try: + return ast.unparse(expr) + except Exception: + return ast.dump(expr) + + def is_slot_expr(expr: ast.AST | None, captured_roots: set[str], none_slots: set[str]) -> bool: + root = _ast_root_name(expr) + if root in none_slots: + return isinstance(expr, ast.Name) + if isinstance(expr, ast.Attribute): + if isinstance(expr.value, ast.Name) and (expr.value.id, expr.attr) in class_none_attrs: + return True + return root in captured_roots + if isinstance(expr, ast.Subscript): + return root in captured_roots or root in none_slots + return False + + def input_derived(expr: ast.AST | None, tainted: set[str], input_names: set[str]) -> bool: + names = _expr_names(expr) + return bool(names & tainted) or _expr_contains_input_derived_call(expr, input_names) + + def local_target_names(target: ast.AST | None) -> set[str]: + if isinstance(target, ast.Name): + return {target.id} + if isinstance(target, (ast.Tuple, ast.List)): + names: set[str] = set() + for elt in target.elts: + names.update(local_target_names(elt)) + return names + if isinstance(target, ast.Starred): + return local_target_names(target.value) + return set() + + for fn in _entrypoint_function_nodes(facts): + input_names = _function_input_names(fn) + if not input_names: + continue + + global_names = { + name + for stmt in fn.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + local_names = { + name + for stmt in _iter_non_nested_nodes(fn) + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.For, ast.With)) + for name in ( + local_target_names(stmt.target) if isinstance(stmt, (ast.AnnAssign, ast.AugAssign, ast.For)) + else set().union(*(local_target_names(t) for t in stmt.targets)) if isinstance(stmt, ast.Assign) + else set() + ) + } - global_names + captured_roots = { + name + for name in _expr_names(fn) + if name not in input_names and name not in local_names + } + none_slots = set(facts._none_inited) | { + name for name in captured_roots if name in facts._none_inited + } + + tainted = set(input_names) + aliases: dict[str, str] = {} + stored_slots: set[str] = set() + returned_slots: set[str] = set() + + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign): + if input_derived(stmt.value, tainted, input_names): + for target in stmt.targets: + tainted.update(_target_names(target)) + key = slot_key(target) if is_slot_expr(target, captured_roots, none_slots) else None + if key: + stored_slots.add(key) + for target in stmt.targets: + if isinstance(target, ast.Name) and is_slot_expr(stmt.value, captured_roots, none_slots): + key = slot_key(stmt.value) + if key: + aliases[target.id] = key + for target in stmt.targets: + if is_slot_expr(target, captured_roots, none_slots) and input_derived(stmt.value, tainted, input_names): + key = slot_key(target) + if key: + stored_slots.add(key) + elif isinstance(stmt, ast.AnnAssign): + if input_derived(stmt.value, tainted, input_names): + tainted.update(_target_names(stmt.target)) + key = slot_key(stmt.target) if is_slot_expr(stmt.target, captured_roots, none_slots) else None + if key: + stored_slots.add(key) + elif ( + isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr in {"append", "extend", "insert", "update", "setdefault", "add"} + and _ast_root_name(stmt.func.value) in captured_roots + and any(input_derived(arg, tainted, input_names) for arg in stmt.args) + ): + key = slot_key(stmt.func.value) + if key: + stored_slots.add(key) + elif isinstance(stmt, ast.Return) and stmt.value is not None: + key = aliases.get(stmt.value.id) if isinstance(stmt.value, ast.Name) else None + if key is None and is_slot_expr(stmt.value, captured_roots, none_slots): + key = slot_key(stmt.value) + if key: + returned_slots.add(key) + + for stmt in _iter_non_nested_nodes(fn): + if not isinstance(stmt, ast.If): + continue + + # First-call branch: if slot is None: slot = input_derived; return slot + body_return_slots: set[str] = set() + body_store_slots: set[str] = set() + for inner in stmt.body: + if isinstance(inner, ast.Assign): + if not input_derived(inner.value, tainted, input_names): + continue + for target in inner.targets: + if is_slot_expr(target, captured_roots, none_slots): + key = slot_key(target) + if key: + body_store_slots.add(key) + elif isinstance(inner, ast.Return) and inner.value is not None: + key = aliases.get(inner.value.id) if isinstance(inner.value, ast.Name) else None + if key is None and is_slot_expr(inner.value, captured_roots, none_slots): + key = slot_key(inner.value) + if key: + body_return_slots.add(key) + if body_store_slots & body_return_slots: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} lazily stores and returns input-derived output state", + }] + if body_store_slots: + for idx, top_stmt in enumerate(fn.body): + if top_stmt is not stmt: + continue + for follow in fn.body[idx + 1:]: + if isinstance(follow, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + continue + if any(isinstance(node, ast.Call) for node in ast.walk(follow)): + break + if isinstance(follow, ast.Return) and follow.value is not None: + key = aliases.get(follow.value.id) if isinstance(follow.value, ast.Name) else None + if key is None and is_slot_expr(follow.value, captured_roots, none_slots): + key = slot_key(follow.value) + if key in body_store_slots: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns first-call input-derived output state", + }] + break + break + + # Replay branch: if state/flag says ready: return slot + if _expr_names(stmt.test) & input_names: + continue + if _body_has_calls(stmt.body): + continue + fast_return_slots: set[str] = set() + for inner in stmt.body: + if not isinstance(inner, ast.Return) or inner.value is None: + continue + key = aliases.get(inner.value.id) if isinstance(inner.value, ast.Name) else None + if key is None and is_slot_expr(inner.value, captured_roots, none_slots): + key = slot_key(inner.value) + if key: + fast_return_slots.add(key) + if fast_return_slots & (stored_slots | returned_slots): + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} returns captured state filled from prior input-derived output", + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3385,6 +6352,18 @@ def _scope_has_obfuscation(scope: ast.AST) -> bool: return False +def _decode_like_helpers(tree: ast.AST) -> set[str]: + helpers: set[str] = set() + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for child in ast.walk(node): + if _expr_has_decode_like_call(child): + helpers.add(node.name) + break + return helpers + + def detect_dynamic_execution(code: str | SubmissionFacts) -> list[dict]: """AST pattern 6: exec/eval/compile and dynamic import. @@ -3395,10 +6374,72 @@ def detect_dynamic_execution(code: str | SubmissionFacts) -> list[dict]: facts = ensure_submission_facts(code) tree = facts.ast_tree if tree is None: + active = facts.python_active + aliases = set(re.findall(r'^\s*(\w+)\s*=\s*(?:exec|eval)\s*$', active, re.MULTILINE)) + if aliases and any(re.search(rf'\b{re.escape(alias)}\s*\(', active) for alias in aliases): + if re.search(r'\bcustom_kernel\b|def\s+\w+\s*\(', active): + return [{ + "pattern": "OBFUSCATED_EXEC", + "severity": "critical", + "evidence": "aliased exec/eval in syntactically invalid source can hide submitted entrypoint code", + }] return [] matches = [] seen: set[str] = set() + exec_aliases = {"exec": "exec", "eval": "eval"} + decode_helpers = _decode_like_helpers(tree) + string_aliases: dict[str, str] = {} + for node in ast.walk(tree): + if not isinstance(node, ast.Assign): + continue + static_value = _static_string(node.value) + if static_value is not None: + for target in node.targets: + if isinstance(target, ast.Name): + string_aliases[target.id] = static_value + if isinstance(node.value, ast.Name) and node.value.id in {"exec", "eval"}: + for target in node.targets: + if isinstance(target, ast.Name): + exec_aliases[target.id] = node.value.id + + entrypoint_fns = _entrypoint_function_nodes(facts) + + def _call_inside_fn(fn: ast.FunctionDef | ast.AsyncFunctionDef, call: ast.Call) -> bool: + return any(inner is call for inner in _iter_non_nested_nodes(fn)) + + def _expr_uses_namespace(expr: ast.AST | None, namespace_names: set[str]) -> bool: + if expr is None: + return False + for node in ast.walk(expr): + if isinstance(node, ast.Subscript) and _ast_root_name(node.value) in namespace_names: + return True + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Subscript) and _ast_root_name(node.func.value) in namespace_names: + return True + return False + + def _entrypoint_exec_builds_return(call: ast.Call) -> bool: + namespace_arg = call.args[1] if len(call.args) >= 2 else None + namespace_names = {namespace_arg.id} if isinstance(namespace_arg, ast.Name) else set() + payload = _static_string(call.args[0]) if call.args else None + if payload is None and call.args and isinstance(call.args[0], ast.Name): + payload = string_aliases.get(call.args[0].id) + payload_builds_code = bool(payload and re.search(r'\bdef\s+\w+\s*\(|\blambda\b|\bimport\s+\w+', payload)) + + for fn in entrypoint_fns: + if not _call_inside_fn(fn, call): + continue + local_namespaces = set(namespace_names) + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Dict): + for target in stmt.targets: + local_namespaces.update(_target_names(target)) + if local_namespaces and payload_builds_code: + for stmt in _iter_non_nested_nodes(fn): + if isinstance(stmt, ast.Return) and _expr_uses_namespace(stmt.value, local_namespaces): + return True + return False # Pre-scan: check if any function scope contains both exec/eval AND # decode/decompress calls (catches split-variable patterns like: @@ -3409,10 +6450,10 @@ def detect_dynamic_execution(code: str | SubmissionFacts) -> list[dict]: continue has_exec = any( isinstance(n, ast.Call) and isinstance(n.func, ast.Name) - and n.func.id in ("exec", "eval") + and n.func.id in exec_aliases for n in ast.walk(scope) ) - if has_exec and _scope_has_obfuscation(scope): + if has_exec and (_scope_has_obfuscation(scope) or _expr_has_decode_like_call(scope, decode_helpers)): obfuscated_scopes.add(id(scope)) for node in ast.walk(tree): @@ -3425,8 +6466,12 @@ def detect_dynamic_execution(code: str | SubmissionFacts) -> list[dict]: elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): func_name = f"{node.func.value.id}.{node.func.attr}" - if func_name in ("exec", "eval") and func_name not in seen: - is_obfuscated = _is_obfuscated_exec(node) + sink_name = exec_aliases.get(func_name, func_name) + if sink_name in ("exec", "eval") and func_name not in seen: + is_obfuscated = _is_obfuscated_exec(node) or ( + bool(node.args) and _expr_has_decode_like_call(node.args[0], decode_helpers) + ) + builds_entrypoint_return = _entrypoint_exec_builds_return(node) # Also check if this exec lives in a scope that has decode/decompress if not is_obfuscated: for scope in ast.walk(tree): @@ -3440,14 +6485,19 @@ def detect_dynamic_execution(code: str | SubmissionFacts) -> list[dict]: break if is_obfuscated: break - if is_obfuscated: + if is_obfuscated or builds_entrypoint_return: key = f"obfuscated_{func_name}" if key not in seen: seen.add(key) + evidence = ( + f"{func_name}() builds and returns generated callable inside {entrypoint_label(facts.entrypoint_name)}" + if builds_entrypoint_return and not is_obfuscated + else f"{func_name}() with encoded/decoded payload (hides exploit code)" + ) matches.append({ "pattern": "OBFUSCATED_EXEC", "severity": "critical", - "evidence": f"{func_name}() with encoded/compressed payload (hides exploit code)", + "evidence": evidence, }) else: seen.add(func_name) @@ -4215,6 +7265,18 @@ class RulePolicy: "TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", ), + "INPUT_PASSTHROUGH_OUTPUT": RulePolicy( + "INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), + "INPUT_REDUCTION_OUTPUT": RulePolicy( + "INPUT_REDUCTION_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), + "SELF_MATMUL_OUTPUT": RulePolicy( + "SELF_MATMUL_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TORCH_COMPILE_CACHE": RulePolicy( "TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", @@ -4465,14 +7527,32 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_scaled_mm_ref, detect_decode_mm_ref, detect_result_caching, + detect_helper_output_replay_cache, detect_value_keyed_output_replay, detect_object_output_cache_replay, detect_last_call_replay, + detect_first_call_state_replay, detect_shape_output_replay, detect_timed_input_replay, + detect_pointer_cache_hit_replay, + detect_partial_bound_storage_replay, + detect_class_pointer_sentinel_replay, + detect_mutable_default_state_replay, + detect_function_attribute_state_replay, + detect_nonlocal_factory_state_replay, + detect_contextvar_state_replay, + detect_alias_state_replay, + detect_partial_bound_method_replay, + detect_generator_send_replay, + detect_class_self_pointer_replay, + detect_lambda_pointer_dispatch_replay, detect_cuda_graph_replay, detect_silent_fallback, detect_trivial_probe, + detect_input_passthrough_output, + detect_indirect_input_passthrough_output, + detect_input_reduction_output, + detect_self_matmul_output, detect_torch_compile_cache, detect_cuda_graph_python, detect_hardcoded_shapes, @@ -4508,14 +7588,32 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("scaled_mm_ref", detect_scaled_mm_ref), ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), + ("helper_output_replay_cache", detect_helper_output_replay_cache), ("value_keyed_output_replay", detect_value_keyed_output_replay), ("object_output_cache_replay", detect_object_output_cache_replay), ("last_call_replay", detect_last_call_replay), + ("first_call_state_replay", detect_first_call_state_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), + ("pointer_cache_hit_replay", detect_pointer_cache_hit_replay), + ("partial_bound_storage_replay", detect_partial_bound_storage_replay), + ("class_pointer_sentinel_replay", detect_class_pointer_sentinel_replay), + ("mutable_default_state_replay", detect_mutable_default_state_replay), + ("function_attribute_state_replay", detect_function_attribute_state_replay), + ("nonlocal_factory_state_replay", detect_nonlocal_factory_state_replay), + ("contextvar_state_replay", detect_contextvar_state_replay), + ("alias_state_replay", detect_alias_state_replay), + ("partial_bound_method_replay", detect_partial_bound_method_replay), + ("generator_send_replay", detect_generator_send_replay), + ("class_self_pointer_replay", detect_class_self_pointer_replay), + ("lambda_pointer_dispatch_replay", detect_lambda_pointer_dispatch_replay), ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), ("trivial_probe", detect_trivial_probe), + ("input_passthrough_output", detect_input_passthrough_output), + ("indirect_input_passthrough_output", detect_indirect_input_passthrough_output), + ("input_reduction_output", detect_input_reduction_output), + ("self_matmul_output", detect_self_matmul_output), ("torch_compile_cache", detect_torch_compile_cache), ("cuda_graph_python", detect_cuda_graph_python), ("hardcoded_shapes", detect_hardcoded_shapes), @@ -5419,7 +8517,8 @@ def _worker_parquet(args: tuple) -> dict: "EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT", "FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS", "TRUSTED_MODULE_IMPORT", - "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", + "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", + "INPUT_PASSTHROUGH_OUTPUT", "INPUT_REDUCTION_OUTPUT", "SELF_MATMUL_OUTPUT", "PREPROCESS_CACHE", "WORKSPACE_CACHE", "RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY", "TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING", "SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE", diff --git a/pyproject.toml b/pyproject.toml index 484b707..eca4a78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "kernelguard" -version = "0.2.3" +version = "0.3.0" description = "Rule-based GPU kernel hack detector." readme = "README.md" requires-python = ">=3.11"