Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"build_cache_size": 2,
"default_benchmark_timeout": 500,
"regressions_thresholds": {
".*": 0.3
".*": 0.2
}
}
5 changes: 5 additions & 0 deletions benchmarks/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import psutil

from ._patch_setup import _apply_patches

_MIN_THREADS = 4 # minimum physical cores required for multi-threaded mode


Expand All @@ -19,3 +21,6 @@ def _thread_count():

_THREADS = os.environ.get("MKL_NUM_THREADS", _thread_count())
os.environ["MKL_NUM_THREADS"] = _THREADS

_apply_patches()
del _apply_patches
67 changes: 67 additions & 0 deletions benchmarks/benchmarks/_patch_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""MKL patch setup — executed once per ASV worker process at import time.

Patches NumPy FFT with the Intel MKL FFT implementation.
Hard-fails with a descriptive RuntimeError if mkl_fft is missing or the
patch does not take effect, so benchmarks never silently run on stock NumPy.
"""
Comment thread
vchamarthi marked this conversation as resolved.

_PATCH_MAP = [
("mkl_fft", "patch_numpy_fft"),
]


def _apply_patches():
import importlib

import numpy as np

patched = {}

for mod_name, patch_fn_name in _PATCH_MAP:
try:
mod = importlib.import_module(mod_name)
except ImportError as exc:
raise RuntimeError(
f"[mkl-patch] Cannot import {mod_name}: {exc}\n"
f" Ensure the conda env contains {mod_name} "
f"from the Intel channel.\n"
" Required channels: "
"https://software.repos.intel.com/python/conda"
) from exc

patch_fn = getattr(mod, patch_fn_name, None)
if patch_fn is None:
raise RuntimeError(
f"[mkl-patch] {mod_name} has no {patch_fn_name}(). "
f"Upgrade {mod_name} to a version that exposes "
"the stock-numpy patch API."
)

try:
patch_fn()
except Exception as exc:
raise RuntimeError(
f"[mkl-patch] {mod_name}.{patch_fn_name}() raised: {exc!r}"
) from exc

is_patched_fn = getattr(mod, "is_patched", None)
if callable(is_patched_fn) and not is_patched_fn():
raise RuntimeError(
f"[mkl-patch] {mod_name}.is_patched() returned False "
"after patching. NumPy may have been imported before "
"patching in a conflicting state."
)

patched[mod_name] = mod

_attr_checks = {
"mkl_fft": lambda: np.fft.fft.__module__,
}
for mod_name in patched:
try:
attr = _attr_checks[mod_name]()
except Exception:
attr = "unknown"
print(f"[mkl-patch] {mod_name}: numpy.fft dispatch -> {attr}")

print("[mkl-patch] ALL OK -- mkl_fft active")
Comment thread
vchamarthi marked this conversation as resolved.
2 changes: 2 additions & 0 deletions benchmarks/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
psutil
scipy
Comment thread
vchamarthi marked this conversation as resolved.