diff --git a/kernelguard.py b/kernelguard.py index f086847..0888667 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2459,6 +2459,86 @@ def _has_ver(expr: ast.AST | None) -> bool: ), }] + # Fourth pass: replay through next(iter(state.values())) after the + # entrypoint mutates that captured mapping. + params = { + arg.arg + for args in (node.args.posonlyargs, node.args.args, node.args.kwonlyargs) + for arg in args + } + if node.args.vararg is not None: + params.add(node.args.vararg.arg) + if node.args.kwarg is not None: + params.add(node.args.kwarg.arg) + + global_names = { + name + for stmt in node.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + local_names = { + target.id + for stmt in node.body + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)) + for target in ( + stmt.targets if isinstance(stmt, ast.Assign) + else [stmt.target] + ) + if isinstance(target, ast.Name) and target.id not in global_names + } + captured = _expr_names(node) - params - local_names + + def _values_iter_root(expr: ast.AST | None) -> Optional[str]: + if not (isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "next" + and len(expr.args) == 1): + return None + iter_call = expr.args[0] + if not (isinstance(iter_call, ast.Call) + and isinstance(iter_call.func, ast.Name) + and iter_call.func.id == "iter" + and len(iter_call.args) == 1): + return None + values_call = iter_call.args[0] + if not (isinstance(values_call, ast.Call) + and isinstance(values_call.func, ast.Attribute) + and values_call.func.attr == "values"): + return None + return _ast_root_name(values_call.func.value) + + mutated_captures: set[str] = set() + for child in ast.walk(node): + if isinstance(child, ast.Assign): + for target in child.targets: + target_root = _ast_root_name(target) + if target_root in captured and not isinstance(target, ast.Name): + mutated_captures.add(target_root) + + if mutated_captures: + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _expr_names(child.test) & params: + continue + returned_roots = { + _values_iter_root(stmt.value) + for stmt in child.body + if isinstance(stmt, ast.Return) + } - {None} + if not (returned_roots & mutated_captures): + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} replays output through values() " + "iteration on mutated captured state" + ), + }] + return []