Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,86 @@ def _has_ver(expr: ast.AST | None) -> bool:
),
}]

# Fourth pass: replay through next(iter(state.values())) after the
# entrypoint mutates that captured mapping.
params = {
arg.arg
for args in (node.args.posonlyargs, node.args.args, node.args.kwonlyargs)
for arg in args
}
if node.args.vararg is not None:
params.add(node.args.vararg.arg)
if node.args.kwarg is not None:
params.add(node.args.kwarg.arg)

global_names = {
name
for stmt in node.body
if isinstance(stmt, ast.Global)
for name in stmt.names
}
local_names = {
target.id
for stmt in node.body
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign))
for target in (
stmt.targets if isinstance(stmt, ast.Assign)
else [stmt.target]
)
if isinstance(target, ast.Name) and target.id not in global_names
}
captured = _expr_names(node) - params - local_names

def _values_iter_root(expr: ast.AST | None) -> Optional[str]:
if not (isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "next"
and len(expr.args) == 1):
return None
iter_call = expr.args[0]
if not (isinstance(iter_call, ast.Call)
and isinstance(iter_call.func, ast.Name)
and iter_call.func.id == "iter"
and len(iter_call.args) == 1):
return None
values_call = iter_call.args[0]
if not (isinstance(values_call, ast.Call)
and isinstance(values_call.func, ast.Attribute)
and values_call.func.attr == "values"):
return None
return _ast_root_name(values_call.func.value)

mutated_captures: set[str] = set()
for child in ast.walk(node):
if isinstance(child, ast.Assign):
for target in child.targets:
target_root = _ast_root_name(target)
if target_root in captured and not isinstance(target, ast.Name):
mutated_captures.add(target_root)

if mutated_captures:
for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & params:
continue
returned_roots = {
_values_iter_root(stmt.value)
for stmt in child.body
if isinstance(stmt, ast.Return)
} - {None}
if not (returned_roots & mutated_captures):
continue

return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} replays output through values() "
"iteration on mutated captured state"
),
}]

return []


Expand Down