From d348a31acc3c7ec4727f39c0fc2d36a2dfb40490 Mon Sep 17 00:00:00 2001 From: Prasanna Date: Sat, 2 May 2026 20:08:42 +0530 Subject: [PATCH] Detect factory closure replay --- kernelguard.py | 105 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/kernelguard.py b/kernelguard.py index f086847..9f99aa3 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2197,6 +2197,111 @@ def _has_ver(expr: ast.AST | None) -> bool: for n in ast.walk(expr) ) + mutating_methods = { + "add", "append", "clear", "extend", "insert", "pop", "popitem", + "remove", "rotate", "setdefault", "update", + } + functions_by_name = { + child.name: child + for child in ast.walk(tree) + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + factory_names: set[str] = set() + for stmt in getattr(tree, "body", []): + if not isinstance(stmt, ast.Assign): + continue + if not any( + isinstance(target, ast.Name) and is_entrypoint_name(target.id) + for target in stmt.targets + ): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + factory_names.add(stmt.value.func.id) + + for factory_name in factory_names: + factory = functions_by_name.get(factory_name) + if factory is None: + continue + + factory_locals = { + target.id + for stmt in factory.body + if isinstance(stmt, ast.Assign) + for target in stmt.targets + if isinstance(target, ast.Name) + } + returned_names = { + stmt.value.id + for stmt in factory.body + if isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Name) + } + inner_functions = [ + stmt for stmt in factory.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) + and stmt.name in returned_names + ] + + for inner in inner_functions: + mutated_subscripts: set[str] = set() + mutated_attrs: set[tuple[str, str]] = set() + for child in ast.walk(inner): + if isinstance(child, ast.Assign): + for target in child.targets: + target_root = _ast_root_name(target) + if target_root not in factory_locals: + continue + if isinstance(target, ast.Attribute): + mutated_attrs.add((target_root, target.attr)) + else: + mutated_subscripts.add(target_root) + elif isinstance(child, ast.AugAssign): + target_root = _ast_root_name(child.target) + if target_root in factory_locals: + mutated_subscripts.add(target_root) + elif (isinstance(child, ast.Call) + and isinstance(child.func, ast.Attribute) + and child.func.attr in mutating_methods): + target_root = _ast_root_name(child.func.value) + if target_root in factory_locals: + mutated_subscripts.add(target_root) + + input_params = { + arg.arg + for args in (inner.args.posonlyargs, inner.args.args, inner.args.kwonlyargs) + for arg in args + } + for child in ast.walk(inner): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + if _expr_names(child.test) & input_params: + continue + + returns_replay = False + for stmt in child.body: + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + if (isinstance(stmt.value, ast.Subscript) + and _ast_root_name(stmt.value) in mutated_subscripts): + returns_replay = True + if isinstance(stmt.value, ast.Attribute): + attr_key = (_ast_root_name(stmt.value), stmt.value.attr) + if attr_key in mutated_attrs: + returns_replay = True + if not returns_replay: + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} is returned from a factory and " + "replays output from mutated closure state" + ), + }] + for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): continue