Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,111 @@ def _has_ver(expr: ast.AST | None) -> bool:
for n in ast.walk(expr)
)

mutating_methods = {
"add", "append", "clear", "extend", "insert", "pop", "popitem",
"remove", "rotate", "setdefault", "update",
}
functions_by_name = {
child.name: child
for child in ast.walk(tree)
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
}
factory_names: set[str] = set()
for stmt in getattr(tree, "body", []):
if not isinstance(stmt, ast.Assign):
continue
if not any(
isinstance(target, ast.Name) and is_entrypoint_name(target.id)
for target in stmt.targets
):
continue
if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name):
factory_names.add(stmt.value.func.id)

for factory_name in factory_names:
factory = functions_by_name.get(factory_name)
if factory is None:
continue

factory_locals = {
target.id
for stmt in factory.body
if isinstance(stmt, ast.Assign)
for target in stmt.targets
if isinstance(target, ast.Name)
}
returned_names = {
stmt.value.id
for stmt in factory.body
if isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Name)
}
inner_functions = [
stmt for stmt in factory.body
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef))
and stmt.name in returned_names
]

for inner in inner_functions:
mutated_subscripts: set[str] = set()
mutated_attrs: set[tuple[str, str]] = set()
for child in ast.walk(inner):
if isinstance(child, ast.Assign):
for target in child.targets:
target_root = _ast_root_name(target)
if target_root not in factory_locals:
continue
if isinstance(target, ast.Attribute):
mutated_attrs.add((target_root, target.attr))
else:
mutated_subscripts.add(target_root)
elif isinstance(child, ast.AugAssign):
target_root = _ast_root_name(child.target)
if target_root in factory_locals:
mutated_subscripts.add(target_root)
elif (isinstance(child, ast.Call)
and isinstance(child.func, ast.Attribute)
and child.func.attr in mutating_methods):
target_root = _ast_root_name(child.func.value)
if target_root in factory_locals:
mutated_subscripts.add(target_root)

input_params = {
arg.arg
for args in (inner.args.posonlyargs, inner.args.args, inner.args.kwonlyargs)
for arg in args
}
for child in ast.walk(inner):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & input_params:
continue

returns_replay = False
for stmt in child.body:
if not isinstance(stmt, ast.Return) or stmt.value is None:
continue
if (isinstance(stmt.value, ast.Subscript)
and _ast_root_name(stmt.value) in mutated_subscripts):
returns_replay = True
if isinstance(stmt.value, ast.Attribute):
attr_key = (_ast_root_name(stmt.value), stmt.value.attr)
if attr_key in mutated_attrs:
returns_replay = True
if not returns_replay:
continue

return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} is returned from a factory and "
"replays output from mutated closure state"
),
}]

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
Expand Down