diff --git a/CHANGELOG.md b/CHANGELOG.md index 934bb0b82..799ee9163 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bug Fixes +- inspect large Keras HDF5 models through file-backed metadata traversal instead of rejecting them at the generic whole-file read cap; aggregate `content_hash` is omitted for these large file-backed HDF5 scans - paginate and bound large Hugging Face repository inventories before streaming so unfiltered scans preserve complete coverage - dispatch logical model directories through their owning scanners, preserving bounded complete SavedModel asset probes before supplemental child-file coverage - stream large Flax MessagePack tensor bodies by declared length without tripping the 512 MiB decode budget diff --git a/modelaudit/core.py b/modelaudit/core.py index 8f32cf21b..cf14c26e9 100644 --- a/modelaudit/core.py +++ b/modelaudit/core.py @@ -69,7 +69,7 @@ def shared_source_sensitive_caches() -> Iterator[None]: merge_inconclusive_flax_msgpack_outcome, merge_safetensors_overlap_analysis, ) -from modelaudit.scanners.base import FORMAT_VALIDATION_CONFIG_KEY, BaseScanner +from modelaudit.scanners.base import DEFAULT_MAX_FILE_READ_SIZE, FORMAT_VALIDATION_CONFIG_KEY, BaseScanner from modelaudit.scanners.mxnet_scanner import MXNET_PREFERRED_XGBOOST_SKIP_PATH_CONFIG_KEY from modelaudit.scanners.safetensors_scanner import MAX_HEADER_BYTES as SAFETENSORS_MAX_HEADER_BYTES from modelaudit.scanners.xgboost_scanner import ( @@ -2603,6 +2603,16 @@ def _should_defer_hash_for_safetensors_header_limit(file_path: str, config: dict return should_defer_safetensors_header_limit_hash(file_path, max_header_bytes) +def _should_defer_hash_for_file_backed_hdf5(file_path: str) -> bool: + """Avoid pre-dispatch whole-file hashing for HDF5 scans handled through h5py metadata traversal.""" + try: + file_size = os.path.getsize(file_path) + except OSError: + return False + + return file_size > DEFAULT_MAX_FILE_READ_SIZE and find_hdf5_signature_offset(file_path) is not None + + def _should_defer_hash_for_max_file_size(file_path: str, config: dict[str, Any]) -> bool: """Avoid hashing files that regular scanning will reject on max_file_size.""" try: @@ -2636,9 +2646,44 @@ def _should_defer_hash_for_max_total_size( return hashed_bytes > max_total_size +_FILE_BACKED_HDF5_UNHASHABLE_PREFIX = "unhashable_file_backed_hdf5_" + + +def _is_file_backed_hdf5_hash_placeholder(content_hash: str) -> bool: + return content_hash.startswith(_FILE_BACKED_HDF5_UNHASHABLE_PREFIX) + + +def _directory_owner_hash_is_unverifiable( + content_hash: str, + *, + allow_file_backed_hdf5: bool, +) -> bool: + if not content_hash.startswith("unhashable_"): + return False + return not (allow_file_backed_hdf5 and _is_file_backed_hdf5_hash_placeholder(content_hash)) + + +def _directory_owner_hash_changed( + before_hash: str | None, + after_hash: str | None, + *, + allow_file_backed_hdf5: bool, +) -> bool: + if before_hash == after_hash: + return False + return not ( + allow_file_backed_hdf5 + and isinstance(before_hash, str) + and isinstance(after_hash, str) + and _is_file_backed_hdf5_hash_placeholder(before_hash) + and _is_file_backed_hdf5_hash_placeholder(after_hash) + ) + + def _is_incomplete_aggregate_hash_placeholder(content_hash: str) -> bool: return content_hash.startswith( ( + _FILE_BACKED_HDF5_UNHASHABLE_PREFIX, "unhashable_max_file_size_", "unhashable_max_total_size_", "unhashable_timeout_", @@ -2680,6 +2725,9 @@ def _hash_files_by_path( if _should_defer_hash_for_safetensors_header_limit(routing_path, hash_config): content_hashes[file_path] = f"unhashable_bounded_safetensors_{id(file_path)}" continue + if _should_defer_hash_for_file_backed_hdf5(routing_path): + content_hashes[file_path] = f"unhashable_file_backed_hdf5_{id(file_path)}" + continue if should_defer_hash_for_pytorch_read_limit(routing_path, hash_config): content_hashes[file_path] = f"unhashable_pytorch_zip_read_limit_{id(file_path)}" continue @@ -2799,6 +2847,7 @@ def _directory_owner_scan_path( config: dict[str, Any], deadline: float, force_staged: bool = False, + require_bound: bool = False, source_paths_by_owner_path: dict[str, str] | None = None, ) -> Iterator[str]: """Yield a bound or hash-verified copied path for logical directory-owner scanning.""" @@ -2807,6 +2856,8 @@ def _directory_owner_scan_path( try: owner_scan_path = scan_path_stack.enter_context(_bound_directory_owner_scan_path(root_path)) except OSError: + if require_bound: + raise owner_scan_path = scan_path_stack.enter_context( _staged_directory_owner_scan_path( root_path, @@ -2818,6 +2869,8 @@ def _directory_owner_scan_path( ), ) else: + if require_bound: + raise OSError("Descriptor-backed directory owner path required for deferred source hashes") owner_scan_path = scan_path_stack.enter_context( _staged_directory_owner_scan_path( root_path, @@ -4239,11 +4292,41 @@ def owner_source_covered_by_child(source: str) -> bool: source: owner_hash_for_source(hashes_by_source, source) or f"unhashable_{id(source)}" for source in owner_sources } - if owner_block_reason is None and any( - hash_value.startswith("unhashable_") for hash_value in owner_hashes_before.values() - ): - owner_block_reason = "directory_owner_snapshot_incomplete" - owner_block_details = {"unhashable_source_count": 1} + file_backed_hdf5_owner_source_count = sum( + _is_file_backed_hdf5_hash_placeholder(hash_value) for hash_value in owner_hashes_before.values() + ) + allow_file_backed_hdf5_owner_hashes = False + if owner_block_reason is None: + unverifiable_owner_hash_count = sum( + _directory_owner_hash_is_unverifiable( + hash_value, + allow_file_backed_hdf5=True, + ) + for hash_value in owner_hashes_before.values() + ) + if unverifiable_owner_hash_count: + owner_block_reason = "directory_owner_snapshot_incomplete" + owner_block_details = {"unhashable_source_count": unverifiable_owner_hash_count} + elif file_backed_hdf5_owner_source_count: + if directory_owner_content_source_paths: + owner_block_reason = "directory_owner_snapshot_incomplete" + owner_block_details = { + "requires_descriptor_bound_owner": True, + "unhashable_source_count": file_backed_hdf5_owner_source_count, + } + else: + try: + with _bound_directory_owner_scan_path(owner_root_path): + pass + except OSError as error: + owner_block_reason = "directory_owner_snapshot_incomplete" + owner_block_details = { + "error_type": type(error).__name__, + "requires_descriptor_bound_owner": True, + "unhashable_source_count": file_backed_hdf5_owner_source_count, + } + else: + allow_file_backed_hdf5_owner_hashes = True owner_snapshot_before_dispatch = directory_owner_initial_snapshot if owner_block_reason is None: @@ -4312,6 +4395,7 @@ def owner_source_covered_by_child(source: str) -> bool: config=owner_hash_config, deadline=start_time + timeout, force_staged=bool(directory_owner_content_source_paths), + require_bound=allow_file_backed_hdf5_owner_hashes, source_paths_by_owner_path=directory_owner_content_source_paths, ) as directory_owner_scan_path: owner_scan_started = True @@ -4401,16 +4485,26 @@ def owner_source_covered_by_child(source: str) -> bool: changed_owner_sources = [ source for source in owner_sources - if owner_hashes_before.get(source) != owner_hashes_after.get(source) + if _directory_owner_hash_changed( + owner_hashes_before.get(source), + owner_hashes_after.get(source), + allow_file_backed_hdf5=allow_file_backed_hdf5_owner_hashes, + ) ] if post_snapshot_reason is None and changed_owner_sources: post_snapshot_reason = "directory_owner_source_changed" post_snapshot_details = {"changed_source_count": len(changed_owner_sources)} - if post_snapshot_reason is None and any( - hash_value.startswith("unhashable_") for hash_value in owner_hashes_after.values() - ): - post_snapshot_reason = "directory_owner_snapshot_incomplete" - post_snapshot_details = {"unhashable_source_count": 1} + if post_snapshot_reason is None: + unverifiable_owner_hash_count = sum( + _directory_owner_hash_is_unverifiable( + hash_value, + allow_file_backed_hdf5=allow_file_backed_hdf5_owner_hashes, + ) + for hash_value in owner_hashes_after.values() + ) + if unverifiable_owner_hash_count: + post_snapshot_reason = "directory_owner_snapshot_incomplete" + post_snapshot_details = {"unhashable_source_count": unverifiable_owner_hash_count} assert directory_owner_result is not None if post_snapshot_reason is not None: @@ -4938,11 +5032,17 @@ def owner_source_covered_by_child(source: str) -> bool: hashed_bytes=top_level_hashed_bytes, ) defer_hash_for_max_file_size = _should_defer_hash_for_max_file_size(target, config) + defer_hash_for_file_backed_hdf5 = _should_defer_hash_for_file_backed_hdf5(target) defer_hash_for_pytorch_read_limit = should_defer_hash_for_pytorch_read_limit( target, config, ) - if defer_hash_for_max_total_size or defer_hash_for_max_file_size or defer_hash_for_pytorch_read_limit: + if ( + defer_hash_for_max_total_size + or defer_hash_for_max_file_size + or defer_hash_for_file_backed_hdf5 + or defer_hash_for_pytorch_read_limit + ): aggregate_hash_complete = False if defer_hash_for_pytorch_read_limit: target_config = dict(target_config) @@ -4951,6 +5051,7 @@ def owner_source_covered_by_child(source: str) -> bool: not _should_defer_hash_for_safetensors_header_limit(target, config) and not defer_hash_for_max_file_size and not defer_hash_for_max_total_size + and not defer_hash_for_file_backed_hdf5 and not defer_hash_for_pytorch_read_limit ): try: @@ -6643,8 +6744,14 @@ def append_streamed_file_hash( hashed_bytes=top_level_hashed_bytes, ) defer_hash_for_max_file_size = _should_defer_hash_for_max_file_size(str(scan_path), scan_config) + defer_hash_for_file_backed_hdf5 = _should_defer_hash_for_file_backed_hdf5(str(scan_path)) defer_hash_for_pytorch_read_limit = should_defer_hash_for_pytorch_read_limit(str(scan_path), scan_config) - if defer_hash_for_max_total_size or defer_hash_for_max_file_size or defer_hash_for_pytorch_read_limit: + if ( + defer_hash_for_max_total_size + or defer_hash_for_max_file_size + or defer_hash_for_file_backed_hdf5 + or defer_hash_for_pytorch_read_limit + ): aggregate_hash_complete = False return None if _should_defer_hash_for_safetensors_header_limit(str(scan_path), scan_config): diff --git a/modelaudit/scanners/keras_h5_scanner.py b/modelaudit/scanners/keras_h5_scanner.py index 18e41fcb0..c132c5214 100644 --- a/modelaudit/scanners/keras_h5_scanner.py +++ b/modelaudit/scanners/keras_h5_scanner.py @@ -5,6 +5,8 @@ import math import os import re +import subprocess +import sys from collections.abc import Callable from contextlib import suppress from typing import Any, ClassVar @@ -32,7 +34,7 @@ redact_evidence_value, redact_untrusted_error_message, ) -from .base import INCONCLUSIVE_SCAN_OUTCOME, BaseScanner, IssueSeverity, ScanResult +from .base import DEFAULT_MAX_FILE_READ_SIZE, INCONCLUSIVE_SCAN_OUTCOME, BaseScanner, IssueSeverity, ScanResult from .keras_utils import ( check_custom_loss_config, check_custom_metric_config, @@ -50,6 +52,186 @@ except Exception: HAS_H5PY = False +_HDF5_ATTRIBUTE_WORKER_CODE = r""" +import base64 +import json +import os +import sys + +for _name in ("OPENBLAS_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): + os.environ.setdefault(_name, "1") + +_request = json.loads(sys.stdin.read()) +_memory_limit = int(_request.get("memory_limit_bytes") or 0) +if _memory_limit > 0: + try: + import resource + + resource.setrlimit(resource.RLIMIT_AS, (_memory_limit, _memory_limit)) + except Exception: + pass + +try: + import h5py + import numpy as np +except BaseException as exc: + print(json.dumps({"status": "error", "error_type": type(exc).__name__, "error": str(exc)[:200]})) + raise SystemExit(0) + + +def _emit(payload): + print(json.dumps(payload, separators=(",", ":"))) + + +def _jsonable(value): + if isinstance(value, bytes): + return {"__bytes__": base64.b64encode(value).decode("ascii")} + if isinstance(value, np.ndarray): + return [_jsonable(item) for item in value.reshape(-1).tolist()] + if isinstance(value, np.generic): + return _jsonable(value.item()) + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, (list, tuple)): + return [_jsonable(item) for item in value] + return str(value) + + +def _decode_names(value): + if hasattr(value, "tolist"): + value = value.tolist() + if isinstance(value, bytes): + return [value.decode("utf-8", errors="ignore")] + if isinstance(value, str): + return [value] + try: + items = list(value) + except TypeError: + items = [value] + names = [] + for item in items: + if isinstance(item, bytes): + names.append(item.decode("utf-8", errors="ignore")) + elif isinstance(item, str): + names.append(item) + return [name for name in names if name] + + +def _is_variable_string(attr_id): + attr_type = attr_id.get_type() + return attr_type.get_class() == h5py.h5t.STRING and attr_type.is_variable_str() + + +def _read_variable_string(attr_id, *, max_bytes, max_items, mode, max_text_chars): + space = attr_id.get_space() + point_count = int(space.get_simple_extent_npoints()) + if point_count == 0 and mode == "names": + return {"status": "value", "value": [], "truncated": False} + if point_count <= 0 or point_count > max_items: + return {"status": "skipped", "reason": "point_count_exceeded", "point_count": point_count} + + per_item_limit = max(max_bytes // point_count, 1) + if mode == "names": + per_item_limit = min(max_text_chars, per_item_limit) + per_item_bytes = per_item_limit + 1 + if point_count * per_item_bytes > max_bytes + max_items: + return {"status": "skipped", "reason": "storage_size_exceeded", "storage_size": point_count * per_item_bytes} + + dims = tuple(int(dimension) for dimension in space.get_simple_extent_dims()) + buffer_shape = dims if dims else () + memory_type = h5py.h5t.C_S1.copy() + memory_type.set_size(per_item_bytes) + buffer = np.zeros(buffer_shape, dtype=f"S{per_item_bytes}") + attr_id.read(buffer, memory_type) + raw_items = [buffer.item()] if buffer_shape == () else buffer.reshape(-1).tolist() + + values = [] + total_bytes = 0 + truncated = False + for raw_item in raw_items: + raw_bytes = bytes(raw_item) if isinstance(raw_item, bytes) else str(raw_item).encode("utf-8", "ignore") + total_bytes += len(raw_bytes) + if len(raw_bytes) >= per_item_bytes or total_bytes > max_bytes: + truncated = True + values.append(raw_bytes.decode("utf-8", errors="ignore")) + values = [value for value in values if value] + if len(values) > max_items: + values = values[:max_items] + truncated = True + if mode == "names": + return {"status": "value", "value": values, "truncated": truncated} + if truncated: + return {"status": "skipped", "reason": "storage_size_exceeded", "storage_size": max(total_bytes, max_bytes + 1)} + return {"status": "value", "value": values[0] if len(values) == 1 else values, "truncated": truncated} + + +def _read_one_attribute(h5_file, request): + object_path = request["object_path"] + attr_name = request["attr_name"] + mode = request["mode"] + max_bytes = int(request["max_bytes"]) + max_items = int(request["max_items"]) + max_text_chars = int(_request["max_text_chars"]) + + try: + obj = h5_file if object_path in {"", "/"} else h5_file.get(object_path, getlink=False) + if obj is None: + return {"status": "error", "reason": "object_missing"} + if attr_name not in obj.attrs: + return {"status": "missing"} + + attr_id = obj.attrs.get_id(attr_name) + space = attr_id.get_space() + point_count = int(space.get_simple_extent_npoints()) + if point_count > max_items: + return {"status": "skipped", "reason": "point_count_exceeded", "point_count": point_count} + + if _is_variable_string(attr_id): + return _read_variable_string( + attr_id, + max_bytes=max_bytes, + max_items=max_items, + mode=mode, + max_text_chars=max_text_chars, + ) + + storage_size = int(attr_id.get_storage_size()) + if storage_size > max_bytes: + return {"status": "skipped", "reason": "storage_size_exceeded", "storage_size": storage_size} + + value = obj.attrs[attr_name] + if mode == "names": + names = _decode_names(value) + truncated = len(names) > max_items + return {"status": "value", "value": names[:max_items], "truncated": truncated} + return {"status": "value", "value": _jsonable(value), "truncated": False} + except Exception as exc: + return {"status": "error", "error_type": type(exc).__name__, "error": str(exc)[:200]} + + +try: + file_path = _request["file_path"] + with h5py.File(file_path, "r") as h5_file: + requests = _request.get("attributes") + if requests is not None: + _emit({"status": "batch", "results": [_read_one_attribute(h5_file, request) for request in requests]}) + else: + _emit( + _read_one_attribute( + h5_file, + { + "object_path": _request["object_path"], + "attr_name": _request["attr_name"], + "mode": _request["mode"], + "max_bytes": _request["max_bytes"], + "max_items": _request["max_items"], + }, + ) + ) +except Exception as exc: + _emit({"status": "error", "error_type": type(exc).__name__, "error": str(exc)[:200]}) +""" + _KERAS_VERSION_SEPARATOR = r"[._-]?" _KERAS_LOCAL_VERSION_SUFFIX = r"\+[a-z0-9]+(?:[._-][a-z0-9]+)*" _KERAS_PRERELEASE_SUFFIX = ( @@ -180,7 +362,12 @@ class KerasH5Scanner(BaseScanner): name = "keras_h5" description = "Scans Keras H5 model files for suspicious layer configurations" supported_extensions: ClassVar[list[str]] = [".h5", ".hdf5", ".keras"] + # HDF5 inspection is file-backed through h5py and uses bounded metadata/link + # traversal, so total file size is not a whole-file read/memory proxy. + default_max_file_read_size: ClassVar[int] = 0 _JSON_ATTRIBUTE_PARSE_FAILED: ClassVar[object] = object() + _HDF5_ATTRIBUTE_READ_SKIPPED: ClassVar[object] = object() + _HDF5_ATTRIBUTE_MISSING: ClassVar[object] = object() _SAFE_K_BACKEND_LAMBDA_FUNCTIONS: ClassVar[frozenset[str]] = frozenset( {"abs", "elu", "hard_sigmoid", "l2_normalize", "relu", "sigmoid", "softmax", "softplus", "softsign", "tanh"} ) @@ -343,7 +530,15 @@ class KerasH5Scanner(BaseScanner): _MAX_HDF5_LINK_VISITS: ClassVar[int] = 4096 _MAX_HDF5_EXTERNAL_REFERENCE_REPORTS: ClassVar[int] = 20 _MAX_HDF5_EXTERNAL_STORAGE_SEGMENT_REPORTS: ClassVar[int] = 20 + _MAX_HDF5_VIRTUAL_SOURCE_REPORTS: ClassVar[int] = 20 + _MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS: ClassVar[int] = _MAX_HDF5_LINK_VISITS _MAX_HDF5_REFERENCE_TEXT_CHARS: ClassVar[int] = 4096 + _MAX_HDF5_JSON_ATTRIBUTE_BYTES: ClassVar[int] = 10 * 1024 * 1024 + _MAX_HDF5_NAME_ATTRIBUTE_BYTES: ClassVar[int] = 10 * 1024 * 1024 + _MAX_HDF5_SOFT_LINK_RESOLUTION_DEPTH: ClassVar[int] = 32 + _HDF5_ATTRIBUTE_WORKER_FILE_SIZE_THRESHOLD: ClassVar[int] = DEFAULT_MAX_FILE_READ_SIZE + _HDF5_ATTRIBUTE_WORKER_MEMORY_BYTES: ClassVar[int] = 256 * 1024 * 1024 + _HDF5_ATTRIBUTE_WORKER_TIMEOUT_SECONDS: ClassVar[float] = 15.0 _MAX_SERIALIZED_CONFIG_NODES: ClassVar[int] = 10_000 _MODEL_CONTAINER_CLASSES: ClassVar[frozenset[str]] = frozenset({"Model", "Functional", "Sequential"}) _WRAPPED_LAYER_SCAN_MODEL: ClassVar[dict[str, Any]] = {"class_name": "Sequential", "config": {"layers": []}} @@ -360,6 +555,10 @@ def __init__(self, config: dict[str, Any] | None = None): self.suspicious_config_props.extend(config["suspicious_config_properties"]) self._current_h5_keras_version: str | None = None self._checked_config_module_references: set[tuple[int, str, str]] = set() + self.max_hdf5_json_attribute_bytes = self._normalize_positive_int_config( + self.config.get("max_hdf5_json_attribute_bytes"), + self._MAX_HDF5_JSON_ATTRIBUTE_BYTES, + ) self._remaining_serialized_config_nodes = self._MAX_SERIALIZED_CONFIG_NODES self._serialized_config_limit_reported = False @@ -388,18 +587,12 @@ def scan(self, path: str) -> ScanResult: self._remaining_serialized_config_nodes = self._MAX_SERIALIZED_CONFIG_NODES self._serialized_config_limit_reported = False - # Check if path is valid - path_check_result = self._check_path(path) - if path_check_result: - return path_check_result - - size_check = self._check_size_limit(path) - if size_check: - return size_check + result = self._create_scan_result_after_preflight(path) + if not result.success: + return result # Check if h5py is installed if not HAS_H5PY: - result = self._create_result() reason = "keras_h5_h5py_unavailable" result.metadata["file_size"] = self.get_file_size(path) self._mark_inconclusive_scan_result(result, reason) @@ -420,12 +613,17 @@ def scan(self, path: str) -> ScanResult: self._finish_scan_result(result) return result - result = self._create_result() file_size = self.get_file_size(path) result.metadata["file_size"] = file_size - - # Add file integrity check for compliance - self.add_file_integrity_check(path, result) + whole_file_hash_skipped = file_size > DEFAULT_MAX_FILE_READ_SIZE + self._add_file_backed_hdf5_inspection_check( + path, + result, + file_size, + whole_file_hash_skipped=whole_file_hash_skipped, + ) + if not whole_file_hash_skipped: + self.add_file_integrity_check(path, result) try: # Store the file path for use in issue locations @@ -434,7 +632,18 @@ def scan(self, path: str) -> ScanResult: with h5py.File(path, "r") as f: result.bytes_scanned = file_size raw_keras_version: str | None = None - keras_version_attr = f.attrs.get("keras_version") + keras_version_attr = self._read_bounded_hdf5_attribute( + f.attrs, + "keras_version", + result, + max_bytes=self._MAX_HDF5_REFERENCE_TEXT_CHARS, + fail_closed=False, + ) + if ( + keras_version_attr is self._HDF5_ATTRIBUTE_MISSING + or keras_version_attr is self._HDF5_ATTRIBUTE_READ_SKIPPED + ): + keras_version_attr = None if isinstance(keras_version_attr, bytes): keras_version_attr = keras_version_attr.decode("utf-8", errors="ignore") if isinstance(keras_version_attr, str) and keras_version_attr.strip(): @@ -442,14 +651,23 @@ def scan(self, path: str) -> ScanResult: result.metadata["keras_version"] = redact_evidence_string(raw_keras_version) self._current_h5_keras_version = raw_keras_version + model_config_attr = self._read_bounded_hdf5_attribute( + f.attrs, + "model_config", + result, + max_bytes=self.max_hdf5_json_attribute_bytes, + fail_closed=True, + ) + has_model_config = model_config_attr is not self._HDF5_ATTRIBUTE_MISSING + # CVE-2026-1669 applies to weight loading too. Inspect full # Keras files and weights-like HDF5 layouts while leaving # unrelated generic HDF5 artifacts quiet. - if "model_config" in f.attrs or self._has_weights_like_hdf5_layout(f, path): + if has_model_config or self._has_weights_like_hdf5_layout(f, path): self._check_hdf5_external_references(f, result, path) # Check if this is a Keras model file - if "model_config" not in f.attrs: + if not has_model_config: # Check if this might be a TensorFlow SavedModel H5 file instead # Look for common TensorFlow H5 structure patterns is_tensorflow_h5 = any( @@ -481,7 +699,10 @@ def scan(self, path: str) -> ScanResult: return result # Parse model config - model_config = self._load_json_attribute(f.attrs["model_config"], result, "model_config") + if model_config_attr is self._HDF5_ATTRIBUTE_READ_SKIPPED: + model_config = self._JSON_ATTRIBUTE_PARSE_FAILED + else: + model_config = self._load_json_attribute(model_config_attr, result, "model_config") # Scan model configuration if model_config is self._JSON_ATTRIBUTE_PARSE_FAILED: @@ -500,9 +721,22 @@ def scan(self, path: str) -> ScanResult: ) # Check for custom objects in the model - if "custom_objects" in f.attrs: - custom_objects_attr = f.attrs["custom_objects"] - custom_objects_list = list(custom_objects_attr) if custom_objects_attr is not None else [] + custom_objects_attr = self._read_bounded_hdf5_attribute( + f.attrs, + "custom_objects", + result, + max_bytes=self.max_hdf5_json_attribute_bytes, + fail_closed=False, + ) + if custom_objects_attr is not self._HDF5_ATTRIBUTE_MISSING: + custom_objects_truncated = custom_objects_attr is self._HDF5_ATTRIBUTE_READ_SKIPPED + if custom_objects_attr is None or custom_objects_truncated: + custom_objects_list = [] + else: + try: + custom_objects_list = list(custom_objects_attr) + except TypeError: + custom_objects_list = [custom_objects_attr] result.add_check( name="Custom Objects Security Check", passed=False, @@ -510,12 +744,25 @@ def scan(self, path: str) -> ScanResult: severity=IssueSeverity.INFO, location=f"{self.current_file_path} (model_config)", rule_code="S302", - details={"custom_objects": redact_evidence_value(custom_objects_list, max_string_chars=200)}, + details={ + "custom_objects": redact_evidence_value(custom_objects_list, max_string_chars=200), + "custom_objects_truncated": custom_objects_truncated, + }, ) # Check for custom metrics and custom loss - if "training_config" in f.attrs: - training_config = self._load_json_attribute(f.attrs["training_config"], result, "training_config") + training_config_attr = self._read_bounded_hdf5_attribute( + f.attrs, + "training_config", + result, + max_bytes=self.max_hdf5_json_attribute_bytes, + fail_closed=True, + ) + if training_config_attr is not self._HDF5_ATTRIBUTE_MISSING: + if training_config_attr is self._HDF5_ATTRIBUTE_READ_SKIPPED: + training_config = self._JSON_ATTRIBUTE_PARSE_FAILED + else: + training_config = self._load_json_attribute(training_config_attr, result, "training_config") if training_config is not self._JSON_ATTRIBUTE_PARSE_FAILED: self._scan_training_config(training_config, result) @@ -589,8 +836,439 @@ def _finish_scan_result(cls, result: ScanResult) -> None: result.finish(success=not result.has_errors) + def _add_file_backed_hdf5_inspection_check( + self, + path: str, + result: ScanResult, + file_size: int, + *, + whole_file_hash_skipped: bool, + ) -> None: + """Record that Keras H5 security inspection does not require whole-file materialization.""" + result.metadata["file_backed_scan"] = True + result.add_check( + name="Keras H5 File-Backed Inspection", + passed=True, + message="Keras H5 inspection uses file-backed HDF5 metadata traversal", + location=path, + details={ + "file_size": file_size, + "file_backed": True, + "whole_file_materialized": False, + "whole_file_hash_skipped": whole_file_hash_skipped, + "max_hdf5_link_visits": self._MAX_HDF5_LINK_VISITS, + "max_hdf5_json_attribute_bytes": self.max_hdf5_json_attribute_bytes, + }, + ) + + @staticmethod + def _json_attribute_size(attr_value: Any) -> int | None: + """Best-effort byte size for a loaded JSON-like HDF5 attribute.""" + if isinstance(attr_value, str): + return len(attr_value.encode("utf-8", errors="ignore")) + if isinstance(attr_value, bytes | bytearray | memoryview): + return len(attr_value) + + nbytes = getattr(attr_value, "nbytes", None) + if isinstance(nbytes, int): + return nbytes + + return None + + @staticmethod + def _hdf5_attribute_is_variable_string(attr_id: Any) -> bool: + try: + attr_type = attr_id.get_type() + is_variable_str = getattr(attr_type, "is_variable_str", None) + return bool(is_variable_str()) if callable(is_variable_str) else False + except Exception: + return False + + @staticmethod + def _read_hdf5_variable_string_attribute( + attr_id: Any, + *, + max_bytes: int, + max_items: int, + max_item_chars: int | None = None, + ) -> tuple[str | list[str] | None, bool, int | None]: + import numpy as np + + try: + space = attr_id.get_space() + point_count = int(space.get_simple_extent_npoints()) + except Exception: + return None, True, None + if point_count <= 0 or point_count > max_items: + return None, True, None + + per_item_limit = max(max_bytes // point_count, 1) + if max_item_chars is not None: + per_item_limit = min(max_item_chars, per_item_limit) + per_item_bytes = per_item_limit + 1 + if point_count * per_item_bytes > max_bytes + max_items: + return None, True, point_count * per_item_bytes + + try: + dims = tuple(int(dimension) for dimension in space.get_simple_extent_dims()) + except Exception: + dims = () + buffer_shape = dims if dims else () + memory_type = h5py.h5t.C_S1.copy() + memory_type.set_size(per_item_bytes) + buffer = np.zeros(buffer_shape, dtype=f"S{per_item_bytes}") + try: + attr_id.read(buffer, memory_type) + except Exception: + return None, True, None + + raw_items = [buffer.item()] if buffer_shape == () else buffer.reshape(-1).tolist() + values: list[str] = [] + total_bytes = 0 + truncated = False + for raw_item in raw_items: + raw_bytes = bytes(raw_item) if isinstance(raw_item, bytes) else str(raw_item).encode("utf-8", "ignore") + total_bytes += len(raw_bytes) + if len(raw_bytes) >= per_item_bytes or total_bytes > max_bytes: + truncated = True + values.append(raw_bytes.decode("utf-8", errors="ignore")) + values = [value for value in values if value] + if len(values) > max_items: + values = values[:max_items] + truncated = True + if not values: + return "", truncated, total_bytes + return (values[0] if len(values) == 1 else values), truncated, total_bytes + + def _mark_hdf5_attribute_size_limit( + self, + result: ScanResult, + attr_name: str, + *, + attr_size: int, + max_bytes: int, + fail_closed: bool, + ) -> None: + reason = f"keras_h5_{attr_name}_size_limit_exceeded" + if fail_closed: + self._mark_inconclusive_scan_result(result, reason) + result.add_check( + name="Keras H5 Config Size Limit" + if attr_name in {"model_config", "training_config"} + else "Keras H5 Attribute Size Limit", + passed=False, + message=f"Keras H5 {attr_name} exceeds bounded parse budget", + severity=IssueSeverity.INFO, + location=self.current_file_path, + details={ + "attribute": attr_name, + "attribute_bytes": attr_size, + "max_attribute_bytes": max_bytes, + "analysis_incomplete": fail_closed, + "scan_outcome_reason": reason, + }, + rule_code="S902", + ) + + def _mark_hdf5_attribute_native_limit( + self, + result: ScanResult, + attr_name: str, + *, + fail_closed: bool, + reason: str | None = None, + ) -> None: + """Record that native HDF5 attribute handling exceeded the isolated worker budget.""" + scan_reason = f"keras_h5_{attr_name}_native_attribute_limit_exceeded" + if fail_closed: + self._mark_inconclusive_scan_result(result, scan_reason) + result.add_check( + name="Keras H5 Attribute Native Memory Limit", + passed=False, + message=f"Keras H5 {attr_name} metadata could not be inspected within the native memory budget", + severity=IssueSeverity.INFO, + location=self.current_file_path, + details={ + "attribute": attr_name, + "max_worker_memory_bytes": self._HDF5_ATTRIBUTE_WORKER_MEMORY_BYTES, + "analysis_incomplete": fail_closed, + "scan_outcome_reason": scan_reason, + "reason": reason or "worker_failed", + }, + rule_code="S902", + ) + + @staticmethod + def _decode_hdf5_attribute_worker_value(value: Any) -> Any: + if isinstance(value, dict) and set(value) == {"__bytes__"}: + with suppress(Exception): + import base64 + + return base64.b64decode(value["__bytes__"]) + return b"" + if isinstance(value, list): + return [KerasH5Scanner._decode_hdf5_attribute_worker_value(item) for item in value] + return value + + @classmethod + def _hdf5_attribute_context(cls, attrs: Any) -> tuple[str, str] | None: + """Return file and object paths for an AttributeManager without opening attrs.""" + try: + from h5py import h5f, h5i + + object_id = attrs._id + file_path = os.fsdecode(h5f.get_name(object_id)) + object_path = os.fsdecode(h5i.get_name(object_id)) or "/" + except Exception: + return None + if not file_path: + return None + return file_path, object_path + + @classmethod + def _should_use_hdf5_attribute_worker(cls, attrs: Any) -> bool: + context = cls._hdf5_attribute_context(attrs) + if context is None: + return False + file_path, _object_path = context + with suppress(OSError): + return os.path.getsize(file_path) > cls._HDF5_ATTRIBUTE_WORKER_FILE_SIZE_THRESHOLD + return False + + @classmethod + def _read_hdf5_attribute_in_worker( + cls, + attrs: Any, + attr_name: str, + *, + mode: str, + max_bytes: int, + max_items: int, + ) -> dict[str, Any]: + return cls._read_hdf5_attributes_in_worker( + [ + { + "attrs": attrs, + "attr_name": attr_name, + "mode": mode, + "max_bytes": max_bytes, + "max_items": max_items, + } + ] + )[0] + + @classmethod + def _read_hdf5_attributes_in_worker(cls, requests: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not requests: + return [] + + error_results = [{"status": "error", "reason": "missing_context"} for _request in requests] + contexts = [cls._hdf5_attribute_context(request["attrs"]) for request in requests] + if any(context is None for context in contexts): + return error_results + + file_paths = {context[0] for context in contexts if context is not None} + if len(file_paths) != 1: + return [{"status": "error", "reason": "mixed_file_batch"} for _request in requests] + + file_path = file_paths.pop() + attribute_requests = [ + { + "object_path": context[1] if context is not None else "/", + "attr_name": request["attr_name"], + "mode": request["mode"], + "max_bytes": request["max_bytes"], + "max_items": request["max_items"], + } + for request, context in zip(requests, contexts, strict=True) + ] + worker_result = cls._run_hdf5_attribute_worker( + { + "file_path": file_path, + "attributes": attribute_requests, + "max_text_chars": cls._MAX_HDF5_REFERENCE_TEXT_CHARS, + "memory_limit_bytes": cls._HDF5_ATTRIBUTE_WORKER_MEMORY_BYTES, + } + ) + if worker_result.get("status") == "batch" and isinstance(worker_result.get("results"), list): + results = worker_result["results"] + if len(results) == len(requests): + return results + return [worker_result] * len(requests) + + @classmethod + def _run_hdf5_attribute_worker(cls, request: dict[str, Any]) -> dict[str, Any]: + env = os.environ.copy() + for env_name in ("OPENBLAS_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): + env.setdefault(env_name, "1") + try: + completed = subprocess.run( + [sys.executable, "-c", _HDF5_ATTRIBUTE_WORKER_CODE], + input=json.dumps(request), + text=True, + capture_output=True, + timeout=cls._HDF5_ATTRIBUTE_WORKER_TIMEOUT_SECONDS, + env=env, + check=False, + ) + except subprocess.TimeoutExpired: + return {"status": "error", "reason": "worker_timeout"} + except Exception as exc: + return {"status": "error", "reason": type(exc).__name__} + + stdout = completed.stdout.strip() + if completed.returncode != 0 or not stdout: + return { + "status": "error", + "reason": "worker_failed", + "returncode": completed.returncode, + "stderr": completed.stderr[-200:], + } + try: + return json.JSONDecoder().decode(stdout.splitlines()[-1]) + except json.JSONDecodeError: + return {"status": "error", "reason": "worker_output_invalid"} + + @classmethod + def _hdf5_attribute_exists(cls, attrs: Any, attr_name: str) -> bool: + if cls._should_use_hdf5_attribute_worker(attrs): + worker_result = cls._read_hdf5_attribute_in_worker( + attrs, + attr_name, + mode="raw", + max_bytes=1, + max_items=1, + ) + return worker_result.get("status") != "missing" + return attr_name in attrs + + def _read_bounded_hdf5_attribute( + self, + attrs: Any, + attr_name: str, + result: ScanResult, + *, + max_bytes: int, + fail_closed: bool, + ) -> Any: + """Read a small HDF5 attribute without materializing unbounded variable-length values.""" + if self._should_use_hdf5_attribute_worker(attrs): + worker_result = self._read_hdf5_attribute_in_worker( + attrs, + attr_name, + mode="raw", + max_bytes=max_bytes, + max_items=self._MAX_HDF5_LINK_VISITS, + ) + worker_status = worker_result.get("status") + if worker_status == "missing": + return self._HDF5_ATTRIBUTE_MISSING + if worker_status == "value": + return self._decode_hdf5_attribute_worker_value(worker_result.get("value")) + if worker_status == "skipped": + attr_size = worker_result.get("storage_size") + if not isinstance(attr_size, int): + attr_size = max_bytes + 1 + self._mark_hdf5_attribute_size_limit( + result, + attr_name, + attr_size=attr_size, + max_bytes=max_bytes, + fail_closed=fail_closed, + ) + return self._HDF5_ATTRIBUTE_READ_SKIPPED + + if fail_closed: + self._mark_hdf5_attribute_native_limit( + result, + attr_name, + fail_closed=fail_closed, + reason=str(worker_result.get("reason") or worker_result.get("error_type") or "worker_failed"), + ) + return self._HDF5_ATTRIBUTE_READ_SKIPPED + + if attr_name not in attrs: + return self._HDF5_ATTRIBUTE_MISSING + + attr_id = attrs.get_id(attr_name) + if self._hdf5_attribute_is_variable_string(attr_id): + attr_value, truncated, attr_size = self._read_hdf5_variable_string_attribute( + attr_id, + max_bytes=max_bytes, + max_items=self._MAX_HDF5_LINK_VISITS, + ) + if truncated: + self._mark_hdf5_attribute_size_limit( + result, + attr_name, + attr_size=attr_size if attr_size is not None else max_bytes + 1, + max_bytes=max_bytes, + fail_closed=fail_closed, + ) + return self._HDF5_ATTRIBUTE_READ_SKIPPED + return attr_value + + with suppress(Exception): + storage_size = int(attr_id.get_storage_size()) + if storage_size > max_bytes: + self._mark_hdf5_attribute_size_limit( + result, + attr_name, + attr_size=storage_size, + max_bytes=max_bytes, + fail_closed=fail_closed, + ) + return self._HDF5_ATTRIBUTE_READ_SKIPPED + + attr_value = attrs[attr_name] + attr_size = self._json_attribute_size(attr_value) + if attr_size is not None and attr_size > max_bytes: + self._mark_hdf5_attribute_size_limit( + result, + attr_name, + attr_size=attr_size, + max_bytes=max_bytes, + fail_closed=fail_closed, + ) + return self._HDF5_ATTRIBUTE_READ_SKIPPED + + return attr_value + + def _load_json_hdf5_attribute(self, attrs: Any, attr_name: str, result: ScanResult) -> Any: + attr_value = self._read_bounded_hdf5_attribute( + attrs, + attr_name, + result, + max_bytes=self.max_hdf5_json_attribute_bytes, + fail_closed=True, + ) + if attr_value is self._HDF5_ATTRIBUTE_MISSING or attr_value is self._HDF5_ATTRIBUTE_READ_SKIPPED: + return self._JSON_ATTRIBUTE_PARSE_FAILED + return self._load_json_attribute(attr_value, result, attr_name) + def _load_json_attribute(self, attr_value: Any, result: ScanResult, attr_name: str) -> Any: """Load a Keras JSON attribute, marking the scan incomplete on malformed metadata.""" + attr_size = self._json_attribute_size(attr_value) + if attr_size is not None and attr_size > self.max_hdf5_json_attribute_bytes: + reason = f"keras_h5_{attr_name}_size_limit_exceeded" + self._mark_inconclusive_scan_result(result, reason) + result.add_check( + name="Keras H5 Config Size Limit", + passed=False, + message=f"Keras H5 {attr_name} exceeds bounded parse budget", + severity=IssueSeverity.INFO, + location=self.current_file_path, + details={ + "attribute": attr_name, + "attribute_bytes": attr_size, + "max_attribute_bytes": self.max_hdf5_json_attribute_bytes, + "analysis_incomplete": True, + "scan_outcome_reason": reason, + }, + rule_code="S902", + ) + return self._JSON_ATTRIBUTE_PARSE_FAILED + try: if isinstance(attr_value, bytes): attr_value = attr_value.decode("utf-8") @@ -623,7 +1301,9 @@ def _has_weights_like_hdf5_layout(cls, h5_file: Any, _path: str) -> bool: @classmethod def _has_legacy_weights_layout(cls, h5_file: Any) -> bool: - layer_names = cls._decode_hdf5_names(h5_file.attrs.get("layer_names")) + layer_names, layer_names_truncated = cls._read_bounded_hdf5_name_attribute(h5_file.attrs, "layer_names") + if layer_names_truncated: + return True if not layer_names: return False @@ -632,13 +1312,13 @@ def _has_legacy_weights_layout(cls, h5_file: Any) -> bool: return True link = h5_file.get(layer_name, getlink=True) - if isinstance(link, h5py.ExternalLink): + if isinstance(link, (h5py.ExternalLink, h5py.SoftLink)): return True if not isinstance(link, h5py.HardLink): continue layer = h5_file.get(layer_name, getlink=False) - if isinstance(layer, h5py.Group) and "weight_names" in layer.attrs: + if isinstance(layer, h5py.Group) and cls._hdf5_attribute_exists(layer.attrs, "weight_names"): return True return False @@ -649,13 +1329,28 @@ def _has_keras3_weights_layout(cls, h5_file: Any) -> bool: layers_link = h5_file.get("layers", getlink=True) if isinstance(layers_link, h5py.ExternalLink): return True - if not isinstance(layers_link, h5py.HardLink): + if isinstance(layers_link, h5py.SoftLink): + target_path, target_link, incomplete = cls._resolve_hdf5_soft_link(h5_file, "layers", layers_link) + if incomplete or target_path is None or target_link is None: + return False + if isinstance(target_link, h5py.ExternalLink): + return True + if not isinstance(target_link, h5py.HardLink): + return False + layers = h5_file.get(target_path, getlink=False) + layer_source_prefix = target_path + elif isinstance(layers_link, h5py.HardLink): + layers = h5_file.get("layers", getlink=False) + layer_source_prefix = "layers" + else: return False - layers = h5_file.get("layers", getlink=False) if not isinstance(layers, h5py.Group): return False + if cls._has_keras3_root_vars_layout(h5_file): + return True + for index, layer_name in enumerate(layers): if index >= cls._MAX_HDF5_LAYOUT_PROBE_ITEMS: return True @@ -663,15 +1358,37 @@ def _has_keras3_weights_layout(cls, h5_file: Any) -> bool: layer_link = layers.get(layer_name, getlink=True) if isinstance(layer_link, h5py.ExternalLink): return True - if not isinstance(layer_link, h5py.HardLink): + if isinstance(layer_link, h5py.SoftLink): + source_name = f"{layer_source_prefix}/{layer_name}" if layer_source_prefix else str(layer_name) + target_path, target_link, incomplete = cls._resolve_hdf5_soft_link(h5_file, source_name, layer_link) + if incomplete or target_path is None or target_link is None: + continue + if isinstance(target_link, h5py.ExternalLink): + return True + if not isinstance(target_link, h5py.HardLink): + continue + layer = h5_file.get(target_path, getlink=False) + elif isinstance(layer_link, h5py.HardLink): + layer = layers.get(layer_name, getlink=False) + else: continue - - layer = layers.get(layer_name, getlink=False) if isinstance(layer, h5py.Group) and cls._has_group_or_external_link(layer, "vars"): return True return False + @classmethod + def _has_keras3_root_vars_layout(cls, h5_file: Any) -> bool: + if cls._has_group_or_external_link(h5_file, "vars"): + return True + + optimizer_link = h5_file.get("optimizer", getlink=True) + if not isinstance(optimizer_link, h5py.HardLink): + return False + + optimizer = h5_file.get("optimizer", getlink=False) + return isinstance(optimizer, h5py.Group) and cls._has_group_or_external_link(optimizer, "vars") + @classmethod def _hdf5_weight_scan_roots(cls, h5_file: Any) -> tuple[list[str], bool]: """Return loader-consumed HDF5 roots without following external links.""" @@ -699,8 +1416,8 @@ def add_root(path: str) -> bool: root_set.add(path) return True - def add_weight_names(group: Any, prefix: str) -> bool: - for weight_name in cls._decode_hdf5_names(group.attrs.get("weight_names")): + def add_weight_name_roots(group: Any, prefix: str, weight_names: list[str]) -> bool: + for weight_name in weight_names: if not consume_name_budget(): return False lookup_name = weight_name or "." @@ -714,31 +1431,133 @@ def add_weight_names(group: Any, prefix: str) -> bool: return False return True + def add_weight_names(group: Any, prefix: str) -> bool: + nonlocal roots_truncated + weight_names, weight_names_truncated = cls._read_bounded_hdf5_name_attribute( + group.attrs, + "weight_names", + ) + if weight_names_truncated: + roots_truncated = True + return False + return add_weight_name_roots(group, prefix, weight_names) + def add_legacy_group_roots(group: Any, prefix: str) -> None: nonlocal roots_truncated if not add_weight_names(group, prefix): return - for layer_name in cls._decode_hdf5_names(group.attrs.get("layer_names")): + layer_names, layer_names_truncated = cls._read_bounded_hdf5_name_attribute(group.attrs, "layer_names") + if layer_names_truncated: + roots_truncated = True + return + layer_groups: list[tuple[Any, str]] = [] + for layer_name in layer_names: if not consume_name_budget(): return layer_link = group.get(layer_name, getlink=True) if layer_link is None: continue layer_path = f"{prefix}/{layer_name}" if prefix else layer_name - if isinstance(layer_link, h5py.ExternalLink): + if isinstance(layer_link, (h5py.ExternalLink, h5py.SoftLink)): if not add_root(layer_path): return continue if not isinstance(layer_link, h5py.HardLink): continue layer_group = group.get(layer_name, getlink=False) - if isinstance(layer_group, h5py.Group) and not add_weight_names(layer_group, layer_path): + if isinstance(layer_group, h5py.Group): + layer_groups.append((layer_group, layer_path)) + + if not layer_groups: + return + + weight_name_results = cls._read_bounded_hdf5_name_attributes( + [(layer_group.attrs, "weight_names") for layer_group, _layer_path in layer_groups] + ) + for (layer_group, layer_path), (weight_names, weight_names_truncated) in zip( + layer_groups, + weight_name_results, + strict=True, + ): + if weight_names_truncated: + roots_truncated = True return + if not add_weight_name_roots(layer_group, layer_path, weight_names): + return + + def add_keras3_saveable_var_roots(*, excluded_prefixes: tuple[str, ...] = ()) -> None: + nonlocal inspected_name_count, roots_truncated + remaining_link_visits = max(cls._MAX_HDF5_LINK_VISITS - inspected_name_count, 0) + if remaining_link_visits == 0: + roots_truncated = True + return + normalized_excluded_prefixes = tuple(prefix.strip("/") for prefix in excluded_prefixes if prefix.strip("/")) + + def add_vars_root(name: str, _link: Any) -> None: + if any(name == prefix or name.startswith(f"{prefix}/") for prefix in normalized_excluded_prefixes): + return + if name == "vars" or name.endswith("/vars"): + add_root(name) + + visited_link_count, links_truncated = cls._visit_hdf5_links( + h5_file, + add_vars_root, + max_links=remaining_link_visits, + ) + inspected_name_count += visited_link_count + if links_truncated: + roots_truncated = True - if "model_config" in h5_file.attrs: + def add_keras3_layer_roots(layers_group: Any, *, layer_prefix: str) -> None: + for layer_name in layers_group: + if not consume_name_budget(): + break + layer_path = f"{layer_prefix}/{layer_name}" if layer_prefix else str(layer_name) + layer_link = layers_group.get(layer_name, getlink=True) + if isinstance(layer_link, (h5py.ExternalLink, h5py.SoftLink)): + if not add_root(layer_path): + break + continue + if not isinstance(layer_link, h5py.HardLink): + continue + + layer = layers_group.get(layer_name, getlink=False) + if ( + isinstance(layer, h5py.Group) + and layer.get("vars", getlink=True) is not None + and not add_root(f"{layer_path}/vars") + ): + break + + def add_keras3_roots(*, require_layers: bool) -> None: + layers_link = h5_file.get("layers", getlink=True) + if isinstance(layers_link, (h5py.ExternalLink, h5py.SoftLink)): + add_root("layers") + excluded_prefixes: tuple[str, ...] = () + if isinstance(layers_link, h5py.SoftLink): + target_path, target_link, incomplete = cls._resolve_hdf5_soft_link(h5_file, "layers", layers_link) + if not incomplete and target_path is not None and isinstance(target_link, h5py.HardLink): + excluded_prefixes = (target_path,) + add_keras3_saveable_var_roots(excluded_prefixes=excluded_prefixes) + return + if not isinstance(layers_link, h5py.HardLink): + if not require_layers: + add_keras3_saveable_var_roots() + return + + layers = h5_file.get("layers", getlink=False) + if not isinstance(layers, h5py.Group): + if not require_layers: + add_keras3_saveable_var_roots() + return + + add_keras3_layer_roots(layers, layer_prefix="layers") + add_keras3_saveable_var_roots() + + if cls._hdf5_attribute_exists(h5_file.attrs, "model_config"): for root_name in cls._KERAS_WEIGHT_ROOT_GROUPS: root_link = h5_file.get(root_name, getlink=True) - if isinstance(root_link, h5py.ExternalLink): + if isinstance(root_link, (h5py.ExternalLink, h5py.SoftLink)): if not add_root(root_name): break continue @@ -749,50 +1568,31 @@ def add_legacy_group_roots(group: Any, prefix: str) -> None: add_legacy_group_roots(root_group, root_name) if roots_truncated: break + if not roots_truncated: + add_keras3_roots(require_layers=False) return roots, roots_truncated - layer_names = cls._decode_hdf5_names(h5_file.attrs.get("layer_names")) + layer_names, layer_names_truncated = cls._read_bounded_hdf5_name_attribute(h5_file.attrs, "layer_names") + if layer_names_truncated: + return roots, True if layer_names: add_legacy_group_roots(h5_file, "") return roots, roots_truncated - layers_link = h5_file.get("layers", getlink=True) - if isinstance(layers_link, h5py.ExternalLink): - add_root("layers") - return roots, roots_truncated - if not isinstance(layers_link, h5py.HardLink): - return roots, False - - layers = h5_file.get("layers", getlink=False) - if not isinstance(layers, h5py.Group): - return roots, False - - for layer_name in layers: - if not consume_name_budget(): - break - layer_path = f"layers/{layer_name}" - layer_link = layers.get(layer_name, getlink=True) - if isinstance(layer_link, h5py.ExternalLink): - if not add_root(layer_path): - break - continue - if not isinstance(layer_link, h5py.HardLink): - continue - - layer = layers.get(layer_name, getlink=False) - if ( - isinstance(layer, h5py.Group) - and layer.get("vars", getlink=True) is not None - and not add_root(f"{layer_path}/vars") - ): - break - + add_keras3_roots(require_layers=True) return roots, roots_truncated + @staticmethod + def _is_same_file_virtual_source(filename: Any) -> bool: + try: + return os.fsdecode(filename) in {"", "."} + except TypeError: + return False + @staticmethod def _has_group_or_external_link(group: Any, name: str) -> bool: link = group.get(name, getlink=True) - if isinstance(link, h5py.ExternalLink): + if isinstance(link, (h5py.ExternalLink, h5py.SoftLink)): return True if not isinstance(link, h5py.HardLink): return False @@ -822,6 +1622,172 @@ def _decode_hdf5_names(value: Any) -> list[str]: names.append(item) return [name for name in names if name] + @classmethod + def _read_bounded_hdf5_name_attribute(cls, attrs: Any, attr_name: str) -> tuple[list[str], bool]: + """Read Keras HDF5 name-list attributes only when their encoded size is bounded.""" + if cls._should_use_hdf5_attribute_worker(attrs): + worker_result = cls._read_hdf5_attribute_in_worker( + attrs, + attr_name, + mode="names", + max_bytes=cls._MAX_HDF5_NAME_ATTRIBUTE_BYTES, + max_items=cls._MAX_HDF5_LINK_VISITS, + ) + return cls._decode_hdf5_name_attribute_worker_result(worker_result) + + if attr_name not in attrs: + return [], False + + try: + attr_id = attrs.get_id(attr_name) + except Exception: + return [], True + + attr_point_count: int | None = None + with suppress(Exception): + attr_space = attr_id.get_space() + attr_point_count = int(attr_space.get_simple_extent_npoints()) + if attr_point_count > cls._MAX_HDF5_LINK_VISITS: + return [], True + + if cls._hdf5_attribute_is_variable_string(attr_id): + return cls._read_hdf5_variable_string_name_attribute( + attr_id, + max_bytes=cls._MAX_HDF5_NAME_ATTRIBUTE_BYTES, + point_count=attr_point_count, + ) + + with suppress(Exception): + if int(attr_id.get_storage_size()) > cls._MAX_HDF5_NAME_ATTRIBUTE_BYTES: + return [], True + + try: + attr_value = attrs[attr_name] + except Exception: + return [], True + + attr_size = cls._json_attribute_size(attr_value) + if attr_size is not None and attr_size > cls._MAX_HDF5_NAME_ATTRIBUTE_BYTES: + return [], True + + names = cls._decode_hdf5_names(attr_value) + if len(names) > cls._MAX_HDF5_LINK_VISITS: + return names[: cls._MAX_HDF5_LINK_VISITS], True + return names, False + + @classmethod + def _decode_hdf5_name_attribute_worker_result(cls, worker_result: dict[str, Any]) -> tuple[list[str], bool]: + worker_status = worker_result.get("status") + if worker_status == "missing": + return [], False + if worker_status == "value": + names = cls._decode_hdf5_names(worker_result.get("value")) + truncated = bool(worker_result.get("truncated")) + if len(names) > cls._MAX_HDF5_LINK_VISITS: + return names[: cls._MAX_HDF5_LINK_VISITS], True + return names, truncated + return [], True + + @classmethod + def _read_bounded_hdf5_name_attributes(cls, requests: list[tuple[Any, str]]) -> list[tuple[list[str], bool]]: + """Read several name attributes with one isolated worker when they share a large HDF5 file.""" + if not requests: + return [] + if all(cls._should_use_hdf5_attribute_worker(attrs) for attrs, _attr_name in requests): + worker_results = cls._read_hdf5_attributes_in_worker( + [ + { + "attrs": attrs, + "attr_name": attr_name, + "mode": "names", + "max_bytes": cls._MAX_HDF5_NAME_ATTRIBUTE_BYTES, + "max_items": cls._MAX_HDF5_LINK_VISITS, + } + for attrs, attr_name in requests + ] + ) + return [cls._decode_hdf5_name_attribute_worker_result(worker_result) for worker_result in worker_results] + return [cls._read_bounded_hdf5_name_attribute(attrs, attr_name) for attrs, attr_name in requests] + + @classmethod + def _read_hdf5_variable_string_name_attribute( + cls, + attr_id: Any, + *, + max_bytes: int, + point_count: int | None, + ) -> tuple[list[str], bool]: + if point_count is None: + with suppress(Exception): + point_count = int(attr_id.get_space().get_simple_extent_npoints()) + if point_count == 0: + return [], False + if point_count is None or point_count < 0 or point_count > cls._MAX_HDF5_LINK_VISITS: + return [], True + + attr_value, truncated, _attr_size = cls._read_hdf5_variable_string_attribute( + attr_id, + max_bytes=max_bytes, + max_items=cls._MAX_HDF5_LINK_VISITS, + max_item_chars=cls._MAX_HDF5_REFERENCE_TEXT_CHARS, + ) + if attr_value is None: + return [], True + names = cls._decode_hdf5_names(attr_value) + if len(names) > cls._MAX_HDF5_LINK_VISITS: + return names[: cls._MAX_HDF5_LINK_VISITS], True + return names, truncated + + @staticmethod + def _normalize_hdf5_soft_link_path(source_name: str, target_path: str) -> str | None: + if not target_path: + return None + + base_parts: list[str] = [] + if not target_path.startswith("/"): + base_parts = [part for part in source_name.split("/")[:-1] if part] + + parts: list[str] = [] + for part in [*base_parts, *target_path.split("/")]: + if part in {"", "."}: + continue + if part == "..": + return None + parts.append(part) + + return "/".join(parts) + + @classmethod + def _resolve_hdf5_soft_link( + cls, + h5_file: Any, + source_name: str, + soft_link: Any, + ) -> tuple[str | None, Any | None, bool]: + """Resolve internal SoftLinks to a final link without following ExternalLinks.""" + target_path = cls._normalize_hdf5_soft_link_path(source_name, getattr(soft_link, "path", "")) + if target_path is None: + return None, None, True + + visited_paths: set[str] = set() + for _depth in range(cls._MAX_HDF5_SOFT_LINK_RESOLUTION_DEPTH): + if target_path in visited_paths: + return target_path, None, True + visited_paths.add(target_path) + + target_link = h5_file.get(target_path, getlink=True) + if target_link is None: + return target_path, None, True + if not isinstance(target_link, h5py.SoftLink): + return target_path, target_link, False + + next_target_path = cls._normalize_hdf5_soft_link_path(target_path, getattr(target_link, "path", "")) + if next_target_path is None: + return target_path, None, True + target_path = next_target_path + + return target_path, None, True + def _check_hdf5_external_references( self, h5_file: Any, @@ -837,10 +1803,119 @@ def _check_hdf5_external_references( findings: list[dict[str, Any]] = [] external_reference_count = 0 external_storage_segments_truncated = False + virtual_dataset_sources_truncated = False + visited_virtual_source_count = 0 + soft_link_resolution_incomplete = False weight_roots, weight_roots_truncated = self._hdf5_weight_scan_roots(h5_file) + visited_link_count = 0 + link_visits_truncated = False + pending_soft_group_roots: list[tuple[str, str, Any]] = [] - def visit(name: str, link: Any) -> None: + def record_external_storage(name: str, obj: Any) -> None: nonlocal external_reference_count, external_storage_segments_truncated + storage_properties = obj.id.get_create_plist() + external_storage_segment_count = storage_properties.get_external_count() + if external_storage_segment_count <= 0: + return + + external_reference_count += 1 + if len(findings) >= self._MAX_HDF5_EXTERNAL_REFERENCE_REPORTS: + return + segments = [ + { + **self._hdf5_external_storage_filename_details(filename), + "offset": int(offset), + "size": int(size), + } + for filename, offset, size in ( + storage_properties.get_external(index) + for index in range( + min( + external_storage_segment_count, + self._MAX_HDF5_EXTERNAL_STORAGE_SEGMENT_REPORTS, + ) + ) + ) + ] + hdf5_path, hdf5_path_truncated = self._bounded_hdf5_reference_text(f"/{name}".replace("//", "/")) + external_storage_finding: dict[str, Any] = { + "kind": "external_storage", + "hdf5_path": hdf5_path, + "segments": segments, + } + if hdf5_path_truncated: + external_storage_finding["hdf5_path_truncated"] = True + if external_storage_segment_count > len(segments): + external_storage_segments_truncated = True + external_storage_finding["segment_count"] = external_storage_segment_count + external_storage_finding["segments_truncated"] = True + findings.append(external_storage_finding) + + def record_virtual_dataset_sources(name: str, obj: Any) -> None: + nonlocal external_reference_count, virtual_dataset_sources_truncated, visited_virtual_source_count + storage_properties = obj.id.get_create_plist() + if storage_properties.get_layout() != h5py.h5d.VIRTUAL: + return + virtual_source_count = storage_properties.get_virtual_count() + if virtual_source_count <= 0: + return + + sources: list[dict[str, Any]] = [] + external_source_count = 0 + remaining_source_inspections = max( + self._MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS - visited_virtual_source_count, + 0, + ) + inspected_source_count = min(virtual_source_count, remaining_source_inspections) + for index in range(inspected_source_count): + visited_virtual_source_count += 1 + raw_filename = storage_properties.get_virtual_filename(index) + if self._is_same_file_virtual_source(raw_filename): + continue + external_source_count += 1 + if len(sources) >= self._MAX_HDF5_VIRTUAL_SOURCE_REPORTS: + continue + filename, filename_truncated = self._bounded_hdf5_reference_text(raw_filename) + dataset_name, dataset_name_truncated = self._bounded_hdf5_reference_text( + storage_properties.get_virtual_dsetname(index) + ) + source_details: dict[str, Any] = { + "filename": filename, + "path": dataset_name, + } + if filename_truncated: + source_details["filename_truncated"] = True + if dataset_name_truncated: + source_details["path_truncated"] = True + sources.append(source_details) + + sources_truncated = virtual_source_count > inspected_source_count or external_source_count > len(sources) + if sources_truncated: + virtual_dataset_sources_truncated = True + if not sources: + if sources_truncated: + external_reference_count += 1 + return + + external_reference_count += 1 + if len(findings) >= self._MAX_HDF5_EXTERNAL_REFERENCE_REPORTS: + return + hdf5_path, hdf5_path_truncated = self._bounded_hdf5_reference_text(f"/{name}".replace("//", "/")) + virtual_dataset_finding: dict[str, Any] = { + "kind": "virtual_dataset", + "hdf5_path": hdf5_path, + "sources": sources, + } + if hdf5_path_truncated: + virtual_dataset_finding["hdf5_path_truncated"] = True + if sources_truncated: + virtual_dataset_finding["source_count"] = virtual_source_count + virtual_dataset_finding["sources_truncated"] = True + findings.append(virtual_dataset_finding) + + def visit(name: str, link: Any, *, obj: Any | None = None, source_name: str | None = None) -> None: + nonlocal external_reference_count, soft_link_resolution_incomplete + resolution_source_name = source_name or name if isinstance(link, h5py.ExternalLink): external_reference_count += 1 if len(findings) < self._MAX_HDF5_EXTERNAL_REFERENCE_REPORTS: @@ -862,64 +1937,99 @@ def visit(name: str, link: Any) -> None: findings.append(external_link_finding) return + if isinstance(link, h5py.SoftLink): + soft_target_path, target_link, incomplete = self._resolve_hdf5_soft_link( + h5_file, + resolution_source_name, + link, + ) + if incomplete or soft_target_path is None or target_link is None: + soft_link_resolution_incomplete = True + return + if isinstance(target_link, h5py.ExternalLink): + visit(name, target_link) + return + if not isinstance(target_link, h5py.HardLink): + return + target_obj = h5_file.get(soft_target_path, getlink=False) + if isinstance(target_obj, h5py.Dataset): + record_external_storage(name, target_obj) + record_virtual_dataset_sources(name, target_obj) + elif isinstance(target_obj, h5py.Group): + pending_soft_group_roots.append((name, soft_target_path, target_obj)) + return + if not isinstance(link, h5py.HardLink): return - obj = h5_file.get(name, getlink=False) + if obj is None: + obj = h5_file.get(name, getlink=False) if isinstance(obj, h5py.Dataset): - storage_properties = obj.id.get_create_plist() - external_storage_segment_count = storage_properties.get_external_count() - if external_storage_segment_count > 0: - external_reference_count += 1 - if len(findings) >= self._MAX_HDF5_EXTERNAL_REFERENCE_REPORTS: - return - segments = [ - { - **self._hdf5_external_storage_filename_details(filename), - "offset": int(offset), - "size": int(size), - } - for filename, offset, size in ( - storage_properties.get_external(index) - for index in range( - min( - external_storage_segment_count, - self._MAX_HDF5_EXTERNAL_STORAGE_SEGMENT_REPORTS, - ) - ) - ) - ] - hdf5_path, hdf5_path_truncated = self._bounded_hdf5_reference_text(f"/{name}".replace("//", "/")) - external_storage_finding: dict[str, Any] = { - "kind": "external_storage", - "hdf5_path": hdf5_path, - "segments": segments, - } - if hdf5_path_truncated: - external_storage_finding["hdf5_path_truncated"] = True - if external_storage_segment_count > len(segments): - external_storage_segments_truncated = True - external_storage_finding["segment_count"] = external_storage_segment_count - external_storage_finding["segments_truncated"] = True - findings.append(external_storage_finding) + record_external_storage(name, obj) + record_virtual_dataset_sources(name, obj) + + def drain_pending_soft_group_roots() -> None: + nonlocal visited_link_count, link_visits_truncated + while pending_soft_group_roots and not link_visits_truncated: + if visited_link_count >= self._MAX_HDF5_LINK_VISITS: + link_visits_truncated = True + return + alias_prefix, source_prefix, soft_group = pending_soft_group_roots.pop(0) + remaining_link_visits = self._MAX_HDF5_LINK_VISITS - visited_link_count + + def visit_soft_group( + name: str, + link: Any, + *, + prefix: str = alias_prefix, + resolved_prefix: str = source_prefix, + ) -> None: + visit(f"{prefix}/{name}", link, source_name=f"{resolved_prefix}/{name}") + + soft_group_visited, soft_group_truncated = self._visit_hdf5_links( + soft_group, + visit_soft_group, + max_links=remaining_link_visits, + ) + visited_link_count += soft_group_visited + if soft_group_truncated: + link_visits_truncated = True - visited_link_count = 0 - link_visits_truncated = False for root_path in weight_roots: root_link = h5_file.get(root_path, getlink=True) if root_link is None: continue + resolved_root_link = root_link if isinstance(root_link, h5py.ExternalLink): visited_link_count += 1 visit(root_path, root_link) continue - if not isinstance(root_link, h5py.HardLink): + if isinstance(root_link, h5py.SoftLink): + visited_link_count += 1 + target_path, target_link, incomplete = self._resolve_hdf5_soft_link(h5_file, root_path, root_link) + if incomplete or target_path is None or target_link is None: + soft_link_resolution_incomplete = True + continue + if isinstance(target_link, h5py.ExternalLink): + visit(root_path, target_link) + continue + if not isinstance(target_link, h5py.HardLink): + continue + root_obj = h5_file.get(target_path, getlink=False) + if isinstance(root_obj, h5py.Dataset): + visit(root_path, target_link, obj=root_obj) + continue + resolved_root_path = target_path + resolved_root_link = target_link + else: + resolved_root_path = root_path + if not isinstance(resolved_root_link, h5py.HardLink): continue - root_obj = h5_file.get(root_path, getlink=False) + root_obj = h5_file.get(resolved_root_path, getlink=False) if isinstance(root_obj, h5py.Dataset): visited_link_count += 1 - visit(root_path, root_link) + visit(root_path, resolved_root_link, obj=root_obj) continue if not isinstance(root_obj, h5py.Group): continue @@ -929,8 +2039,14 @@ def visit(name: str, link: Any) -> None: link_visits_truncated = True break - def visit_root(name: str, link: Any, *, prefix: str = root_path) -> None: - visit(f"{prefix}/{name}", link) + def visit_root( + name: str, + link: Any, + *, + prefix: str = root_path, + source_prefix: str = resolved_root_path, + ) -> None: + visit(f"{prefix}/{name}", link, source_name=f"{source_prefix}/{name}") root_visited, root_truncated = self._visit_hdf5_links( root_obj, @@ -941,6 +2057,9 @@ def visit_root(name: str, link: Any, *, prefix: str = root_path) -> None: if root_truncated: link_visits_truncated = True break + drain_pending_soft_group_roots() + if link_visits_truncated: + break external_references_truncated = external_reference_count > len(findings) if ( @@ -948,6 +2067,8 @@ def visit_root(name: str, link: Any, *, prefix: str = root_path) -> None: or link_visits_truncated or external_references_truncated or external_storage_segments_truncated + or virtual_dataset_sources_truncated + or soft_link_resolution_incomplete ): reason = "keras_h5_external_reference_analysis_limit_exceeded" self._mark_inconclusive_scan_result(result, reason) @@ -968,6 +2089,10 @@ def visit_root(name: str, link: Any, *, prefix: str = root_path) -> None: "reported_external_reference_count": len(findings), "external_references_truncated": external_references_truncated, "external_storage_segments_truncated": external_storage_segments_truncated, + "visited_virtual_source_count": visited_virtual_source_count, + "max_virtual_source_inspections": self._MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS, + "virtual_dataset_sources_truncated": virtual_dataset_sources_truncated, + "soft_link_resolution_incomplete": soft_link_resolution_incomplete, }, rule_code="S902", ) @@ -981,17 +2106,18 @@ def visit_root(name: str, link: Any, *, prefix: str = root_path) -> None: "cvss": 8.1, "cwe": "CWE-200, CWE-73", "description": ( - "HDF5 external storage or ExternalLink entries can cause Keras weight loading to read arbitrary " - "host files into model tensors." + "HDF5 external storage, ExternalLink, or Virtual Dataset entries can cause Keras weight loading to " + "read arbitrary host files into model tensors." ), "remediation": "Upgrade to Keras >= 3.12.1 or >= 3.13.2 and reject weights using HDF5 external references.", "external_references": findings, "affected_versions": "Keras >= 3.0.0, < 3.12.1 and >= 3.13.0, < 3.13.2", } - if external_references_truncated or external_storage_segments_truncated: + if external_references_truncated or external_storage_segments_truncated or virtual_dataset_sources_truncated: details["external_reference_count"] = external_reference_count details["external_references_truncated"] = external_references_truncated details["external_storage_segments_truncated"] = external_storage_segments_truncated + details["virtual_dataset_sources_truncated"] = virtual_dataset_sources_truncated display_keras_version = redact_evidence_string(keras_version) if isinstance(keras_version, str) else None vuln_status = self._is_vulnerable_to_cve_2026_1669(keras_version) if isinstance(keras_version, str) else None diff --git a/modelaudit/utils/file/large_file_handler.py b/modelaudit/utils/file/large_file_handler.py index ab7c88c63..101bad5f7 100644 --- a/modelaudit/utils/file/large_file_handler.py +++ b/modelaudit/utils/file/large_file_handler.py @@ -14,6 +14,7 @@ from ..helpers.cache_decorator import ( add_optional_dependency_availability_to_version_context, + should_bypass_cache_for_file_backed_hdf5, should_bypass_cache_for_safetensors_header_limit, should_bypass_cache_for_unavailable_hdf5_analysis, should_bypass_cache_for_zip_entry_preflight, @@ -251,6 +252,10 @@ def scan_large_file( logger.debug(f"Bypassing large-file cache for bounded SafeTensors header failure: {file_path}") return scanner.scan(file_path) # type: ignore[no-any-return] + if should_bypass_cache_for_file_backed_hdf5(file_path): + logger.debug(f"Bypassing large-file cache for file-backed HDF5 inspection: {file_path}") + return _scan_large_file_internal(file_path, scanner, progress_callback, timeout) + # If caching is disabled, proceed with direct scan if not cache_enabled: return _scan_large_file_internal(file_path, scanner, progress_callback, timeout) diff --git a/modelaudit/utils/helpers/cache_decorator.py b/modelaudit/utils/helpers/cache_decorator.py index d7a76d76f..931f25fe1 100644 --- a/modelaudit/utils/helpers/cache_decorator.py +++ b/modelaudit/utils/helpers/cache_decorator.py @@ -168,6 +168,17 @@ def should_bypass_cache_for_unavailable_hdf5_analysis(file_path: str) -> bool: return True +def should_bypass_cache_for_file_backed_hdf5(file_path: str) -> bool: + """Bypass content-hash cache probes for HDF5 files scanned through bounded file-backed inspection.""" + try: + from ...scanners.base import DEFAULT_MAX_FILE_READ_SIZE + + file_size = os.path.getsize(file_path) + except OSError: + return False + return file_size > DEFAULT_MAX_FILE_READ_SIZE and find_hdf5_signature_offset(file_path) is not None + + def should_bypass_cache_for_zip_entry_preflight(file_path: str, config: dict[str, Any]) -> bool: """Avoid cache probes that materialize an over-limit or inconsistent ZIP directory.""" try: @@ -426,6 +437,10 @@ def wrapper(*args, **kwargs): logger.debug(f"Bypassing cache because HDF5 analysis is unavailable: {file_path}") return func(*args, **kwargs) + if should_bypass_cache_for_file_backed_hdf5(file_path): + logger.debug(f"Bypassing cache for file-backed HDF5 inspection: {file_path}") + return func(*args, **kwargs) + if not cache_config.should_cache_file(file_stat.st_size, file_ext): logger.debug(f"File {file_path} not suitable for caching, calling function directly") return func(*args, **kwargs) diff --git a/tests/scanners/test_keras_h5_scanner.py b/tests/scanners/test_keras_h5_scanner.py index 864092445..a2f01537b 100644 --- a/tests/scanners/test_keras_h5_scanner.py +++ b/tests/scanners/test_keras_h5_scanner.py @@ -1,10 +1,12 @@ import base64 import json import marshal +import os import subprocess import sys import textwrap import zipfile +from collections.abc import Callable from pathlib import Path from typing import Any @@ -15,16 +17,19 @@ import h5py +import modelaudit.core as core_module from modelaudit.cache import get_cache_manager, reset_cache_manager from modelaudit.cache.optimized_config import build_cache_version_context -from modelaudit.core import determine_exit_code, scan_model_directory_or_file from modelaudit.integrations.sarif_formatter import format_sarif_output from modelaudit.scanners import keras_h5_scanner as keras_h5_scanner_module from modelaudit.scanners import keras_utils -from modelaudit.scanners.base import INCONCLUSIVE_SCAN_OUTCOME, CheckStatus, IssueSeverity +from modelaudit.scanners.base import DEFAULT_MAX_FILE_READ_SIZE, INCONCLUSIVE_SCAN_OUTCOME, CheckStatus, IssueSeverity from modelaudit.scanners.keras_h5_scanner import KerasH5Scanner from modelaudit.utils.file.hdf5 import HDF5_MAGIC, find_hdf5_signature_offset, hdf5_metadata_checksum -from modelaudit.utils.helpers.cache_decorator import should_bypass_cache_for_missing_h5py +from modelaudit.utils.helpers.cache_decorator import ( + should_bypass_cache_for_file_backed_hdf5, + should_bypass_cache_for_missing_h5py, +) ASSETS_DIR = Path(__file__).parent.parent / "assets" / "samples" / "keras" @@ -128,6 +133,18 @@ def create_raw_config_h5_file( return h5_path +def inflate_h5_file_to_size(path: Path, minimum_size: int = DEFAULT_MAX_FILE_READ_SIZE + 4096) -> None: + """Make an HDF5 fixture appear large using sparse trailing padding.""" + with path.open("ab") as handle: + handle.truncate(minimum_size) + + +def assert_not_rejected_by_read_cap(result: Any) -> None: + reasons = result.metadata.get("scan_outcome_reasons", []) + assert "max_file_read_size_exceeded" not in reasons + assert not any(check.name == "File Size Limit" and check.status == CheckStatus.FAILED for check in result.checks) + + def create_h5_with_external_link( tmp_path: Path, *, @@ -384,178 +401,1474 @@ def test_keras_h5_non_string_layer_class_fails_closed_without_abort(tmp_path: Pa }, ) - result = KerasH5Scanner().scan(str(model_path)) + result = KerasH5Scanner().scan(str(model_path)) + + type_checks = [check for check in result.checks if check.name == "Layer Class Type Validation"] + assert len(type_checks) == 1 + assert type_checks[0].severity == IssueSeverity.WARNING + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_layer_class_invalid_type" in result.metadata["scan_outcome_reasons"] + assert result.metadata["layer_counts"][""] == 1 + assert raw_secret not in result.to_json() + + +def test_keras_h5_non_string_model_class_preserves_nested_cve_detection(tmp_path: Path) -> None: + """Malformed root metadata must not suppress scanning of nested layers.""" + raw_secret = "sk-proj-CAND061H5MODELCLASSSECRET000000000000" + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": {"api_key": raw_secret}, + "config": { + "layers": [ + {"class_name": "Lambda", "config": {"function": "lambda x: x"}}, + ] + }, + }, + keras_version="3.10.0", + ) + + result = KerasH5Scanner().scan(str(model_path)) + + type_checks = [check for check in result.checks if check.name == "Model Class Type Validation"] + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2025-9905"] + assert len(type_checks) == 1 + assert len(cve_issues) == 1 + assert cve_issues[0].severity == IssueSeverity.CRITICAL + assert result.metadata["model_class"] == "" + assert "keras_h5_model_class_invalid_type" in result.metadata["scan_outcome_reasons"] + assert raw_secret not in result.to_json() + + +@pytest.mark.parametrize( + "fixture_factory", + [create_h5_with_external_link, create_h5_with_external_storage], +) +def test_keras_h5_scanner_flags_external_references_despite_fixed_file_version( + tmp_path: Path, + fixture_factory: Any, +) -> None: + """Standalone H5 files cannot use artifact-controlled keras_version to suppress external refs.""" + model_path = fixture_factory(tmp_path, keras_version="3.13.2") + + scanner = KerasH5Scanner() + result = scanner.scan(str(model_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].severity == IssueSeverity.WARNING + assert cve_issues[0].details["keras_version"] == "3.13.2" + assert cve_issues[0].details["parse_status"] == "untrusted_artifact_version" + assert cve_issues[0].details["version_source"] == "hdf5_file_attribute" + assert not any( + check.name == "HDF5 External Weight Reference Version Check" and check.status == CheckStatus.PASSED + for check in result.checks + ) + + +def test_keras_h5_scanner_fixed_metadata_without_external_refs_stays_quiet(tmp_path: Path) -> None: + """Fixed-looking metadata alone should not produce external-reference noise.""" + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": { + "name": "sequential", + "layers": [{"class_name": "Dense", "config": {"units": 1}}], + }, + }, + keras_version="3.13.2", + file_name="fixed_no_external_refs.h5", + ) + + result = KerasH5Scanner().scan(str(model_path)) + + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + assert not any(check.name.startswith("HDF5 External Weight Reference") for check in result.checks) + + +def test_keras_h5_metadata_redacts_model_controlled_identifiers(tmp_path: Path) -> None: + raw_secret = "sk-proj-KERASH5METADATASECRET1234567890" + model_path = tmp_path / "metadata_redaction.h5" + model_config = { + "class_name": f"Model_{raw_secret}", + "config": {"layers": [{"class_name": f"Layer_{raw_secret}", "config": {}}]}, + } + + with h5py.File(model_path, "w") as h5_file: + h5_file.attrs["model_config"] = json.dumps(model_config) + h5_file.attrs["keras_version"] = f"3.10.0+{raw_secret}" + h5_file.create_group(f"group_{raw_secret}") + weights = h5_file.create_group("model_weights") + weights.create_dataset(f"kernel_{raw_secret}", data=[1.0]) + + metadata = KerasH5Scanner().extract_metadata(str(model_path)) + serialized_metadata = json.dumps(metadata, default=str) + + assert raw_secret not in serialized_metadata + assert metadata["has_model_config"] is True + assert metadata["has_model_weights"] is True + assert metadata["total_parameters"] == 1 + assert metadata["model_class"] == "Model_" + assert metadata["keras_version"] == "3.10.0+" + assert metadata["layer_types"] == ["Layer_"] + assert metadata["parameter_details"] == [{"name": "kernel_", "shape": [1], "dtype": "float64", "size": 1}] + assert "group_" in metadata["h5_keys"] + + +def test_keras_h5_metadata_redacts_extraction_error( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + raw_secret = "ATTACKER_CONTROLLED_KERAS_H5_METADATA_FAILURE" + model_path = create_mock_h5_file(tmp_path) + + def fail_h5py_open(*_args: Any, **_kwargs: Any) -> None: + raise RuntimeError(raw_secret) + + monkeypatch.setattr(keras_h5_scanner_module.h5py, "File", fail_h5py_open) + + metadata = KerasH5Scanner().extract_metadata(str(model_path)) + + assert metadata["extraction_error"] == "" + assert raw_secret not in json.dumps(metadata, default=str) + + +@pytest.mark.parametrize("keras_version", ["3.13.x", "2.12.0-gpu", "3.13.2rc1junk", "3.13.2+"]) +def test_keras_h5_scanner_unparseable_external_reference_versions_mark_unknown_risk( + tmp_path: Path, + keras_version: str, +) -> None: + model_path = create_h5_with_external_link(tmp_path, keras_version=keras_version) + + result = KerasH5Scanner().scan(str(model_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].severity == IssueSeverity.WARNING + assert cve_issues[0].details["keras_version"] == keras_version + assert cve_issues[0].details["parse_status"] == "unknown" + assert any("is non-canonical" in issue.message for issue in cve_issues) + + assert not any( + check.name == "HDF5 External Weight Reference Version Check" and check.status == CheckStatus.PASSED + for check in result.checks + ) + + +def test_keras_h5_scanner_benign_model_has_no_warning_noise(tmp_path: Path) -> None: + """Benign H5 models should not produce warning or critical noise.""" + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": { + "name": "sequential", + "layers": [{"class_name": "Dense", "config": {"units": 1}}], + }, + }, + ) + + scanner = KerasH5Scanner() + result = scanner.scan(str(model_path)) + + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +def test_large_benign_keras_h5_scans_file_backed_without_default_read_cap( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Valid large HDF5 Keras files should reach h5py-backed inspection.""" + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + file_name="large_benign.h5", + ) + inflate_h5_file_to_size(model_path) + + def fail_hash(_self: KerasH5Scanner, _path: str) -> dict[str, str | None]: + pytest.fail("Keras H5 scanning must not hash/read the whole file") + + monkeypatch.setattr(KerasH5Scanner, "calculate_file_hashes", fail_hash) + monkeypatch.setattr( + core_module, + "_calculate_file_hash", + lambda _path: pytest.fail("Core must not hash large HDF5 before Keras H5 dispatch"), + ) + + result = KerasH5Scanner().scan(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + + assert model_path.stat().st_size > DEFAULT_MAX_FILE_READ_SIZE + assert result.success is True + assert audit_result.success is True + assert audit_result.content_hash is None + assert result.metadata["file_backed_scan"] is True + assert_not_rejected_by_read_cap(result) + assert any(check.name == "Keras H5 File-Backed Inspection" for check in result.checks) + assert not any(check.name == "File Integrity Hash" for check in result.checks) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +def test_large_keras_h5_directory_scan_defers_core_hashing( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_dir = tmp_path / "models" + model_dir.mkdir() + model_path = create_custom_h5_file( + model_dir, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + file_name="large_directory_model.h5", + ) + inflate_h5_file_to_size(model_path) + + monkeypatch.setattr( + core_module, + "_calculate_file_hash", + lambda _path: pytest.fail("Directory scan must not hash large HDF5 before Keras H5 dispatch"), + ) + monkeypatch.setattr( + KerasH5Scanner, + "calculate_file_hashes", + lambda _self, _path: pytest.fail("Keras H5 scanner must not hash large HDF5"), + ) + + audit_result = core_module.scan_model_directory_or_file(str(model_dir), cache_enabled=False) + metadata = audit_result.file_metadata[str(model_path)] + + assert audit_result.success is True + assert audit_result.files_scanned == 1 + assert audit_result.content_hash is None + assert "keras_h5" in audit_result.scanner_names + assert "max_file_read_size_exceeded" not in (getattr(metadata, "model_extra", {}) or {}).get( + "scan_outcome_reasons", + [], + ) + + +def test_large_keras_h5_streaming_scan_defers_core_hashing( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + file_name="large_streamed_model.h5", + ) + inflate_h5_file_to_size(model_path) + + monkeypatch.setattr( + "modelaudit.utils.helpers.file_hash.compute_sha256_hash", + lambda _path: pytest.fail("Streaming scan must not hash large HDF5 before Keras H5 dispatch"), + ) + monkeypatch.setattr( + KerasH5Scanner, + "calculate_file_hashes", + lambda _self, _path: pytest.fail("Keras H5 scanner must not hash large HDF5"), + ) + + audit_result = core_module.scan_model_streaming( + file_generator=iter([(model_path, True)]), + timeout=30, + delete_after_scan=False, + cache_enabled=False, + ) + metadata = audit_result.file_metadata[str(model_path)] + + assert audit_result.success is True + assert audit_result.files_scanned == 1 + assert audit_result.content_hash is None + assert "keras_h5" in audit_result.scanner_names + assert "max_file_read_size_exceeded" not in metadata.get("scan_outcome_reasons", []) + + +def test_large_malicious_keras_h5_still_detects_lambda_payload(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": { + "layers": [ + { + "class_name": "Lambda", + "config": {"function": "lambda x: __import__('os').system('id')"}, + } + ] + }, + }, + keras_version="3.11.2", + file_name="large_lambda.h5", + ) + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + + assert_not_rejected_by_read_cap(result) + assert any( + check.name == "Lambda Layer Code Analysis" and check.status == CheckStatus.FAILED for check in result.checks + ) + assert any( + issue.details.get("cve_id") == "CVE-2025-9905" and issue.severity == IssueSeverity.CRITICAL + for issue in result.issues + ) + assert core_module.determine_exit_code(audit_result) == 1 + + +def test_large_malformed_keras_h5_fails_closed_without_size_limit(tmp_path: Path) -> None: + model_path = create_raw_config_h5_file( + tmp_path, + model_config_attr="{", + file_name="large_malformed_config.h5", + ) + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_model_config_parse_failed" in result.metadata["scan_outcome_reasons"] + assert_not_rejected_by_read_cap(result) + assert any(check.name == "Keras H5 Config Parse" and check.status == CheckStatus.FAILED for check in result.checks) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + assert core_module.determine_exit_code(audit_result) == 2 + + +def test_large_hdf5_external_link_still_detected_without_target_resolution(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="large_external_link.h5", + ) + with h5py.File(model_path, "a") as f: + weights_group = f.require_group("model_weights") + weights_group.attrs["layer_names"] = [b"dense"] + dense = weights_group.create_group("dense") + dense.attrs["weight_names"] = [b"linked_kernel"] + dense["linked_kernel"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": "/model_weights/dense/linked_kernel", + "filename": "missing_external_source.h5", + "path": "/payload", + }, + ] + + +def test_large_hdf5_soft_link_to_external_link_still_detected(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="large_soft_external_link.h5", + ) + with h5py.File(model_path, "a") as f: + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"soft_alias"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + dense["external_payload"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + dense["soft_alias"] = h5py.SoftLink("/model_weights/dense/external_payload") + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": "/model_weights/dense/soft_alias", + "filename": "missing_external_source.h5", + "path": "/payload", + }, + ] + + +def test_large_hdf5_soft_link_to_external_storage_still_detected(tmp_path: Path) -> None: + raw_storage = tmp_path / "weights.raw" + raw_storage.write_bytes(b"\x00" * 8) + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="large_soft_external_storage.h5", + ) + with h5py.File(model_path, "a") as f: + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"soft_alias"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + dense.create_dataset( + "external_kernel", + shape=(2,), + dtype="float32", + external=[(raw_storage.name, 0, 8)], + ) + dense["soft_alias"] = h5py.SoftLink("/model_weights/dense/external_kernel") + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "external_storage", + "hdf5_path": "/model_weights/dense/soft_alias", + "segments": [{"filename": "weights.raw", "offset": 0, "size": 8}], + }, + ] + + +def test_large_hdf5_virtual_dataset_source_still_detected(tmp_path: Path) -> None: + virtual_source = tmp_path / "virtual_source.h5" + with h5py.File(virtual_source, "w") as f: + f.create_dataset("payload", data=[1.0, 2.0]) + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="large_virtual_dataset.h5", + ) + with h5py.File(model_path, "a") as f: + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[:] = h5py.VirtualSource(virtual_source.name, "/payload", shape=(2,)) + dense.create_virtual_dataset("virtual_kernel", layout) + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "virtual_dataset", + "hdf5_path": "/model_weights/dense/virtual_kernel", + "sources": [{"filename": "virtual_source.h5", "path": "/payload"}], + }, + ] + + +def test_large_file_backed_hdf5_bypasses_cache_content_hash( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + from modelaudit.utils.file.large_file_handler import SMALL_FILE_THRESHOLD + from modelaudit.utils.helpers.secure_hasher import SecureFileHasher + + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.2", + file_name="huge_cache_bypass.h5", + ) + inflate_h5_file_to_size(model_path, SMALL_FILE_THRESHOLD + 4096) + assert should_bypass_cache_for_file_backed_hdf5(str(model_path)) is True + + def fail_if_cache_hashes_hdf5(self: SecureFileHasher, path: str) -> str: + if path == str(model_path): + pytest.fail("large file-backed HDF5 was content-hashed for cache lookup") + return "a" * 64 + + monkeypatch.setattr(SecureFileHasher, "hash_file", fail_if_cache_hashes_hdf5) + monkeypatch.setattr( + SecureFileHasher, + "hash_file_with_stat", + lambda self, path, _stat: fail_if_cache_hashes_hdf5(self, path), + ) + + reset_cache_manager() + try: + audit_result = core_module.scan_model_directory_or_file( + str(model_path), + cache_enabled=True, + cache_dir=str(tmp_path / "cache"), + min_cache_file_size=0, + max_cache_file_size=SMALL_FILE_THRESHOLD * 2, + content_hash_threshold=1, + ) + finally: + reset_cache_manager() + + assert audit_result.files_scanned == 1 + assert "keras_h5" in audit_result.scanner_names + assert core_module.determine_exit_code(audit_result) == 0 + metadata = audit_result.file_metadata[str(model_path)] + assert "max_file_read_size_exceeded" not in metadata.get("scan_outcome_reasons", []) + + +def test_keras_h5_virtual_dataset_external_source_after_report_cap_still_detected(tmp_path: Path) -> None: + late_source = tmp_path / "late_virtual_source.h5" + with h5py.File(late_source, "w") as f: + f.create_dataset("payload", data=[1.0]) + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="late_virtual_dataset.h5", + ) + with h5py.File(model_path, "a") as f: + same_file_count = KerasH5Scanner._MAX_HDF5_VIRTUAL_SOURCE_REPORTS + f.create_dataset("internal_payload", data=[float(index) for index in range(same_file_count)]) + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + layout = h5py.VirtualLayout(shape=(same_file_count + 1,), dtype="float64") + same_file_source = h5py.VirtualSource(".", "/internal_payload", shape=(same_file_count,)) + for index in range(same_file_count): + layout[index] = same_file_source[index] + external_source = h5py.VirtualSource(late_source.name, "/payload", shape=(1,)) + layout[same_file_count] = external_source[0] + dense.create_virtual_dataset("virtual_kernel", layout) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is True + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "virtual_dataset", + "hdf5_path": "/model_weights/dense/virtual_kernel", + "sources": [{"filename": "late_virtual_source.h5", "path": "/payload"}], + }, + ] + + +def test_keras_h5_virtual_dataset_source_inspection_limit_fails_closed( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="virtual_dataset_inspection_limit.h5", + ) + with h5py.File(model_path, "a") as f: + f.create_dataset("internal_payload", data=[1.0, 2.0, 3.0]) + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + layout = h5py.VirtualLayout(shape=(3,), dtype="float64") + same_file_source = h5py.VirtualSource(".", "/internal_payload", shape=(3,)) + for index in range(3): + layout[index] = same_file_source[index] + dense.create_virtual_dataset("virtual_kernel", layout) + + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS", 2) + + result = KerasH5Scanner().scan(str(model_path)) + + reason = "keras_h5_external_reference_analysis_limit_exceeded" + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert reason in result.metadata["scan_outcome_reasons"] + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + limit_checks = [check for check in result.checks if check.name == "HDF5 External Reference Analysis Limit"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert limit_checks[0].details["virtual_dataset_sources_truncated"] is True + + +def test_keras_h5_virtual_dataset_external_source_before_scan_wide_budget_still_detected( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + external_source = tmp_path / "early_virtual_source.h5" + with h5py.File(external_source, "w") as f: + f.create_dataset("payload", data=[1.0, 2.0]) + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="virtual_dataset_external_before_budget.h5", + ) + with h5py.File(model_path, "a") as f: + f.create_dataset("internal_payload", data=[1.0, 2.0]) + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[0] = h5py.VirtualSource(external_source.name, "/payload", shape=(2,))[0] + layout[1] = h5py.VirtualSource(".", "/internal_payload", shape=(2,))[1] + dense.create_virtual_dataset("virtual_kernel", layout) + + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS", 1) + + result = KerasH5Scanner().scan(str(model_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "virtual_dataset", + "hdf5_path": "/model_weights/dense/virtual_kernel", + "sources": [{"filename": "early_virtual_source.h5", "path": "/payload"}], + "source_count": 2, + "sources_truncated": True, + }, + ] + assert cve_issues[0].details["virtual_dataset_sources_truncated"] is True + limit_checks = [check for check in result.checks if check.name == "HDF5 External Reference Analysis Limit"] + assert len(limit_checks) == 1 + assert limit_checks[0].details["visited_virtual_source_count"] == 1 + assert limit_checks[0].details["max_virtual_source_inspections"] == 1 + + +def test_keras_h5_virtual_dataset_source_inspection_budget_is_scan_wide( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="virtual_dataset_scan_wide_budget.h5", + ) + with h5py.File(model_path, "a") as f: + f.create_dataset("internal_payload", data=[1.0, 2.0]) + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_a", b"virtual_b"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + same_file_source = h5py.VirtualSource(".", "/internal_payload", shape=(2,)) + for dataset_name in ("virtual_a", "virtual_b"): + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + for index in range(2): + layout[index] = same_file_source[index] + dense.create_virtual_dataset(dataset_name, layout) + + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_VIRTUAL_SOURCE_INSPECTIONS", 3) + + result = KerasH5Scanner().scan(str(model_path)) + + reason = "keras_h5_external_reference_analysis_limit_exceeded" + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert reason in result.metadata["scan_outcome_reasons"] + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + limit_checks = [check for check in result.checks if check.name == "HDF5 External Reference Analysis Limit"] + assert len(limit_checks) == 1 + assert limit_checks[0].status == CheckStatus.FAILED + assert limit_checks[0].details["visited_virtual_source_count"] == 3 + assert limit_checks[0].details["max_virtual_source_inspections"] == 3 + assert limit_checks[0].details["virtual_dataset_sources_truncated"] is True + + +def test_keras_h5_same_file_virtual_dataset_source_stays_clean(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="same_file_virtual_dataset.h5", + ) + with h5py.File(model_path, "a") as f: + f.create_dataset("internal_payload", data=[1.0, 2.0]) + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"virtual_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[:] = h5py.VirtualSource(".", "/internal_payload", shape=(2,)) + dense.create_virtual_dataset("virtual_kernel", layout) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is True + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +def test_keras_h5_scanner_flags_model_config_keras3_layer_vars_external_link(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="keras3_layer_vars_with_model_config.h5", + ) + with h5py.File(model_path, "a") as f: + legacy_dense = f.require_group("model_weights").create_group("dense") + legacy_dense.attrs["weight_names"] = [b"legacy_kernel"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + legacy_dense["legacy_kernel"] = h5py.ExternalLink("missing_legacy_source.h5", "/payload") + f.create_group("layers").create_group("dense").create_group("vars")["0"] = h5py.ExternalLink( + "missing_keras3_source.h5", + "/payload", + ) + + result = KerasH5Scanner().scan(str(model_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": "/model_weights/dense/legacy_kernel", + "filename": "missing_legacy_source.h5", + "path": "/payload", + }, + { + "kind": "ExternalLink", + "hdf5_path": "/layers/dense/vars/0", + "filename": "missing_keras3_source.h5", + "path": "/payload", + }, + ] + + +def test_keras_h5_scanner_allows_model_config_keras3_same_file_virtual_dataset(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="keras3_same_file_vds_with_model_config.h5", + ) + with h5py.File(model_path, "a") as f: + f.create_dataset("internal_payload", data=[1.0, 2.0]) + vars_group = f.create_group("layers").create_group("dense").create_group("vars") + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[:] = h5py.VirtualSource(".", "/internal_payload", shape=(2,)) + vars_group.create_virtual_dataset("0", layout) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is True + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +@pytest.mark.parametrize("root_path", ["vars", "optimizer/vars"]) +def test_keras_h5_scanner_flags_model_config_keras3_root_vars_external_link( + tmp_path: Path, + root_path: str, +) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + keras_version="3.13.1", + file_name="keras3_root_vars_with_model_config.h5", + ) + with h5py.File(model_path, "a") as f: + vars_group = f.require_group(root_path) + vars_group["0"] = h5py.ExternalLink("missing_keras3_root_source.h5", "/payload") + + result = KerasH5Scanner().scan(str(model_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": f"/{root_path}/0", + "filename": "missing_keras3_root_source.h5", + "path": "/payload", + }, + ] + + +@pytest.mark.parametrize("root_path", ["vars", "optimizer/vars"]) +def test_keras_h5_scanner_flags_keras3_root_vars_external_link(tmp_path: Path, root_path: str) -> None: + weights_path = tmp_path / "keras3_root_vars.weights.h5" + with h5py.File(weights_path, "w") as f: + f.create_group("layers").create_group("dense") + vars_group = f.require_group(root_path) + vars_group["0"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + + result = KerasH5Scanner().scan(str(weights_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": f"/{root_path}/0", + "filename": "missing_external_source.h5", + "path": "/payload", + }, + ] + + +@pytest.mark.parametrize("root_path", ["vars", "optimizer/vars"]) +def test_keras_h5_scanner_flags_keras3_root_vars_external_storage(tmp_path: Path, root_path: str) -> None: + raw_storage = tmp_path / "root_weights.raw" + raw_storage.write_bytes(b"\x00" * 8) + weights_path = tmp_path / "keras3_root_external_storage.weights.h5" + with h5py.File(weights_path, "w") as f: + f.create_group("layers").create_group("dense") + vars_group = f.require_group(root_path) + vars_group.create_dataset( + "0", + shape=(2,), + dtype="float32", + external=[(raw_storage.name, 0, 8)], + ) + + result = KerasH5Scanner().scan(str(weights_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "external_storage", + "hdf5_path": f"/{root_path}/0", + "segments": [{"filename": "root_weights.raw", "offset": 0, "size": 8}], + }, + ] + + +@pytest.mark.parametrize("root_path", ["vars", "optimizer/vars"]) +def test_keras_h5_scanner_flags_keras3_root_vars_virtual_dataset_source(tmp_path: Path, root_path: str) -> None: + virtual_source = tmp_path / "root_virtual_source.h5" + with h5py.File(virtual_source, "w") as f: + f.create_dataset("payload", data=[1.0, 2.0]) + weights_path = tmp_path / "keras3_root_virtual_vars.weights.h5" + with h5py.File(weights_path, "w") as f: + f.create_group("layers").create_group("dense") + vars_group = f.require_group(root_path) + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[:] = h5py.VirtualSource(virtual_source.name, "/payload", shape=(2,)) + vars_group.create_virtual_dataset("0", layout) + + result = KerasH5Scanner().scan(str(weights_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "virtual_dataset", + "hdf5_path": f"/{root_path}/0", + "sources": [{"filename": "root_virtual_source.h5", "path": "/payload"}], + }, + ] + + +def test_keras_h5_scanner_flags_arbitrary_keras3_saveable_vars_external_link(tmp_path: Path) -> None: + external_source = tmp_path / "external.h5" + with h5py.File(external_source, "w") as f: + f.create_dataset("payload", data=[1.0, 2.0]) + weights_path = tmp_path / "keras3_custom_child.weights.h5" + with h5py.File(weights_path, "w") as f: + f.create_group("layers").create_group("dense").create_group("vars").create_dataset("0", data=[1.0]) + f.require_group("custom_parent").require_group("custom_child").require_group("vars")["0"] = h5py.ExternalLink( + external_source.name, + "/payload", + ) + inflate_h5_file_to_size(weights_path, 536_871_936) + + result = KerasH5Scanner().scan(str(weights_path)) + audit_result = core_module.scan_model_directory_or_file(str(weights_path), cache_enabled=False) + + assert weights_path.stat().st_size == 536_871_936 + assert_not_rejected_by_read_cap(result) + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": "/custom_parent/custom_child/vars/0", + "filename": "external.h5", + "path": "/payload", + }, + ] + assert core_module.determine_exit_code(audit_result) == 1 + + +def test_keras_h5_scanner_allows_internal_arbitrary_keras3_saveable_vars(tmp_path: Path) -> None: + weights_path = tmp_path / "keras3_internal_custom_child.weights.h5" + with h5py.File(weights_path, "w") as f: + f.create_group("layers").create_group("dense").create_group("vars").create_dataset("0", data=[1.0]) + f.require_group("custom_parent").require_group("custom_child").require_group("vars").create_dataset( + "0", + data=[2.0], + ) + inflate_h5_file_to_size(weights_path, 536_871_936) + + result = KerasH5Scanner().scan(str(weights_path)) + + assert result.success is True + assert_not_rejected_by_read_cap(result) + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +def test_large_hdf5_soft_link_cycle_fails_closed_without_size_limit(tmp_path: Path) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + }, + file_name="large_soft_cycle.h5", + ) + with h5py.File(model_path, "a") as f: + dense = f.require_group("model_weights").create_group("dense") + dense.attrs["weight_names"] = [b"cycle_a"] + f["model_weights"].attrs["layer_names"] = [b"dense"] + dense["cycle_a"] = h5py.SoftLink("/model_weights/dense/cycle_b") + dense["cycle_b"] = h5py.SoftLink("/model_weights/dense/cycle_a") + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_external_reference_analysis_limit_exceeded" in result.metadata["scan_outcome_reasons"] + assert_not_rejected_by_read_cap(result) + assert any( + check.name == "HDF5 External Reference Analysis Limit" + and check.details["soft_link_resolution_incomplete"] is True + for check in result.checks + ) + + +def test_sparse_chunked_compressed_hdf5_dataset_does_not_materialize( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = create_custom_h5_file( + tmp_path, + { + "class_name": "Sequential", + "config": { + "layers": [ + { + "class_name": "Lambda", + "config": {"function": "lambda x: __import__('os').system('id')"}, + } + ] + }, + }, + keras_version="3.11.2", + file_name="sparse_compressed.h5", + ) + with h5py.File(model_path, "a") as f: + weights_group = f.require_group("model_weights") + weights_group.create_dataset( + "huge_sparse_compressed", + shape=(1024 * 1024 * 1024,), + dtype="float32", + chunks=(1024,), + compression="gzip", + fillvalue=0, + ) + inflate_h5_file_to_size(model_path) + + def fail_dataset_read(_self: Any, _key: Any) -> Any: + raise AssertionError("HDF5 dataset payload was materialized") + + def fail_read_direct(_self: Any, *_args: Any, **_kwargs: Any) -> None: + raise AssertionError("HDF5 dataset payload was materialized") + + monkeypatch.setattr(h5py.Dataset, "__getitem__", fail_dataset_read) + monkeypatch.setattr(h5py.Dataset, "read_direct", fail_read_direct) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + assert any(issue.details.get("cve_id") == "CVE-2025-9905" for issue in result.issues) + + +def test_keras_h5_oversized_config_attribute_fails_closed_before_json_parse( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = create_raw_config_h5_file( + tmp_path, + model_config_attr=json.dumps( + { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"padding": "A" * 64}}]}, + } + ), + file_name="oversized_config_attr.h5", + ) + inflate_h5_file_to_size(model_path) + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_JSON_ATTRIBUTE_BYTES", 32) + original_attr_getitem = h5py.AttributeManager.__getitem__ + + def fail_model_config_materialization(self: Any, name: str) -> Any: + if name == "model_config": + raise AssertionError("oversized Keras H5 config should not be materialized") + return original_attr_getitem(self, name) + + def fail_json_loads(_payload: Any) -> Any: + raise AssertionError("oversized Keras H5 config should not be parsed") + + monkeypatch.setattr(h5py.AttributeManager, "__getitem__", fail_model_config_materialization) + monkeypatch.setattr(keras_h5_scanner_module.json, "loads", fail_json_loads) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_model_config_size_limit_exceeded" in result.metadata["scan_outcome_reasons"] + assert_not_rejected_by_read_cap(result) + assert any( + check.name == "Keras H5 Config Size Limit" and check.status == CheckStatus.FAILED for check in result.checks + ) + + +def test_generic_hdf5_dangling_layers_soft_link_stays_clean(tmp_path: Path) -> None: + model_path = tmp_path / "generic_dangling_layers_soft_link.h5" + with h5py.File(model_path, "w") as f: + f["layers"] = h5py.SoftLink("/missing") + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is True + assert "scan_outcome" not in result.metadata + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + assert any( + check.name == "Keras Model Format Check" and check.details.get("format") == "generic_h5" + for check in result.checks + ) + + +@pytest.mark.parametrize("attr_name", ["layer_names", "weight_names"]) +def test_keras_h5_oversized_weight_name_attribute_fails_closed_before_materialization( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + attr_name: str, +) -> None: + model_path = tmp_path / f"oversized_{attr_name}.weights.h5" + with h5py.File(model_path, "w") as f: + if attr_name == "layer_names": + f.attrs["layer_names"] = [b"dense", b"A" * 64] + else: + f.attrs["layer_names"] = [b"dense"] + dense = f.create_group("dense") + dense.attrs["weight_names"] = [b"kernel:0", b"A" * 64] + dense.create_dataset("kernel:0", data=[1.0]) + inflate_h5_file_to_size(model_path) + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_NAME_ATTRIBUTE_BYTES", 32) + original_attr_getitem = h5py.AttributeManager.__getitem__ + + def fail_name_attribute_materialization(self: Any, name: str) -> Any: + if name == attr_name: + raise AssertionError(f"oversized Keras H5 {attr_name} should not be materialized") + return original_attr_getitem(self, name) + + monkeypatch.setattr(h5py.AttributeManager, "__getitem__", fail_name_attribute_materialization) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_external_reference_analysis_limit_exceeded" in result.metadata["scan_outcome_reasons"] + assert_not_rejected_by_read_cap(result) + assert any( + check.name == "HDF5 External Reference Analysis Limit" and check.details["weight_roots_truncated"] is True + for check in result.checks + ) + + +def test_large_dense_hdf5_name_attribute_uses_isolated_worker(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + import numpy as np + + model_path = tmp_path / "dense_layer_names.weights.h5" + encoded_attribute_bytes = 16 * 1024 * 1024 + element_size = 32 + with h5py.File(model_path, "w", track_order=True) as f: + f.attrs.create( + "layer_names", + np.full(encoded_attribute_bytes // element_size, b"dense", dtype=f"S{element_size}"), + ) + inflate_h5_file_to_size(model_path, 536_871_936) + + def fail_parent_attribute_access(_self: Any, name: str) -> Any: + raise AssertionError(f"large HDF5 attribute {name!r} was inspected in the parent process") + + monkeypatch.setattr(h5py.AttributeManager, "__contains__", fail_parent_attribute_access) + monkeypatch.setattr(h5py.AttributeManager, "get_id", fail_parent_attribute_access) + monkeypatch.setattr(h5py.AttributeManager, "__getitem__", fail_parent_attribute_access) + + result = KerasH5Scanner().scan(str(model_path)) + + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert "keras_h5_external_reference_analysis_limit_exceeded" in result.metadata["scan_outcome_reasons"] + assert_not_rejected_by_read_cap(result) + assert any( + check.name == "HDF5 External Reference Analysis Limit" and check.details["weight_roots_truncated"] is True + for check in result.checks + ) + + +def test_large_variable_string_model_config_uses_json_budget(tmp_path: Path) -> None: + model_path = tmp_path / "large_variable_model_config.h5" + model_config = { + "class_name": "Sequential", + "config": { + "name": "A" * 5000, + "layers": [ + { + "class_name": "Lambda", + "config": {"function": "lambda x: __import__('os').system('id')"}, + } + ], + }, + } + with h5py.File(model_path, "w") as f: + f.attrs.create( + "model_config", + json.dumps(model_config), + dtype=h5py.string_dtype(encoding="utf-8"), + ) + f.require_group("model_weights") + inflate_h5_file_to_size(model_path) + + result = KerasH5Scanner().scan(str(model_path)) + + assert_not_rejected_by_read_cap(result) + assert "keras_h5_model_config_parse_failed" not in result.metadata.get("scan_outcome_reasons", []) + assert "keras_h5_model_config_size_limit_exceeded" not in result.metadata.get("scan_outcome_reasons", []) + assert any( + check.name == "Lambda Layer Code Analysis" and check.status == CheckStatus.FAILED for check in result.checks + ) + assert any( + issue.message == "Lambda layer contains dangerous Python code" and issue.severity == IssueSeverity.CRITICAL + for issue in result.issues + ) + - type_checks = [check for check in result.checks if check.name == "Layer Class Type Validation"] - assert len(type_checks) == 1 - assert type_checks[0].severity == IssueSeverity.WARNING - assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME - assert "keras_h5_layer_class_invalid_type" in result.metadata["scan_outcome_reasons"] - assert result.metadata["layer_counts"][""] == 1 - assert raw_secret not in result.to_json() +def test_variable_string_vector_custom_objects_does_not_skip_training_config(tmp_path: Path) -> None: + model_path = tmp_path / "vector_custom_objects.h5" + model_config = { + "class_name": "Sequential", + "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}, + } + training_config = { + "loss": {"output_1": "malicious_loss"}, + "metrics": [["accuracy"]], + } + with h5py.File(model_path, "w") as f: + f.attrs["model_config"] = json.dumps(model_config) + f.attrs.create( + "custom_objects", + ["custom_loss", "custom_metric"], + dtype=h5py.string_dtype(encoding="utf-8"), + ) + f.attrs["training_config"] = json.dumps(training_config) + result = KerasH5Scanner().scan(str(model_path)) -def test_keras_h5_non_string_model_class_preserves_nested_cve_detection(tmp_path: Path) -> None: - """Malformed root metadata must not suppress scanning of nested layers.""" - raw_secret = "sk-proj-CAND061H5MODELCLASSSECRET000000000000" - model_path = create_custom_h5_file( - tmp_path, - { - "class_name": {"api_key": raw_secret}, - "config": { - "layers": [ - {"class_name": "Lambda", "config": {"function": "lambda x: x"}}, - ] - }, - }, - keras_version="3.10.0", + assert "keras_h5_scan_failed" not in result.metadata.get("scan_outcome_reasons", []) + assert any( + check.name == "Custom Objects Security Check" + and check.details["custom_objects"] == ["custom_loss", "custom_metric"] + for check in result.checks + ) + assert any( + check.name == "Custom Loss Detection" and check.details.get("identifier") == "malicious_loss" + for check in result.checks ) - result = KerasH5Scanner().scan(str(model_path)) - type_checks = [check for check in result.checks if check.name == "Model Class Type Validation"] - cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2025-9905"] - assert len(type_checks) == 1 - assert len(cve_issues) == 1 - assert cve_issues[0].severity == IssueSeverity.CRITICAL - assert result.metadata["model_class"] == "" - assert "keras_h5_model_class_invalid_type" in result.metadata["scan_outcome_reasons"] - assert raw_secret not in result.to_json() +def test_empty_variable_string_name_attribute_is_not_truncated(tmp_path: Path) -> None: + import numpy as np + model_path = tmp_path / "empty_variable_names.h5" + with h5py.File(model_path, "w") as f: + f.attrs.create( + "layer_names", + np.array([], dtype=object), + dtype=h5py.string_dtype(encoding="utf-8"), + ) -@pytest.mark.parametrize( - "fixture_factory", - [create_h5_with_external_link, create_h5_with_external_storage], -) -def test_keras_h5_scanner_flags_external_references_despite_fixed_file_version( + with h5py.File(model_path, "r") as f: + names, truncated = KerasH5Scanner._read_bounded_hdf5_name_attribute(f.attrs, "layer_names") + attr_id = f.attrs.get_id("layer_names") + direct_names, direct_truncated = KerasH5Scanner._read_hdf5_variable_string_name_attribute( + attr_id, + max_bytes=KerasH5Scanner._MAX_HDF5_NAME_ATTRIBUTE_BYTES, + point_count=0, + ) + + assert names == [] + assert truncated is False + assert direct_names == [] + assert direct_truncated is False + + +def test_large_empty_variable_string_name_attributes_scan_cleanly( tmp_path: Path, - fixture_factory: Any, + monkeypatch: pytest.MonkeyPatch, ) -> None: - """Standalone H5 files cannot use artifact-controlled keras_version to suppress external refs.""" - model_path = fixture_factory(tmp_path, keras_version="3.13.2") + import numpy as np - scanner = KerasH5Scanner() - result = scanner.scan(str(model_path)) + model_path = tmp_path / "large_empty_variable_names.h5" + model_config = {"class_name": "Sequential", "config": {"layers": []}} + empty_names = np.array([], dtype=object) + with h5py.File(model_path, "w") as f: + f.attrs["model_config"] = json.dumps(model_config) + weights = f.require_group("model_weights") + weights.attrs.create( + "weight_names", + empty_names, + dtype=h5py.string_dtype(encoding="utf-8"), + ) + weights.attrs.create( + "layer_names", + empty_names, + dtype=h5py.string_dtype(encoding="utf-8"), + ) + inflate_h5_file_to_size(model_path) - cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] - assert len(cve_issues) == 1 - assert cve_issues[0].severity == IssueSeverity.WARNING - assert cve_issues[0].details["keras_version"] == "3.13.2" - assert cve_issues[0].details["parse_status"] == "untrusted_artifact_version" - assert cve_issues[0].details["version_source"] == "hdf5_file_attribute" - assert not any( - check.name == "HDF5 External Weight Reference Version Check" and check.status == CheckStatus.PASSED - for check in result.checks + worker_name_attrs: list[str] = [] + original_batch_reader: Callable[[list[dict[str, Any]]], list[dict[str, Any]]] = ( + KerasH5Scanner._read_hdf5_attributes_in_worker ) + def counting_batch_reader(cls: type[KerasH5Scanner], requests: list[dict[str, Any]]) -> list[dict[str, Any]]: + worker_name_attrs.extend(str(request["attr_name"]) for request in requests if request.get("mode") == "names") + return original_batch_reader(requests) -def test_keras_h5_scanner_fixed_metadata_without_external_refs_stays_quiet(tmp_path: Path) -> None: - """Fixed-looking metadata alone should not produce external-reference noise.""" - model_path = create_custom_h5_file( - tmp_path, - { - "class_name": "Sequential", - "config": { - "name": "sequential", - "layers": [{"class_name": "Dense", "config": {"units": 1}}], - }, - }, - keras_version="3.13.2", - file_name="fixed_no_external_refs.h5", - ) + monkeypatch.setattr(KerasH5Scanner, "_read_hdf5_attributes_in_worker", classmethod(counting_batch_reader)) result = KerasH5Scanner().scan(str(model_path)) - assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) - assert not any(check.name.startswith("HDF5 External Weight Reference") for check in result.checks) + assert result.success is True + assert_not_rejected_by_read_cap(result) + assert worker_name_attrs.count("weight_names") >= 1 + assert worker_name_attrs.count("layer_names") >= 1 + assert "keras_h5_external_reference_analysis_limit_exceeded" not in result.metadata.get( + "scan_outcome_reasons", + [], + ) + assert not any( + check.name == "HDF5 External Reference Analysis Limit" and check.details.get("weight_roots_truncated") is True + for check in result.checks + ) -def test_keras_h5_metadata_redacts_model_controlled_identifiers(tmp_path: Path) -> None: - raw_secret = "sk-proj-KERASH5METADATASECRET1234567890" - model_path = tmp_path / "metadata_redaction.h5" - model_config = { - "class_name": f"Model_{raw_secret}", - "config": {"layers": [{"class_name": f"Layer_{raw_secret}", "config": {}}]}, - } +def test_large_legacy_weight_name_attributes_are_batched_in_worker( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + layer_count = 64 + model_path = tmp_path / "many_legacy_layers.weights.h5" + with h5py.File(model_path, "w") as f: + layer_names = [f"layer_{index}".encode() for index in range(layer_count)] + f.attrs["layer_names"] = layer_names + for index in range(layer_count): + layer = f.create_group(f"layer_{index}") + layer.attrs["weight_names"] = [b"kernel:0"] + layer.create_dataset("kernel:0", data=[float(index)]) + inflate_h5_file_to_size(model_path) - with h5py.File(model_path, "w") as h5_file: - h5_file.attrs["model_config"] = json.dumps(model_config) - h5_file.attrs["keras_version"] = f"3.10.0+{raw_secret}" - h5_file.create_group(f"group_{raw_secret}") - weights = h5_file.create_group("model_weights") - weights.create_dataset(f"kernel_{raw_secret}", data=[1.0]) + worker_batch_sizes: list[int] = [] + original_batch_reader: Callable[[list[dict[str, Any]]], list[dict[str, Any]]] = ( + KerasH5Scanner._read_hdf5_attributes_in_worker + ) - metadata = KerasH5Scanner().extract_metadata(str(model_path)) - serialized_metadata = json.dumps(metadata, default=str) + def counting_batch_reader(cls: type[KerasH5Scanner], requests: list[dict[str, Any]]) -> list[dict[str, Any]]: + worker_batch_sizes.append(len(requests)) + return original_batch_reader(requests) - assert raw_secret not in serialized_metadata - assert metadata["has_model_config"] is True - assert metadata["has_model_weights"] is True - assert metadata["total_parameters"] == 1 - assert metadata["model_class"] == "Model_" - assert metadata["keras_version"] == "3.10.0+" - assert metadata["layer_types"] == ["Layer_"] - assert metadata["parameter_details"] == [{"name": "kernel_", "shape": [1], "dtype": "float64", "size": 1}] - assert "group_" in metadata["h5_keys"] + monkeypatch.setattr(KerasH5Scanner, "_read_hdf5_attributes_in_worker", classmethod(counting_batch_reader)) + result = KerasH5Scanner().scan(str(model_path)) -def test_keras_h5_metadata_redacts_extraction_error( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - raw_secret = "ATTACKER_CONTROLLED_KERAS_H5_METADATA_FAILURE" - model_path = create_mock_h5_file(tmp_path) + assert result.success is True + assert_not_rejected_by_read_cap(result) + assert max(worker_batch_sizes) >= layer_count + assert len(worker_batch_sizes) <= 12 - def fail_h5py_open(*_args: Any, **_kwargs: Any) -> None: - raise RuntimeError(raw_secret) - monkeypatch.setattr(keras_h5_scanner_module.h5py, "File", fail_h5py_open) +def _sha256_file(path: Path) -> str: + import hashlib - metadata = KerasH5Scanner().extract_metadata(str(model_path)) + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(8 * 1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() - assert metadata["extraction_error"] == "" - assert raw_secret not in json.dumps(metadata, default=str) + +def _assert_pinned_hf_h5_metadata( + huggingface_hub: Any, + *, + repo_id: str, + revision: str, + expected_size: int, + expected_blob_id: str, + expected_sha256: str, +) -> None: + info = huggingface_hub.HfApi().model_info(repo_id=repo_id, revision=revision, files_metadata=True) + assert info.sha == revision + tf_model = next(sibling for sibling in info.siblings if sibling.rfilename == "tf_model.h5") + assert tf_model.size == expected_size + assert tf_model.blob_id == expected_blob_id + assert tf_model.lfs is not None + assert tf_model.lfs.sha256 == expected_sha256 -@pytest.mark.parametrize("keras_version", ["3.13.x", "2.12.0-gpu", "3.13.2rc1junk", "3.13.2+"]) -def test_keras_h5_scanner_unparseable_external_reference_versions_mark_unknown_risk( +def _assert_real_hf_h5_reaches_keras_scan( tmp_path: Path, - keras_version: str, + *, + repo_id: str, + revision: str, + expected_size: int, + expected_blob_id: str, + expected_sha256: str, + expected_root_keys: list[str] | None = None, + expected_attrs: set[str] | None = None, + expected_layer_names: list[str] | None = None, ) -> None: - model_path = create_h5_with_external_link(tmp_path, keras_version=keras_version) + if os.environ.get("MODELAUDIT_RUN_REAL_HF_H5") != "1": + pytest.skip("Set MODELAUDIT_RUN_REAL_HF_H5=1 to download and scan pinned HF H5 models") + + huggingface_hub = pytest.importorskip("huggingface_hub") + _assert_pinned_hf_h5_metadata( + huggingface_hub, + repo_id=repo_id, + revision=revision, + expected_size=expected_size, + expected_blob_id=expected_blob_id, + expected_sha256=expected_sha256, + ) + cache_dir = os.environ.get("MODELAUDIT_HF_CACHE_DIR", str(tmp_path / "hf-cache")) + model_path = Path( + huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename="tf_model.h5", + revision=revision, + cache_dir=cache_dir, + ) + ) - result = KerasH5Scanner().scan(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + metadata = audit_result.file_metadata[str(model_path)] + metadata_extra = getattr(metadata, "model_extra", {}) or {} + + assert model_path.stat().st_size == expected_size + assert metadata.file_size == expected_size + assert _sha256_file(model_path) == expected_sha256 + with h5py.File(model_path, "r") as h5_file: + if expected_root_keys is not None: + assert sorted(h5_file.keys()) == expected_root_keys + if expected_attrs is not None: + assert set(h5_file.attrs.keys()) == expected_attrs + layer_names, layer_names_truncated = KerasH5Scanner._read_bounded_hdf5_name_attribute( + h5_file.attrs, + "layer_names", + ) + assert layer_names_truncated is False + assert layer_names + if expected_layer_names is not None: + assert layer_names == expected_layer_names + for layer_name in layer_names: + assert layer_name in h5_file + assert audit_result.files_scanned == 1 + assert "keras_h5" in audit_result.scanner_names + assert "max_file_read_size_exceeded" not in (metadata_extra.get("scan_outcome_reasons") or []) + assert any( + check.name == "Keras Model Format Check" and check.status == CheckStatus.PASSED for check in audit_result.checks + ) + assert core_module.determine_exit_code(audit_result) == 0 - cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] - assert len(cve_issues) == 1 - assert cve_issues[0].severity == IssueSeverity.WARNING - assert cve_issues[0].details["keras_version"] == keras_version - assert cve_issues[0].details["parse_status"] == "unknown" - assert any("is non-canonical" in issue.message for issue in cve_issues) - assert not any( - check.name == "HDF5 External Weight Reference Version Check" and check.status == CheckStatus.PASSED - for check in result.checks +@pytest.mark.integration +def test_real_hf_xlm_roberta_large_h5_reaches_keras_scan_without_read_cap(tmp_path: Path) -> None: + _assert_real_hf_h5_reaches_keras_scan( + tmp_path, + repo_id="FacebookAI/xlm-roberta-large", + revision="c23d21b0620b635a76227c604d44e43a9f0ee389", + expected_size=2_240_076_248, + expected_blob_id="c902fe1cef9561c2e78bd7fccc5f83887e844f8b", + expected_sha256="a465c8d459fe83e10db5655221e2e7e7b6df3de2216c524399358d17ac7315ea", + expected_root_keys=["roberta", "top_level_model_weights"], + expected_attrs={"backend", "keras_version", "layer_names"}, + expected_layer_names=["roberta"], ) -def test_keras_h5_scanner_benign_model_has_no_warning_noise(tmp_path: Path) -> None: - """Benign H5 models should not produce warning or critical noise.""" - model_path = create_custom_h5_file( +@pytest.mark.integration +def test_real_hf_esm2_large_h5_reaches_keras_scan_without_read_cap(tmp_path: Path) -> None: + _assert_real_hf_h5_reaches_keras_scan( tmp_path, - { - "class_name": "Sequential", - "config": { - "name": "sequential", - "layers": [{"class_name": "Dense", "config": {"units": 1}}], - }, - }, + repo_id="facebook/esm2_t33_650M_UR50D", + revision="08e4846e537177426273712802403f7ba8261b6c", + expected_size=2_605_109_760, + expected_blob_id="c3271b7e4fc4dbd0f1bd3980c02cc21101c57cbb", + expected_sha256="3110b0ee07a47362ff90dc4d780b12287e06f2a09f56c8e117c4aed089fc96b8", ) - scanner = KerasH5Scanner() - result = scanner.scan(str(model_path)) - assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) +@pytest.mark.integration +def test_real_hf_whisper_large_v2_h5_reaches_keras_scan_without_read_cap(tmp_path: Path) -> None: + _assert_real_hf_h5_reaches_keras_scan( + tmp_path, + repo_id="openai/whisper-large-v2", + revision="ae4642769ce2ad8fc292556ccea8e901f1530655", + expected_size=6_174_574_896, + expected_blob_id="38414d47073f613961f19565ed6b481e1b9b0f80", + expected_sha256="489f5f36ba6e1959913bb77b30baf85e8b791e1e585dec7d65a2e217bfb8be47", + ) def test_missing_h5py_returns_inconclusive_exit2_without_cache( @@ -704,16 +2017,16 @@ def find_spec(self, fullname, path=None, target=None): sys.meta_path.insert(0, BrokenH5pyFinder()) - from modelaudit.core import determine_exit_code, scan_model_directory_or_file + import modelaudit.core as core_module model_path = sys.argv[1] - result = scan_model_directory_or_file(model_path, cache_enabled=False) + result = core_module.scan_model_directory_or_file(model_path, cache_enabled=False) metadata = result.file_metadata[model_path] print( json.dumps( { "success": result.success, - "exit_code": determine_exit_code(result), + "exit_code": core_module.determine_exit_code(result), "check_names": [check.name for check in result.checks], "scan_outcome_reasons": metadata.get("scan_outcome_reasons", []), } @@ -877,13 +2190,13 @@ def fail_h5py_open(*_args: Any, **_kwargs: Any) -> None: assert raw_secret not in failed_result.to_json() assert cache_manager.get_stats()["total_entries"] == 1 - audit_result = scan_model_directory_or_file( + audit_result = core_module.scan_model_directory_or_file( str(model_path), cache_enabled=True, cache_dir=str(cache_dir), min_cache_file_size=0, ) - assert determine_exit_code(audit_result) == 2 + assert core_module.determine_exit_code(audit_result) == 2 assert "keras_h5_scan_failed" in audit_result.file_metadata[str(model_path)]["scan_outcome_reasons"] finally: reset_cache_manager() @@ -946,25 +2259,25 @@ def _assert_inconclusive_keras_h5_scan( for check in result.checks ) - audit_result = scan_model_directory_or_file(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path)) metadata = audit_result.file_metadata[str(model_path)] assert audit_result.has_errors is False assert metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME assert reason in metadata.get("scan_outcome_reasons") - assert determine_exit_code(audit_result) == 2 + assert core_module.determine_exit_code(audit_result) == 2 def _assert_inconclusive_keras_h5_scan_not_cached(model_path: Path, reason: str, cache_dir: Path) -> None: reset_cache_manager() try: - first_result = scan_model_directory_or_file( + first_result = core_module.scan_model_directory_or_file( str(model_path), cache_enabled=True, cache_dir=str(cache_dir), min_cache_file_size=0, ) - second_result = scan_model_directory_or_file( + second_result = core_module.scan_model_directory_or_file( str(model_path), cache_enabled=True, cache_dir=str(cache_dir), @@ -973,7 +2286,7 @@ def _assert_inconclusive_keras_h5_scan_not_cached(model_path: Path, reason: str, for audit_result in (first_result, second_result): metadata = audit_result.file_metadata[str(model_path)] - assert determine_exit_code(audit_result) == 2 + assert core_module.determine_exit_code(audit_result) == 2 assert metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME assert reason in metadata.get("scan_outcome_reasons") assert not any( @@ -1212,12 +2525,12 @@ def test_keras_h5_inconclusive_training_config_preserves_security_exit1(tmp_path ) result = KerasH5Scanner().scan(str(model_path)) - audit_result = scan_model_directory_or_file(str(model_path)) + audit_result = core_module.scan_model_directory_or_file(str(model_path)) assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME assert "keras_h5_training_config_parse_failed" in result.metadata["scan_outcome_reasons"] assert any(issue.severity == IssueSeverity.CRITICAL for issue in result.issues) - assert determine_exit_code(audit_result) == 1 + assert core_module.determine_exit_code(audit_result) == 1 def test_keras_h5_inconclusive_scan_outcome_uncached_rerun_preserves_exit2(tmp_path: Path) -> None: @@ -1230,13 +2543,13 @@ def test_keras_h5_inconclusive_scan_outcome_uncached_rerun_preserves_exit2(tmp_p cache_dir = tmp_path / "cache" reset_cache_manager() - first_result = scan_model_directory_or_file( + first_result = core_module.scan_model_directory_or_file( str(model_path), cache_enabled=True, cache_dir=str(cache_dir), min_cache_file_size=0, ) - second_result = scan_model_directory_or_file( + second_result = core_module.scan_model_directory_or_file( str(model_path), cache_enabled=True, cache_dir=str(cache_dir), @@ -1244,8 +2557,8 @@ def test_keras_h5_inconclusive_scan_outcome_uncached_rerun_preserves_exit2(tmp_p ) metadata = second_result.file_metadata[str(model_path)] - assert determine_exit_code(first_result) == 2 - assert determine_exit_code(second_result) == 2 + assert core_module.determine_exit_code(first_result) == 2 + assert core_module.determine_exit_code(second_result) == 2 assert metadata.get("scan_outcome") == INCONCLUSIVE_SCAN_OUTCOME assert "keras_h5_model_config_invalid_type" in metadata.get("scan_outcome_reasons") assert get_cache_manager(str(cache_dir), enabled=True).get_stats()["total_entries"] == 0 @@ -1317,11 +2630,12 @@ def test_keras_h5_scanner_skips_generic_nested_weight_like_groups(tmp_path: Path assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) -def test_keras_h5_scanner_skips_generic_root_weight_like_groups(tmp_path: Path) -> None: +@pytest.mark.parametrize("root_path", ["vars", "optimizer/vars"]) +def test_keras_h5_scanner_skips_generic_root_weight_like_groups(tmp_path: Path, root_path: str) -> None: """Generic root vars/weights groups are common outside Keras and should stay quiet.""" generic_path = tmp_path / "generic_root_vars.h5" with h5py.File(generic_path, "w") as f: - f["vars"] = h5py.ExternalLink("external_source.h5", "/payload") + f.require_group(root_path)["0"] = h5py.ExternalLink("external_source.h5", "/payload") result = KerasH5Scanner().scan(str(generic_path)) @@ -1660,6 +2974,137 @@ def test_keras_h5_scanner_flags_keras3_weights_external_link_without_resolving_i ] +@pytest.mark.parametrize("layout", ["legacy", "keras3"]) +def test_keras_h5_scanner_flags_weights_only_soft_linked_external_reference( + tmp_path: Path, + layout: str, +) -> None: + """Weights-only SoftLink aliases must still route into external-reference analysis.""" + weights_path = tmp_path / f"soft_linked_{layout}.weights.h5" + with h5py.File(weights_path, "w") as f: + if layout == "legacy": + f.attrs["layer_names"] = [b"dense_alias"] + dense = f.create_group("real_dense") + dense.attrs["weight_names"] = [b"kernel:0"] + dense["kernel:0"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + f["dense_alias"] = h5py.SoftLink("/real_dense") + expected_path = "/dense_alias/kernel:0" + else: + real_layers = f.create_group("real_layers") + vars_group = real_layers.create_group("dense").create_group("vars") + vars_group["0"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + f["layers"] = h5py.SoftLink("/real_layers") + expected_path = "/layers/dense/vars/0" + + result = KerasH5Scanner().scan(str(weights_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [ + { + "kind": "ExternalLink", + "hdf5_path": expected_path, + "filename": "missing_external_source.h5", + "path": "/payload", + }, + ] + + +@pytest.mark.parametrize("reference_kind", ["ExternalLink", "external_storage", "virtual_dataset"]) +def test_keras_h5_scanner_flags_soft_link_group_nested_external_reference( + tmp_path: Path, + reference_kind: str, +) -> None: + """A loader-consumed SoftLink group alias must not hide nested external HDF5 references.""" + weights_path = tmp_path / "soft_link_group.weights.h5" + with h5py.File(weights_path, "w") as f: + vars_group = f.create_group("layers").create_group("dense").create_group("vars") + resolved_group = f.create_group("resolved_group") + vars_group["0"] = h5py.SoftLink("/resolved_group") + expected_reference: dict[str, Any] + + if reference_kind == "ExternalLink": + resolved_group["payload"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + expected_reference = { + "kind": "ExternalLink", + "hdf5_path": "/layers/dense/vars/0/payload", + "filename": "missing_external_source.h5", + "path": "/payload", + } + elif reference_kind == "external_storage": + raw_storage = tmp_path / "weights.raw" + raw_storage.write_bytes(b"\x00" * 8) + resolved_group.create_dataset( + "payload", + shape=(2,), + dtype="float32", + external=[(raw_storage.name, 0, 8)], + ) + expected_reference = { + "kind": "external_storage", + "hdf5_path": "/layers/dense/vars/0/payload", + "segments": [{"filename": "weights.raw", "offset": 0, "size": 8}], + } + else: + virtual_source = tmp_path / "virtual_source.h5" + with h5py.File(virtual_source, "w") as source_file: + source_file.create_dataset("payload", data=[1.0, 2.0]) + layout = h5py.VirtualLayout(shape=(2,), dtype="float64") + layout[:] = h5py.VirtualSource(virtual_source.name, "/payload", shape=(2,)) + resolved_group.create_virtual_dataset("payload", layout) + expected_reference = { + "kind": "virtual_dataset", + "hdf5_path": "/layers/dense/vars/0/payload", + "sources": [{"filename": "virtual_source.h5", "path": "/payload"}], + } + + result = KerasH5Scanner().scan(str(weights_path)) + + cve_issues = [issue for issue in result.issues if issue.details.get("cve_id") == "CVE-2026-1669"] + assert len(cve_issues) == 1 + assert cve_issues[0].details["external_references"] == [expected_reference] + + +def test_keras_h5_scanner_allows_soft_link_group_with_internal_dataset(tmp_path: Path) -> None: + """Clean internal SoftLink group aliases should stay non-findings.""" + weights_path = tmp_path / "clean_soft_link_group.weights.h5" + with h5py.File(weights_path, "w") as f: + vars_group = f.create_group("layers").create_group("dense").create_group("vars") + resolved_group = f.create_group("resolved_group") + resolved_group.create_dataset("payload", data=[1.0, 2.0]) + vars_group["0"] = h5py.SoftLink("/resolved_group") + + result = KerasH5Scanner().scan(str(weights_path)) + + assert result.success is True + assert not any(issue.details.get("cve_id") == "CVE-2026-1669" for issue in result.issues) + assert not any(issue.severity in (IssueSeverity.WARNING, IssueSeverity.CRITICAL) for issue in result.issues) + + +def test_keras_h5_scanner_soft_link_group_traversal_respects_link_budget( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + weights_path = tmp_path / "budgeted_soft_link_group.weights.h5" + with h5py.File(weights_path, "w") as f: + vars_group = f.create_group("layers").create_group("dense").create_group("vars") + resolved_group = f.create_group("resolved_group") + resolved_group["payload"] = h5py.ExternalLink("missing_external_source.h5", "/payload") + vars_group["0"] = h5py.SoftLink("/resolved_group") + + monkeypatch.setattr(KerasH5Scanner, "_MAX_HDF5_LINK_VISITS", 1) + + result = KerasH5Scanner().scan(str(weights_path)) + + reason = "keras_h5_external_reference_analysis_limit_exceeded" + assert result.success is False + assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME + assert reason in result.metadata["scan_outcome_reasons"] + limit_checks = [check for check in result.checks if check.details.get("scan_outcome_reason") == reason] + assert len(limit_checks) == 1 + assert limit_checks[0].details["link_visits_truncated"] is True + + def test_keras_h5_scanner_legacy_h5py_traversal_flags_dangling_external_link( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, @@ -1690,7 +3135,7 @@ def test_keras_h5_scanner_external_reference_collection_does_not_resolve_soft_li tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - """SoftLink aliases must not be dereferenced while collecting external references.""" + """SoftLink aliases are resolved as links without dereferencing external targets.""" weights_path = tmp_path / "soft_alias.weights.h5" with h5py.File(weights_path, "w") as f: vars_group = f.create_group("layers").create_group("dense").create_group("vars") @@ -1723,6 +3168,12 @@ def guarded_get( "filename": "missing_external_source.h5", "path": "/payload", }, + { + "kind": "ExternalLink", + "hdf5_path": "/layers/dense/vars/soft_alias", + "filename": "missing_external_source.h5", + "path": "/payload", + }, ] @@ -1756,8 +3207,8 @@ def test_keras_h5_scanner_external_reference_traversal_limit_fails_closed( assert limit_checks[0].details["visited_link_count"] == 2 assert limit_checks[0].details["link_visits_truncated"] is True - audit_result = scan_model_directory_or_file(str(weights_path), cache_enabled=False) - assert determine_exit_code(audit_result) == 2 + audit_result = core_module.scan_model_directory_or_file(str(weights_path), cache_enabled=False) + assert core_module.determine_exit_code(audit_result) == 2 _assert_inconclusive_keras_h5_scan_not_cached(weights_path, reason, tmp_path / f"cache-{legacy_h5py}") @@ -1787,8 +3238,8 @@ def test_keras_h5_scanner_external_reference_reports_are_bounded( assert cve_issues[0].details["external_reference_count"] == 3 assert cve_issues[0].details["external_references_truncated"] is True - audit_result = scan_model_directory_or_file(str(model_path), cache_enabled=False) - assert determine_exit_code(audit_result) == 1 + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + assert core_module.determine_exit_code(audit_result) == 1 def test_keras_h5_scanner_external_storage_segment_reports_are_bounded( @@ -2624,7 +4075,7 @@ def test_lambda_code_details_omit_sensitive_previews_in_json_and_sarif(tmp_path: for check in scanner_result.checks ) - audit_result = scan_model_directory_or_file(str(model_path), cache_enabled=False) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) json_output = audit_result.model_dump_json(indent=2, exclude_none=True) sarif_output = format_sarif_output(audit_result, [str(model_path)]) @@ -2737,7 +4188,7 @@ def test_lambda_nested_metadata_omits_artifact_controlled_keys_and_fake_wrapped_ assert suspicious_checks[0].details.get("context") == "Lambda" assert not any(check.name == "Custom Layer Class Detection" for check in scanner_result.checks) - audit_result = scan_model_directory_or_file(str(model_path), cache_enabled=False) + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) json_output = audit_result.model_dump_json(indent=2, exclude_none=True) sarif_output = format_sarif_output(audit_result, [str(model_path)]) @@ -2805,8 +4256,8 @@ def test_lambda_scalar_function_metadata_fails_closed_without_echoing_value(tmp_ assert "function" not in malformed_checks[0].details assert result.success is True - audit_result = scan_model_directory_or_file(str(model_path), cache_enabled=False) - assert determine_exit_code(audit_result) == 1 + audit_result = core_module.scan_model_directory_or_file(str(model_path), cache_enabled=False) + assert core_module.determine_exit_code(audit_result) == 1 def test_lambda_null_function_placeholder_is_not_treated_as_malformed(tmp_path: Path) -> None: @@ -5696,8 +7147,8 @@ def test_nested_module_after_node_limit_is_inconclusive( assert not any(issue.details.get("cve_id") == "CVE-2025-1550" for issue in result.issues) assert result.success is False assert result.metadata["scan_outcome"] == INCONCLUSIVE_SCAN_OUTCOME - audit_result = scan_model_directory_or_file(str(model_path), scanner_config={}) - assert determine_exit_code(audit_result) == 2 + audit_result = core_module.scan_model_directory_or_file(str(model_path), scanner_config={}) + assert core_module.determine_exit_code(audit_result) == 2 def test_nested_non_lambda_serialized_function_is_critical(self, tmp_path: Path) -> None: model_path = create_custom_h5_file( diff --git a/tests/test_core.py b/tests/test_core.py index 2cb47d5f2..86660d2e8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2190,6 +2190,103 @@ def fail_owner_scan(scanner: TensorFlowSavedModelScanner, owner_path: str) -> Sc assert determine_exit_code(result) == 2 +@pytest.mark.skipif( + not _DESCRIPTOR_BOUND_DIRECTORY_OWNER_PATH_AVAILABLE, + reason="descriptor-bound directory owner path is unavailable", +) +def test_savedmodel_owner_allows_large_file_backed_hdf5_child( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _require_tf_protos() + from modelaudit.scanners.keras_h5_scanner import KerasH5Scanner + from modelaudit.scanners.tf_savedmodel_scanner import TensorFlowSavedModelScanner + + model_dir = tmp_path / "saved-model" + model_dir.mkdir() + _write_safe_savedmodel(model_dir / "saved_model.pb") + variables_dir = model_dir / "variables" + variables_dir.mkdir() + hdf5_path = variables_dir / "large-benign.json" + _write_large_benign_keras_hdf5(hdf5_path) + + hdf5_scans: list[Path] = [] + owner_calls: list[Path] = [] + original_hdf5_scan = KerasH5Scanner.scan + original_owner_scan = TensorFlowSavedModelScanner.scan + original_hash = core_module._calculate_file_hash + + def record_hdf5_scan(scanner: KerasH5Scanner, path: str) -> ScanResult: + if Path(path).resolve() == hdf5_path.resolve(): + hdf5_scans.append(Path(path).resolve()) + return original_hdf5_scan(scanner, path) + + def record_owner_scan(scanner: TensorFlowSavedModelScanner, owner_path: str) -> ScanResult: + if Path(owner_path).is_dir(): + owner_calls.append(Path(owner_path).resolve()) + return original_owner_scan(scanner, owner_path) + + def reject_large_hdf5_hash(path: str, *, deadline: float | None = None) -> str: + if Path(path).resolve() == hdf5_path.resolve(): + pytest.fail("large file-backed HDF5 child must not be whole-file hashed") + return original_hash(path, deadline=deadline) + + monkeypatch.setattr(KerasH5Scanner, "scan", record_hdf5_scan) + monkeypatch.setattr(TensorFlowSavedModelScanner, "scan", record_owner_scan) + monkeypatch.setattr(core_module, "_calculate_file_hash", reject_large_hdf5_hash) + + result = scan_model_directory_or_file(str(model_dir), cache_enabled=False) + + owner_metadata = result.file_metadata[str(model_dir)] + hdf5_metadata = result.file_metadata[str(hdf5_path)] + assert owner_calls == [model_dir.resolve()] + assert hdf5_scans == [hdf5_path.resolve()] + assert owner_metadata["directory_owner_scan"] is True + assert "directory_owner_snapshot_incomplete" not in owner_metadata.get("scan_outcome_reasons", []) + assert hdf5_metadata["content_hash"].startswith("unhashable_file_backed_hdf5_") + assert hdf5_metadata["file_backed_scan"] is True + assert result.content_hash is None + assert "tf_savedmodel" in result.scanner_names + assert "keras_h5" in result.scanner_names + assert determine_exit_code(result) == 0 + + +def test_savedmodel_owner_rejects_large_file_backed_hdf5_without_descriptor_binding( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _require_tf_protos() + from modelaudit.scanners.tf_savedmodel_scanner import TensorFlowSavedModelScanner + + model_dir = tmp_path / "saved-model" + model_dir.mkdir() + _write_safe_savedmodel(model_dir / "saved_model.pb") + variables_dir = model_dir / "variables" + variables_dir.mkdir() + hdf5_path = variables_dir / "large-benign.json" + _write_large_benign_keras_hdf5(hdf5_path) + _force_staged_directory_owner_scan(monkeypatch) + + original_owner_scan = TensorFlowSavedModelScanner.scan + + def reject_directory_owner_scan(scanner: TensorFlowSavedModelScanner, owner_path: str) -> ScanResult: + if Path(owner_path).is_dir(): + raise AssertionError("deferred HDF5 owner sources require descriptor-bound dispatch") + return original_owner_scan(scanner, owner_path) + + monkeypatch.setattr(TensorFlowSavedModelScanner, "scan", reject_directory_owner_scan) + + result = scan_model_directory_or_file(str(model_dir), cache_enabled=False) + + owner_metadata = result.file_metadata[str(model_dir)] + assert owner_metadata["directory_owner_scan"] is False + assert "directory_owner_snapshot_incomplete" in owner_metadata["scan_outcome_reasons"] + assert owner_metadata["operational_error"] is True + assert result.file_metadata[str(hdf5_path)]["content_hash"].startswith("unhashable_file_backed_hdf5_") + assert result.content_hash is None + assert determine_exit_code(result) == 2 + + def test_mixed_savedmodel_and_orbax_root_preserves_both_security_scans(tmp_path: Path) -> None: _require_tf_protos() model_dir = tmp_path / "mixed-model" @@ -2451,6 +2548,31 @@ def _write_safe_savedmodel(path: Path) -> None: path.write_bytes(_build_collection_only_tf_savedmodel(value=b"documentation: https://example.invalid/runtime")) +def _write_large_benign_keras_hdf5(path: Path) -> None: + h5py = pytest.importorskip("h5py") + with h5py.File(path, "w") as h5_file: + h5_file.attrs["model_config"] = json.dumps( + {"class_name": "Sequential", "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}}, + ) + h5_file.attrs["keras_version"] = "3.13.2" + with path.open("ab") as handle: + handle.truncate(core_module.DEFAULT_MAX_FILE_READ_SIZE + 4096) + + +def _write_large_valid_userblock_keras_hdf5(path: Path) -> int: + h5py = pytest.importorskip("h5py") + userblock_size = 16 * 1024 * 1024 + assert userblock_size > HDF5_SIGNATURE_SCAN_MAX_BYTES + with h5py.File(path, "w", userblock_size=userblock_size) as h5_file: + h5_file.attrs["model_config"] = json.dumps( + {"class_name": "Sequential", "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}}, + ) + h5_file.attrs["keras_version"] = "3.13.2" + with path.open("ab") as handle: + handle.truncate(core_module.DEFAULT_MAX_FILE_READ_SIZE + 4096) + return userblock_size + + def _write_safe_r_serialized(path: Path) -> None: path.write_bytes(b"X\nsafe\nmodel\nweights") @@ -8557,6 +8679,76 @@ def fail_h5py_open(*_args: Any, **_kwargs: Any) -> None: assert determine_exit_code(audit_result) == 2 +def test_scan_file_defers_hash_for_large_valid_hdf5_userblock_beyond_signature_probe( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = tmp_path / "large-userblock-model.h5" + userblock_size = _write_large_valid_userblock_keras_hdf5(model_path) + cache_dir = tmp_path / "cache" + config = { + "cache_enabled": True, + "cache_dir": str(cache_dir), + "min_cache_file_size": 0, + "max_cache_file_size": core_module.DEFAULT_MAX_FILE_READ_SIZE * 2, + } + + def fail_if_cache_hashes_large_hdf5(_self: SecureFileHasher, path: str) -> str: + if Path(path).resolve() == model_path.resolve(): + pytest.fail("large valid HDF5 userblock file was whole-file hashed for cache lookup") + return "0" * 64 + + monkeypatch.setattr(SecureFileHasher, "hash_file", fail_if_cache_hashes_large_hdf5) + monkeypatch.setattr( + SecureFileHasher, + "hash_file_with_stat", + lambda self, path, _stat: fail_if_cache_hashes_large_hdf5(self, path), + ) + + assert find_hdf5_signature_offset(str(model_path)) == userblock_size + + reset_cache_manager() + try: + result = scan_file(str(model_path), config=config) + cache_entries = get_cache_manager(str(cache_dir), enabled=True).get_stats()["total_entries"] + finally: + reset_cache_manager() + + assert result.scanner_name == "keras_h5" + assert result.success is False + assert "hdf5_userblock_zip_probe_incomplete" in result.metadata["scan_outcome_reasons"] + assert result.metadata["file_backed_scan"] is True + assert cache_entries == 0 + + +def test_directory_scan_defers_hash_for_large_valid_hdf5_userblock_beyond_signature_probe( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_path = tmp_path / "large-userblock-model.h5" + userblock_size = _write_large_valid_userblock_keras_hdf5(model_path) + original_hash = core_module._calculate_file_hash + + def fail_if_directory_hashes_large_hdf5(path: str, *, deadline: float | None = None) -> str: + if Path(path).resolve() == model_path.resolve(): + pytest.fail("large valid HDF5 userblock file was whole-file hashed during directory scan") + return original_hash(path, deadline=deadline) + + monkeypatch.setattr(core_module, "_calculate_file_hash", fail_if_directory_hashes_large_hdf5) + + assert find_hdf5_signature_offset(str(model_path)) == userblock_size + + audit_result = scan_model_directory_or_file(str(tmp_path), cache_enabled=False) + metadata = audit_result.file_metadata[str(model_path)] + + assert "keras_h5" in audit_result.scanner_names + assert metadata["content_hash"].startswith("unhashable_file_backed_hdf5_") + assert metadata["file_backed_scan"] is True + assert "hdf5_userblock_zip_probe_incomplete" in metadata["scan_outcome_reasons"] + assert audit_result.content_hash is None + assert determine_exit_code(audit_result) == 2 + + def test_scan_directory_preserves_hdf5_userblock_under_skipped_suffix( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/test_streaming_scan.py b/tests/test_streaming_scan.py index bb581c0c0..188f90ca2 100644 --- a/tests/test_streaming_scan.py +++ b/tests/test_streaming_scan.py @@ -1,5 +1,6 @@ """Tests for streaming scan-and-delete functionality.""" +import json import logging import os import pickle @@ -30,9 +31,10 @@ from modelaudit.integrations.sarif_formatter import format_sarif_output from modelaudit.models import FileMetadataModel, LicenseInfoModel, create_initial_audit_result from modelaudit.scanners import safetensors_scanner -from modelaudit.scanners.base import CheckStatus, Issue, IssueSeverity, ScanResult +from modelaudit.scanners.base import DEFAULT_MAX_FILE_READ_SIZE, CheckStatus, Issue, IssueSeverity, ScanResult from modelaudit.utils.file import detection as file_detection from modelaudit.utils.file.detection import SAFETENSORS_ROUTING_HEADER_PARSE_BYTES +from modelaudit.utils.file.hdf5 import HDF5_SIGNATURE_SCAN_MAX_BYTES, find_hdf5_signature_offset from modelaudit.utils.helpers.file_hash import compute_sha256_hash from modelaudit.utils.helpers.file_iterator import iterate_files_streaming from modelaudit.utils.helpers.secure_hasher import compute_aggregate_hash @@ -83,6 +85,20 @@ def write_hf_cachedir_tag(path: Path) -> None: ) +def write_large_valid_userblock_keras_hdf5(path: Path) -> int: + h5py = pytest.importorskip("h5py") + userblock_size = 16 * 1024 * 1024 + assert userblock_size > HDF5_SIGNATURE_SCAN_MAX_BYTES + with h5py.File(path, "w", userblock_size=userblock_size) as h5_file: + h5_file.attrs["model_config"] = json.dumps( + {"class_name": "Sequential", "config": {"layers": [{"class_name": "Dense", "config": {"units": 1}}]}}, + ) + h5_file.attrs["keras_version"] = "3.13.2" + with path.open("ab") as handle: + handle.truncate(DEFAULT_MAX_FILE_READ_SIZE + 4096) + return userblock_size + + def create_mock_location_scan_result( resolved_path: Path, *, @@ -2826,6 +2842,36 @@ def test_scan_model_streaming_oversized_renamed_safetensors_fails_before_hashing assert any(check.name == "Header Size Limit" for check in result.checks) +def test_scan_model_streaming_defers_hash_for_large_valid_hdf5_userblock_beyond_signature_probe( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + payload = tmp_path / "large-userblock-model.h5" + userblock_size = write_large_valid_userblock_keras_hdf5(payload) + assert find_hdf5_signature_offset(str(payload)) == userblock_size + + monkeypatch.setattr( + "modelaudit.utils.helpers.file_hash.compute_sha256_hash", + lambda _path: pytest.fail("streaming large file-backed HDF5 artifact must not be whole-file hashed"), + ) + + result = scan_model_streaming( + file_generator=iter([(payload, True)]), + timeout=30, + delete_after_scan=False, + cache_enabled=False, + ) + + metadata = result.file_metadata[str(payload)] + assert result.files_scanned == 1 + assert result.content_hash is None + assert "keras_h5" in result.scanner_names + assert metadata["file_backed_scan"] is True + assert metadata["file_hashes"] is None + assert "hdf5_userblock_zip_probe_incomplete" in metadata["scan_outcome_reasons"] + assert determine_exit_code(result) == 2 + + def test_scan_model_streaming_does_not_hash_files_over_max_file_size( tmp_path: Path, monkeypatch: pytest.MonkeyPatch,