Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,107 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_getitem_store_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: custom __getitem__ returns replay state from an object."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

getitem_classes = {
stmt.name
for stmt in tree.body
if isinstance(stmt, ast.ClassDef)
and any(
isinstance(item, ast.FunctionDef) and item.name == "__getitem__"
for item in stmt.body
)
}
instances = {
stmt.targets[0].id: stmt.value.func.id
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in getitem_classes
}

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _instance_attr(expr: ast.AST | None) -> str | None:
if isinstance(expr, ast.Subscript):
return _instance_attr(expr.value)
if not (
isinstance(expr, ast.Attribute)
and isinstance(expr.value, ast.Name)
and expr.value.id in instances
):
return None
return expr.value.id

def _instance_subscript(expr: ast.AST | None) -> str | None:
if not isinstance(expr, ast.Subscript):
return None
root = _ast_root_name(expr)
if root in instances:
return root
return None

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
mutated = {
root
for expr in ast.walk(node)
if isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr in {"append", "extend", "insert", "update", "setdefault"}
and any(bool(_expr_names(arg) & params) for arg in expr.args)
for root in [_instance_attr(expr.func.value)]
if root is not None
}
if not mutated:
continue

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & params:
continue
tested = _expr_names(child.test) & mutated
returned = {
root
for stmt in child.body
if isinstance(stmt, ast.Return)
for root in [_instance_subscript(stmt.value)]
if root is not None
}
if tested & returned:
return [{
"pattern": "GETITEM_STORE_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns custom __getitem__ state "
"populated from a prior input"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3580,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"GETITEM_STORE_REPLAY": RulePolicy(
"GETITEM_STORE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3867,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_getitem_store_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3906,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("getitem_store_replay", detect_getitem_store_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4803,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "GETITEM_STORE_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down