From d68c80f0f777a42a880e36bd4e1968ee7265a1a9 Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 22 Jun 2026 14:49:14 -0400 Subject: [PATCH 1/2] Add optional libxsmm fast path for strided ToT micro-GEMMs (-DTA_LIBXSMM=ON) Fetch+build libxsmm from source (no system install assumed) and route the small strided tensor-of-tensors GEMMs (ce+e, ce+ce, scale) through its JIT, falling back to vendor BLAS for shapes max(M,N,K)>64. Runtime toggle TA_LIBXSMM=0. --- CMakeLists.txt | 7 ++ external/libxsmm.cmake | 129 +++++++++++++++++++++++++++ external/versions.cmake | 8 ++ src/CMakeLists.txt | 5 ++ src/TiledArray/math/libxsmm_gemm.cpp | 108 ++++++++++++++++++++++ src/TiledArray/math/libxsmm_gemm.h | 37 ++++++++ src/TiledArray/tensor/arena_einsum.h | 102 ++++++++++++++------- src/TiledArray/tensor/tensor.cpp | 16 ++++ src/TiledArray/tensor/tensor.h | 64 ++++++++++--- 9 files changed, 432 insertions(+), 44 deletions(-) create mode 100644 external/libxsmm.cmake create mode 100644 src/TiledArray/math/libxsmm_gemm.cpp create mode 100644 src/TiledArray/math/libxsmm_gemm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index c3f01d9e74..4a36bb2645 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,6 +154,10 @@ if (TA_STRIDED_DGEMM_COUNT) add_compile_definitions(TA_STRIDED_DGEMM_COUNT) endif() +option(TA_LIBXSMM + "Fetch+build libxsmm from source and route small strided ToT micro-GEMMs (ce+e, ce+ce, scale) through its JIT fast path; falls back to vendor BLAS for shapes max(M,N,K)>64. Runtime toggle: TA_LIBXSMM=0" + OFF) + option(TA_EXPERT "TiledArray Expert mode: disables automatically downloading or building dependencies" OFF) redefaultable_option(TA_WERROR "Treat compiler warnings as errors when compiling TiledArray's own translation units (does not propagate to consumers of installed TiledArray targets)" OFF) @@ -366,6 +370,9 @@ include(${PROJECT_SOURCE_DIR}/cmake/modules/FindOrFetchBTAS.cmake) if(TA_SCALAPACK) include(external/scalapackpp.cmake) endif() +if(TA_LIBXSMM) + include(external/libxsmm.cmake) +endif() # other optional deps: # 2. TTG diff --git a/external/libxsmm.cmake b/external/libxsmm.cmake new file mode 100644 index 0000000000..97cbab432e --- /dev/null +++ b/external/libxsmm.cmake @@ -0,0 +1,129 @@ +## +## Fetch + build libxsmm from source, and expose it as the TiledArray_LIBXSMM +## INTERFACE target. Enabled by -DTA_LIBXSMM=ON. +## +## libxsmm provides a JIT small-GEMM fast path for the strided tensor-of-tensors +## micro-GEMMs (ce+e, ce+ce, scale). Unlike most TA deps, libxsmm's canonical +## build is a GNU Makefile (not CMake), so this uses ExternalProject_Add with a +## custom `make ... install` build command rather than CMAKE_ARGS. +## +## No system install is assumed: if libxsmm is not found via LIBXSMM_INSTALL_DIR +## (an optional hint), it is cloned and built from source under +## ${FETCHCONTENT_BASE_DIR}. There is intentionally NO TA_LIBXSMM_ROOT knob. +## + +# Optional: reuse a pre-built libxsmm ONLY if the user explicitly points at one +# via -DLIBXSMM_INSTALL_DIR=... We deliberately do NOT search default system +# paths (no NO_DEFAULT_PATH omission): a stray /usr/local install must never be +# picked up silently. Absent an explicit hint, libxsmm is always fetched+built. +# clear any stale value cached by a prior configure (e.g. before this guard) +unset(_LIBXSMM_INSTALL_DIR CACHE) +set(_LIBXSMM_PREBUILT _LIBXSMM_PREBUILT-NOTFOUND) +if (DEFINED LIBXSMM_INSTALL_DIR) + find_path(_LIBXSMM_PREBUILT NAMES include/libxsmm.h lib/libxsmm.a + HINTS ${LIBXSMM_INSTALL_DIR} NO_DEFAULT_PATH) +endif () + +if (_LIBXSMM_PREBUILT) + + set(_LIBXSMM_INSTALL_DIR ${_LIBXSMM_PREBUILT}) + message(STATUS "libxsmm found at ${_LIBXSMM_INSTALL_DIR}") + +elseif (TA_EXPERT) + + message("** libxsmm was not found") + message(STATUS "** Downloading and building libxsmm is explicitly disabled in EXPERT mode") + message(FATAL_ERROR "** Either provide a pre-built libxsmm via -DLIBXSMM_INSTALL_DIR=... or disable -DTA_LIBXSMM=OFF") + +else () + + include(ExternalProject) + + # libxsmm is a C library; make sure CMAKE_C_COMPILER is configured + enable_language(C) + + set(EXTERNAL_SOURCE_DIR ${FETCHCONTENT_BASE_DIR}/libxsmm-src) + set(_LIBXSMM_INSTALL_DIR ${FETCHCONTENT_BASE_DIR}/libxsmm-install) + + if (NOT LIBXSMM_URL) + set(LIBXSMM_URL https://github.com/libxsmm/libxsmm.git) + endif (NOT LIBXSMM_URL) + if (NOT LIBXSMM_TAG) + set(LIBXSMM_TAG ${TA_TRACKED_LIBXSMM_TAG}) + endif (NOT LIBXSMM_TAG) + + message("** Will clone libxsmm from ${LIBXSMM_URL}") + + # Compiler for libxsmm's sub-make. libxsmm builds with -target + # -apple-macos, which makes a bare CommandLineTools clang stop + # auto-injecting the macOS SDK sysroot, so it cannot find system headers + # (pthread.h) or libSystem at link time (and CMAKE_OSX_SYSROOT is often + # empty). On Apple, build libxsmm with the /usr/bin/{cc,c++} xcrun shims, + # which always resolve the active SDK for both compile and link; the + # resulting libxsmm.a is C-ABI-compatible with the rest of TiledArray. + # Elsewhere, honor the project's configured compilers. + if (APPLE) + set(_libxsmm_cc /usr/bin/cc) + set(_libxsmm_cxx /usr/bin/c++) + else () + set(_libxsmm_cc ${CMAKE_C_COMPILER}) + set(_libxsmm_cxx ${CMAKE_CXX_COMPILER}) + endif () + + # libxsmm Make knobs: + # STATIC=1 build libxsmm.a (we link the archive into tiledarray) + # FORTRAN=0 skip the Fortran interface (no gfortran needed) + # BLAS=0 do not wrap an external BLAS (we only use the JIT SMM path, + # and TA already links its own BLAS); avoids a second BLAS dep + # PREFIX=... install headers+lib into our private prefix + set(LIBXSMM_BUILD_BYPRODUCTS "${_LIBXSMM_INSTALL_DIR}/lib/libxsmm.a") + message(STATUS "custom target libxsmm is expected to build these byproducts: ${LIBXSMM_BUILD_BYPRODUCTS}") + + ExternalProject_Add(libxsmm + PREFIX ${FETCHCONTENT_BASE_DIR} + STAMP_DIR ${FETCHCONTENT_BASE_DIR}/libxsmm-ep-artifacts + TMP_DIR ${FETCHCONTENT_BASE_DIR}/libxsmm-ep-artifacts # in case CMAKE_INSTALL_PREFIX is not writable + #--Download step-------------- + DOWNLOAD_DIR ${EXTERNAL_SOURCE_DIR} + GIT_REPOSITORY ${LIBXSMM_URL} + GIT_TAG ${LIBXSMM_TAG} + #--Configure step------------- (none: libxsmm uses a plain Makefile) + SOURCE_DIR ${EXTERNAL_SOURCE_DIR} + UPDATE_DISCONNECTED 1 + BUILD_IN_SOURCE 1 + CONFIGURE_COMMAND "" + #--Build step----------------- build + install in one make invocation + BUILD_COMMAND make -j6 STATIC=1 FORTRAN=0 BLAS=0 + CC=${_libxsmm_cc} CXX=${_libxsmm_cxx} AR=${CMAKE_AR} + PREFIX=${_LIBXSMM_INSTALL_DIR} install + BUILD_BYPRODUCTS ${LIBXSMM_BUILD_BYPRODUCTS} + #--Install step--------------- (done by BUILD_COMMAND above) + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "libxsmm installed to ${_LIBXSMM_INSTALL_DIR}" + #--Custom targets------------- + STEP_TARGETS build + ) + + # the include dir must exist at configure time so the INTERFACE target's + # BUILD_INTERFACE include path validates (it is populated at build time) + execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${_LIBXSMM_INSTALL_DIR}/include") + + # build libxsmm before any TiledArray translation unit links + add_dependencies(External-tiledarray libxsmm) + +endif (_LIBXSMM_INSTALL_DIR) + +# Synthetic target carrying the include dir, the static archive, and the gating +# define. PUBLIC propagation (via _TILEDARRAY_DEPENDENCIES) makes +# TILEDARRAY_HAS_LIBXSMM + the include path visible to consumers (e.g. MPQC). +add_library(TiledArray_LIBXSMM INTERFACE) +set_target_properties(TiledArray_LIBXSMM + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES + "$;$" + INTERFACE_LINK_LIBRARIES + "${_LIBXSMM_INSTALL_DIR}/lib/libxsmm.a;${CMAKE_DL_LIBS}" + INTERFACE_COMPILE_DEFINITIONS + "TILEDARRAY_HAS_LIBXSMM" + ) + +install(TARGETS TiledArray_LIBXSMM EXPORT tiledarray COMPONENT tiledarray) diff --git a/external/versions.cmake b/external/versions.cmake index f5565d299c..2811c8a6c4 100644 --- a/external/versions.cmake +++ b/external/versions.cmake @@ -23,6 +23,14 @@ set(TA_TRACKED_BTAS_PREVIOUS_TAG 7e64fbad97c76f316f313f4c8ed3fca5445da15f) set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece) set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83) +# libxsmm: pin a recent `main` commit, NOT the 1.17 release tag (2021). The +# modern dispatch API used by TiledArray/math/libxsmm_gemm.h +# (libxsmm_dispatch_gemm/libxsmm_create_gemm_shape/libxsmm_gemm_param) does not +# exist in 1.17, and 1.17 predates most of libxsmm's AArch64/Apple-Silicon JIT +# work. This SHA is the version validated on Apple M2 (reports as "1.17-3808"). +set(TA_TRACKED_LIBXSMM_TAG c14cbc6f8bc7964f8c5190a3a16b8cace03e5889) +set(TA_TRACKED_LIBXSMM_PREVIOUS_TAG c14cbc6f8bc7964f8c5190a3a16b8cace03e5889) + set(TA_TRACKED_UMPIRE-CXX-ALLOCATOR_TAG 1ba7f5f0aa99438826dd1c6bc1cd396080b9d608) set(TA_TRACKED_UMPIRE-CXX-ALLOCATOR_PREVIOUS_TAG 0f8144f19897766d0f117f7353221d4e3b8b1178) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 64bc4e2eda..6e1f97c251 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -135,6 +135,7 @@ TiledArray/external/madness.h TiledArray/host/env.h TiledArray/math/blas.h TiledArray/math/gemm_helper.h +TiledArray/math/libxsmm_gemm.h TiledArray/math/outer.h TiledArray/math/parallel_gemm.h TiledArray/math/partial_reduce.h @@ -215,6 +216,7 @@ set(TILEDARRAY_SOURCE_FILES TiledArray/einsum/index.cpp TiledArray/expressions/permopt.cpp TiledArray/host/env.cpp + TiledArray/math/libxsmm_gemm.cpp TiledArray/math/linalg/basic.cpp TiledArray/math/linalg/rank-local.cpp TiledArray/tensor/print.cpp @@ -303,6 +305,9 @@ endif(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) if( TARGET TiledArray_SCALAPACK ) list(APPEND _TILEDARRAY_DEPENDENCIES TiledArray_SCALAPACK) endif() +if( TARGET TiledArray_LIBXSMM ) + list(APPEND _TILEDARRAY_DEPENDENCIES TiledArray_LIBXSMM) +endif() list(APPEND _TILEDARRAY_DEPENDENCIES "${LAPACK_LIBRARIES}") if( TARGET ttg-parsec ) diff --git a/src/TiledArray/math/libxsmm_gemm.cpp b/src/TiledArray/math/libxsmm_gemm.cpp new file mode 100644 index 0000000000..f2f34bbe8d --- /dev/null +++ b/src/TiledArray/math/libxsmm_gemm.cpp @@ -0,0 +1,108 @@ +/// \file libxsmm_gemm.cpp +/// The ONLY translation unit that includes . Isolating the libxsmm +/// include here keeps its macros (libxsmm_macros.h) from leaking into any +/// TiledArray header. When TILEDARRAY_HAS_LIBXSMM is undefined, libxsmm_gemm_le64 +/// is still defined here as a `return false` stub so callers link unconditionally. + +#include "TiledArray/math/libxsmm_gemm.h" + +#ifdef TILEDARRAY_HAS_LIBXSMM +#include +#include +#include +#include +#endif + +namespace TiledArray::detail { + +/// max(M,N,K) cutoff above which we keep using the vendor BLAS. +static constexpr std::int64_t libxsmm_gemm_max_dim = 64; + +#ifdef TILEDARRAY_HAS_LIBXSMM +/// Runtime master switch for the libxsmm fast path. Even in a libxsmm-enabled +/// build, exporting `TA_LIBXSMM=0` (also accepts `off`/`OFF`/`false`/`no`) +/// routes EVERY strided micro-GEMM back through the vendor BLAS path, i.e. +/// libxsmm_gemm_le64() returns false for all shapes. Unset or any other value +/// => libxsmm ON. Read from the environment once, on first use, and cached. +static bool libxsmm_runtime_enabled() { + static const bool enabled = [] { + const char* v = std::getenv("TA_LIBXSMM"); + if (v == nullptr || *v == '\0') return true; // default ON when compiled in + return !(std::strcmp(v, "0") == 0 || std::strcmp(v, "off") == 0 || + std::strcmp(v, "OFF") == 0 || std::strcmp(v, "false") == 0 || + std::strcmp(v, "no") == 0); + }(); + return enabled; +} +#endif + +bool libxsmm_gemm_le64(bool trans_a, bool trans_b, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, + const double* a, std::int64_t lda, const double* b, + std::int64_t ldb, double beta, double* c, + std::int64_t ldc) { +#ifdef TILEDARRAY_HAS_LIBXSMM + // Runtime master switch: TA_LIBXSMM=0 sends everything back to vendor BLAS. + if (!libxsmm_runtime_enabled()) return false; + // libxsmm only for small shapes; max(M,N,K) <= 64. + if (m > libxsmm_gemm_max_dim || n > libxsmm_gemm_max_dim || + k > libxsmm_gemm_max_dim) + return false; + // libxsmm SMM has no alpha and only beta in {0,1} (LIBXSMM_GEMM_NO_BYPASS). + if (alpha != 1.0) return false; + if (beta != 0.0 && beta != 1.0) return false; + + static std::once_flag init_flag; + std::call_once(init_flag, [] { + // libxsmm's own verbose dispatch/JIT stats are part of the profiling + // result, so fold them into TA_PROFILE: when profiling is on and the user + // has NOT pinned LIBXSMM_VERBOSE explicitly, enable libxsmm verbosity here, + // BEFORE libxsmm_init() parses the environment. TA_PROFILE>=1 -> a concise + // exit summary (version + registry "gemm=" kernel count); TA_PROFILE>=2 + // -> verbose per-kernel JIT events. Must run before libxsmm_init(). + if (std::getenv("LIBXSMM_VERBOSE") == nullptr) { + const char* p = std::getenv("TA_PROFILE"); + const int lvl = (p != nullptr) ? std::atoi(p) : 0; + if (lvl >= 2) + setenv("LIBXSMM_VERBOSE", "3", /*overwrite=*/0); + else if (lvl >= 1) + setenv("LIBXSMM_VERBOSE", "2", /*overwrite=*/0); + } + libxsmm_init(); + }); + + // Mirror blas::gemm's row-major -> col-major mapping: it realizes the result + // as a column-major GEMM (op_b, op_a, n, m, k) with operands (b, a) swapped. + // libxsmm is column-major, so: A'=b (ld=ldb), B'=a (ld=lda), dims (n, m, k), + // TRANS_A from op_b, TRANS_B from op_a. + const libxsmm_bitfield flags = + (trans_b ? LIBXSMM_GEMM_FLAG_TRANS_A : 0) | + (trans_a ? LIBXSMM_GEMM_FLAG_TRANS_B : 0) | + (beta == 0.0 ? LIBXSMM_GEMM_FLAG_BETA_0 : 0); + + const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape( + static_cast(n), static_cast(m), + static_cast(k), static_cast(ldb), + static_cast(lda), static_cast(ldc), + LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, + LIBXSMM_DATATYPE_F64); + + const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm( + shape, flags, static_cast(LIBXSMM_GEMM_PREFETCH_NONE)); + if (kernel == nullptr) return false; // shape not JIT-able -> fall back + + libxsmm_gemm_param param; + std::memset(¶m, 0, sizeof param); + param.a.primary = const_cast(b); // A' = b + param.b.primary = const_cast(a); // B' = a + param.c.primary = c; + kernel(¶m); + return true; +#else + (void)trans_a; (void)trans_b; (void)m; (void)n; (void)k; (void)alpha; + (void)a; (void)lda; (void)b; (void)ldb; (void)beta; (void)c; (void)ldc; + return false; +#endif +} + +} // namespace TiledArray::detail diff --git a/src/TiledArray/math/libxsmm_gemm.h b/src/TiledArray/math/libxsmm_gemm.h new file mode 100644 index 0000000000..188c20b76b --- /dev/null +++ b/src/TiledArray/math/libxsmm_gemm.h @@ -0,0 +1,37 @@ +#ifndef TILEDARRAY_MATH_LIBXSMM_GEMM_H__INCLUDED +#define TILEDARRAY_MATH_LIBXSMM_GEMM_H__INCLUDED + +/// \file libxsmm_gemm.h +/// Optional libxsmm fast path for small strided tensor-of-tensors micro-GEMMs. +/// This header is DECLARATION-ONLY: the implementation lives in libxsmm_gemm.cpp, +/// which is the single translation unit that includes . Keeping the +/// libxsmm include out of this header is deliberate -- pulls in +/// libxsmm_macros.h, whose macros otherwise leak into TiledArray headers that +/// transitively include this one (e.g. arena_einsum.h -> ... -> math/vector_op.h, +/// breaking detail::is_scalar_v). Callers just see a plain function. + +#include + +// N.B. namespace TiledArray::detail (NOT TiledArray::math::detail): introducing +// a TiledArray::math::detail namespace would hijack unqualified `detail::` name +// lookup inside TiledArray::math headers (e.g. vector_op.h's detail::is_scalar_v, +// which lives in TiledArray::detail), breaking their compilation. +namespace TiledArray::detail { + +/// Computes C(m x n) [+]= alpha * op_a(A) . op_b(B) in **row-major** layout with +/// leading dims lda/ldb/ldc, i.e. exactly TiledArray::math::blas::gemm (double) +/// semantics. \p trans_a / \p trans_b are the transpose flags (true == Trans). +/// +/// \return true iff libxsmm performed the GEMM. Returns false (caller must fall +/// back to blas::gemm) when: built without libxsmm (TILEDARRAY_HAS_LIBXSMM +/// undefined), the runtime switch `TA_LIBXSMM=0` is set, max(m,n,k) > 64, +/// alpha != 1, beta not in {0,1}, or libxsmm could not JIT this shape. +bool libxsmm_gemm_le64(bool trans_a, bool trans_b, std::int64_t m, + std::int64_t n, std::int64_t k, double alpha, + const double* a, std::int64_t lda, const double* b, + std::int64_t ldb, double beta, double* c, + std::int64_t ldc); + +} // namespace TiledArray::detail + +#endif // TILEDARRAY_MATH_LIBXSMM_GEMM_H__INCLUDED diff --git a/src/TiledArray/tensor/arena_einsum.h b/src/TiledArray/tensor/arena_einsum.h index f49b259c4b..ca7f526b50 100644 --- a/src/TiledArray/tensor/arena_einsum.h +++ b/src/TiledArray/tensor/arena_einsum.h @@ -5,6 +5,7 @@ #include "TiledArray/error.h" #include "TiledArray/math/gemm_helper.h" +#include "TiledArray/math/libxsmm_gemm.h" #include "TiledArray/permutation.h" #include "TiledArray/tensor/arena.h" #include "TiledArray/tensor/arena_kernels.h" @@ -1247,14 +1248,27 @@ void arena_strided_dgemm_ce_e(ResultOuter& C, const LeftOuter& L, { ScopedShapedGemmTimer _gt(g_gemm_ns_ce_e, g_gemm_calls_ce_e, g_ce_e_shapes, P, Q, K); - blas::gemm(blas::Transpose, blas::NoTranspose, - /*M=*/static_cast(P), - /*N=*/static_cast(Q), - /*K=*/static_cast(K), factor, - /*A=*/l0.data(), /*lda=*/static_cast(ldA), - /*B=*/r0.data(), /*ldb=*/static_cast(ldB), - /*beta=*/1.0, - /*C=*/Cc.data(), /*ldc=*/static_cast(Q)); + // libxsmm fast path when max(M,N,K)<=64 (and alpha==1); else the + // vendor-BLAS fallback. The ScopedShapedGemmTimer above wraps EITHER + // path, so wall/shape capture is backend-agnostic. + if (!TiledArray::detail::libxsmm_gemm_le64( + /*trans_a=*/true, /*trans_b=*/false, + /*M=*/static_cast(P), + /*N=*/static_cast(Q), + /*K=*/static_cast(K), factor, + /*A=*/l0.data(), /*lda=*/static_cast(ldA), + /*B=*/r0.data(), /*ldb=*/static_cast(ldB), + /*beta=*/1.0, + /*C=*/Cc.data(), /*ldc=*/static_cast(Q))) { + blas::gemm(blas::Transpose, blas::NoTranspose, + /*M=*/static_cast(P), + /*N=*/static_cast(Q), + /*K=*/static_cast(K), factor, + /*A=*/l0.data(), /*lda=*/static_cast(ldA), + /*B=*/r0.data(), /*ldb=*/static_cast(ldB), + /*beta=*/1.0, + /*C=*/Cc.data(), /*ldc=*/static_cast(Q)); + } } #ifdef TA_STRIDED_DGEMM_COUNT g_strided_dgemm_ce_e_calls.fetch_add(1, std::memory_order_relaxed); @@ -1539,17 +1553,31 @@ void arena_strided_dgemm_ce_ce_right(ResultOuter& C, const LeftOuter& L, { ScopedShapedGemmTimer _gt(g_gemm_ns_ce_ce, g_gemm_calls_ce_ce, g_ce_ce_shapes, Mseg, P, Q); - blas::gemm( - blas::NoTranspose, - left_inner_transposed ? blas::NoTranspose : blas::Transpose, - /*M=*/static_cast(Mseg), - /*N=*/static_cast(P), - /*K=*/static_cast(Q), factor, - /*A=*/rstart, /*lda=*/static_cast(ldR), - /*B=*/Lk, - /*ldb=*/static_cast(left_inner_transposed ? P : Q), - /*beta=*/1.0, - /*C=*/cstart, /*ldc=*/static_cast(ldC)); + // libxsmm fast path when max(M,N,K)<=64 (and alpha==1); else the + // vendor-BLAS fallback. The ScopedShapedGemmTimer wraps EITHER path. + if (!TiledArray::detail::libxsmm_gemm_le64( + /*trans_a=*/false, + /*trans_b=*/!left_inner_transposed, + /*M=*/static_cast(Mseg), + /*N=*/static_cast(P), + /*K=*/static_cast(Q), factor, + /*A=*/rstart, /*lda=*/static_cast(ldR), + /*B=*/Lk, + /*ldb=*/static_cast(left_inner_transposed ? P : Q), + /*beta=*/1.0, + /*C=*/cstart, /*ldc=*/static_cast(ldC))) { + blas::gemm( + blas::NoTranspose, + left_inner_transposed ? blas::NoTranspose : blas::Transpose, + /*M=*/static_cast(Mseg), + /*N=*/static_cast(P), + /*K=*/static_cast(Q), factor, + /*A=*/rstart, /*lda=*/static_cast(ldR), + /*B=*/Lk, + /*ldb=*/static_cast(left_inner_transposed ? P : Q), + /*beta=*/1.0, + /*C=*/cstart, /*ldc=*/static_cast(ldC)); + } } #ifdef TA_STRIDED_DGEMM_COUNT g_strided_dgemm_ce_ce_right_calls.fetch_add(1, @@ -1775,17 +1803,31 @@ void arena_strided_dgemm_ce_ce_left(ResultOuter& C, const LeftOuter& L, { ScopedShapedGemmTimer _gt(g_gemm_ns_ce_ce, g_gemm_calls_ce_ce, g_ce_ce_shapes, Mseg, P, Q); - blas::gemm( - blas::NoTranspose, - right_inner_transposed ? blas::Transpose : blas::NoTranspose, - /*M=*/static_cast(Mseg), - /*N=*/static_cast(P), - /*K=*/static_cast(Q), factor, - /*A=*/lstart, /*lda=*/static_cast(ldA), - /*B=*/Rk, - /*ldb=*/static_cast(right_inner_transposed ? Q : P), - /*beta=*/1.0, - /*C=*/cstart, /*ldc=*/static_cast(ldC)); + // libxsmm fast path when max(M,N,K)<=64 (and alpha==1); else the + // vendor-BLAS fallback. The ScopedShapedGemmTimer wraps EITHER path. + if (!TiledArray::detail::libxsmm_gemm_le64( + /*trans_a=*/false, + /*trans_b=*/right_inner_transposed, + /*M=*/static_cast(Mseg), + /*N=*/static_cast(P), + /*K=*/static_cast(Q), factor, + /*A=*/lstart, /*lda=*/static_cast(ldA), + /*B=*/Rk, + /*ldb=*/static_cast(right_inner_transposed ? Q : P), + /*beta=*/1.0, + /*C=*/cstart, /*ldc=*/static_cast(ldC))) { + blas::gemm( + blas::NoTranspose, + right_inner_transposed ? blas::Transpose : blas::NoTranspose, + /*M=*/static_cast(Mseg), + /*N=*/static_cast(P), + /*K=*/static_cast(Q), factor, + /*A=*/lstart, /*lda=*/static_cast(ldA), + /*B=*/Rk, + /*ldb=*/static_cast(right_inner_transposed ? Q : P), + /*beta=*/1.0, + /*C=*/cstart, /*ldc=*/static_cast(ldC)); + } } #ifdef TA_STRIDED_DGEMM_COUNT g_strided_dgemm_ce_ce_left_calls.fetch_add(1, diff --git a/src/TiledArray/tensor/tensor.cpp b/src/TiledArray/tensor/tensor.cpp index 553f08d708..0d391b34c3 100644 --- a/src/TiledArray/tensor/tensor.cpp +++ b/src/TiledArray/tensor/tensor.cpp @@ -36,3 +36,19 @@ template class Tensor>; template class Tensor>; } // namespace TiledArray + +// --------------------------------------------------------------------------- +// libxsmm fast-path wrapper for the scale strided GEMM (declared in tensor.h). +// Kept out of tensor.h so the scale call sites need no libxsmm types; forwards +// to detail::libxsmm_gemm_le64 (declared in the lean libxsmm_gemm.h, +// defined in libxsmm_gemm.cpp -- the only TU that includes ). +#include "TiledArray/math/libxsmm_gemm.h" + +namespace TiledArray::detail { +bool scale_libxsmm_dgemm(bool trans_a, bool trans_b, long m, long n, long k, + const double* a, long lda, const double* b, long ldb, + double beta, double* c, long ldc) { + return TiledArray::detail::libxsmm_gemm_le64( + trans_a, trans_b, m, n, k, /*alpha=*/1.0, a, lda, b, ldb, beta, c, ldc); +} +} // namespace TiledArray::detail diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 998b764ba2..3cd236646d 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -137,6 +137,14 @@ struct ScaleRegimeCounters { }; inline ScaleRegimeCounters g_scale[2]; // [0]=tot_x_t, [1]=t_x_tot +/// Non-inline wrapper around detail::libxsmm_gemm_le64 (double), DEFINED +/// in tensor.cpp. Lets the scale path use the libxsmm fast path WITHOUT pulling +/// into the heavily-included tensor.h. Returns true iff libxsmm ran +/// the GEMM; false => the caller must fall back to blas::gemm. +bool scale_libxsmm_dgemm(bool trans_a, bool trans_b, long m, long n, long k, + const double* a, long lda, const double* b, long ldb, + double beta, double* c, long ldc); + /// Manual (non-scoped) phase clock for regions that set locals used later, so a /// timed scope can't wrap them. No-op unless TA_GEMM_TIMING is set. Mirrors the /// arena_einsum.h phase_start/phase_stop pattern. @@ -3422,13 +3430,27 @@ class Tensor { } const integer Ai = static_cast(A); detail::ScopedScaleTimer _scale_gt(detail::g_scale[0].gemm_ns); - TiledArray::math::blas::gemm( - TiledArray::math::blas::Transpose, - TiledArray::math::blas::NoTranspose, - /*M=*/N, /*N=*/Ai, /*K=*/K, Real(1), - /*A=*/right_data, /*lda=*/N, - /*B=*/lc0[0].data(), /*ldb=*/ldb, Real(1), - /*C=*/rc0[0].data(), /*ldc=*/ldc); + // libxsmm fast path when max(M,N,K)<=64 (alpha==1, beta in + // {0,1}); else vendor BLAS. The timer above wraps EITHER path. + // double only. + bool _xsmm = false; + if constexpr (std::is_same_v) { + _xsmm = detail::scale_libxsmm_dgemm( + /*trans_a=*/true, /*trans_b=*/false, + /*m=*/N, /*n=*/Ai, /*k=*/K, + /*a=*/right_data, /*lda=*/N, + /*b=*/lc0[0].data(), /*ldb=*/ldb, /*beta=*/1.0, + /*c=*/rc0[0].data(), /*ldc=*/ldc); + } + if (!_xsmm) { + TiledArray::math::blas::gemm( + TiledArray::math::blas::Transpose, + TiledArray::math::blas::NoTranspose, + /*M=*/N, /*N=*/Ai, /*K=*/K, Real(1), + /*A=*/right_data, /*lda=*/N, + /*B=*/lc0[0].data(), /*ldb=*/ldb, Real(1), + /*C=*/rc0[0].data(), /*ldc=*/ldc); + } } else { // per-cell AXPY fallback for this row if (detail::scale_gemm_timing_enabled()) { // classify fallback reason (re-scan; observation only, does @@ -3586,13 +3608,27 @@ class Tensor { } const integer Ai = static_cast(A); detail::ScopedScaleTimer _scale_gt(detail::g_scale[1].gemm_ns); - TiledArray::math::blas::gemm( - TiledArray::math::blas::NoTranspose, - TiledArray::math::blas::NoTranspose, - /*M=*/M, /*N=*/Ai, /*K=*/K, Real(1), - /*A=*/left_data, /*lda=*/K, - /*B=*/right_data[n].data(), /*ldb=*/ldb, Real(1), - /*C=*/this_data[n].data(), /*ldc=*/ldc); + // libxsmm fast path when max(M,N,K)<=64 (alpha==1, beta in + // {0,1}); else vendor BLAS. The timer above wraps EITHER path. + // double only. + bool _xsmm = false; + if constexpr (std::is_same_v) { + _xsmm = detail::scale_libxsmm_dgemm( + /*trans_a=*/false, /*trans_b=*/false, + /*m=*/M, /*n=*/Ai, /*k=*/K, + /*a=*/left_data, /*lda=*/K, + /*b=*/right_data[n].data(), /*ldb=*/ldb, /*beta=*/1.0, + /*c=*/this_data[n].data(), /*ldc=*/ldc); + } + if (!_xsmm) { + TiledArray::math::blas::gemm( + TiledArray::math::blas::NoTranspose, + TiledArray::math::blas::NoTranspose, + /*M=*/M, /*N=*/Ai, /*K=*/K, Real(1), + /*A=*/left_data, /*lda=*/K, + /*B=*/right_data[n].data(), /*ldb=*/ldb, Real(1), + /*C=*/this_data[n].data(), /*ldc=*/ldc); + } } else { // per-cell AXPY fallback for this column if (detail::scale_gemm_timing_enabled()) { // classify fallback reason (re-scan; observation only) + From 0f34462c23745660fe3c48b1cfae7923b7d041aa Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 22 Jun 2026 16:42:25 -0400 Subject: [PATCH 2/2] PR agent fix: libxsmm install/export, ld guard, TA_LIBXSMM=ON test gate - Install libxsmm.a+headers into TA's prefix; split TiledArray_LIBXSMM into BUILD/INSTALL interfaces so the exported config has no build-tree leak. - Guard 64->32-bit narrowing of lda/ldb/ldc in libxsmm_gemm_le64. - Make the libxsmm sub-make parallelism configurable (LIBXSMM_BUILD_NJOBS). - Add a TA_LIBXSMM=1 CTest gate + a direct scale_libxsmm_dgemm numerical test. --- external/libxsmm.cmake | 35 ++++++++++++++++--- src/TiledArray/math/libxsmm_gemm.cpp | 6 ++++ tests/CMakeLists.txt | 12 +++++++ tests/arena_strided_dgemm.cpp | 52 ++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 5 deletions(-) diff --git a/external/libxsmm.cmake b/external/libxsmm.cmake index 97cbab432e..77c28f55a7 100644 --- a/external/libxsmm.cmake +++ b/external/libxsmm.cmake @@ -70,6 +70,10 @@ else () set(_libxsmm_cxx ${CMAKE_CXX_COMPILER}) endif () + # parallelism for the libxsmm sub-make (overridable; default mirrors the + # project's typical build budget rather than hardcoding into the command) + set(LIBXSMM_BUILD_NJOBS 6 CACHE STRING "Parallel jobs for the libxsmm sub-make") + # libxsmm Make knobs: # STATIC=1 build libxsmm.a (we link the archive into tiledarray) # FORTRAN=0 skip the Fortran interface (no gfortran needed) @@ -93,7 +97,7 @@ else () BUILD_IN_SOURCE 1 CONFIGURE_COMMAND "" #--Build step----------------- build + install in one make invocation - BUILD_COMMAND make -j6 STATIC=1 FORTRAN=0 BLAS=0 + BUILD_COMMAND make -j${LIBXSMM_BUILD_NJOBS} STATIC=1 FORTRAN=0 BLAS=0 CC=${_libxsmm_cc} CXX=${_libxsmm_cxx} AR=${CMAKE_AR} PREFIX=${_LIBXSMM_INSTALL_DIR} install BUILD_BYPRODUCTS ${LIBXSMM_BUILD_BYPRODUCTS} @@ -112,16 +116,37 @@ else () endif (_LIBXSMM_INSTALL_DIR) +# Fold libxsmm's static lib + headers into TiledArray's OWN install prefix, so +# the exported TiledArray config is self-contained and does not reference TA's +# build tree (which a downstream find_package(TiledArray) consumer like MPQC +# would otherwise link against -- and which breaks once the build tree is +# wiped/relocated). Done for both the fetched and the prebuilt cases so the TA +# install is identical either way. +install(FILES "${_LIBXSMM_INSTALL_DIR}/lib/libxsmm.a" + DESTINATION "${TILEDARRAY_INSTALL_LIBDIR}" COMPONENT tiledarray) +install(DIRECTORY "${_LIBXSMM_INSTALL_DIR}/include/" + DESTINATION "${TILEDARRAY_INSTALL_INCLUDEDIR}" COMPONENT tiledarray) + # Synthetic target carrying the include dir, the static archive, and the gating -# define. PUBLIC propagation (via _TILEDARRAY_DEPENDENCIES) makes -# TILEDARRAY_HAS_LIBXSMM + the include path visible to consumers (e.g. MPQC). +# define. PUBLIC propagation (via _TILEDARRAY_DEPENDENCIES) makes the libxsmm.a +# link requirement reach consumers (libtiledarray is a static archive, so its +# undefined libxsmm symbols are resolved at the consumer's final link), plus +# TILEDARRAY_HAS_LIBXSMM + the include path. The link/include paths are split +# into BUILD_INTERFACE (TA's build tree) and INSTALL_INTERFACE (the installed +# copy above), so the exported config never references the build tree. The +# include INSTALL_INTERFACE is relative (CMake prepends the import prefix); the +# link library INSTALL_INTERFACE must be an absolute path to the installed +# archive -- CMake does NOT prepend the import prefix to INTERFACE_LINK_LIBRARIES +# entries, so a relative path there would be resolved against the consumer's cwd +# and fail to link. Absolute CMAKE_INSTALL_PREFIX is leak-free (the install tree +# is the stable final location, unlike the build tree). add_library(TiledArray_LIBXSMM INTERFACE) set_target_properties(TiledArray_LIBXSMM PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "$;$" + "$;$" INTERFACE_LINK_LIBRARIES - "${_LIBXSMM_INSTALL_DIR}/lib/libxsmm.a;${CMAKE_DL_LIBS}" + "$;$;${CMAKE_DL_LIBS}" INTERFACE_COMPILE_DEFINITIONS "TILEDARRAY_HAS_LIBXSMM" ) diff --git a/src/TiledArray/math/libxsmm_gemm.cpp b/src/TiledArray/math/libxsmm_gemm.cpp index f2f34bbe8d..00705c580e 100644 --- a/src/TiledArray/math/libxsmm_gemm.cpp +++ b/src/TiledArray/math/libxsmm_gemm.cpp @@ -9,6 +9,7 @@ #ifdef TILEDARRAY_HAS_LIBXSMM #include #include +#include #include #include #endif @@ -51,6 +52,11 @@ bool libxsmm_gemm_le64(bool trans_a, bool trans_b, std::int64_t m, // libxsmm SMM has no alpha and only beta in {0,1} (LIBXSMM_GEMM_NO_BYPASS). if (alpha != 1.0) return false; if (beta != 0.0 && beta != 1.0) return false; + // libxsmm_blasint is 32-bit; refuse leading dims that would narrow silently. + // (M,N,K are already <=64; lda/ldb/ldc are strides and unbounded in general.) + constexpr std::int64_t bi_max = + static_cast(std::numeric_limits::max()); + if (lda > bi_max || ldb > bi_max || ldc > bi_max) return false; static std::once_flag init_flag; std::call_once(init_flag, [] { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b6c4f2bb34..f5221b836f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -192,6 +192,18 @@ else() ) endif() +# When libxsmm is enabled, also run the strided-GEMM suites with the libxsmm +# fast path forced ON (TA_LIBXSMM=1), so its numerics are gated against the +# in-suite references (the default-OFF run above does not exercise libxsmm). +if (TA_LIBXSMM) + add_test(NAME tiledarray/unit/run-libxsmm + COMMAND $ --log_level=unit_scope + --run_test=arena_strided_dgemm_suite:arena_einsum_unit_suite) + set_tests_properties(tiledarray/unit/run-libxsmm + PROPERTIES FIXTURES_REQUIRED TA_UNIT_TESTS_EXEC + ENVIRONMENT "${TA_UNIT_TESTS_ENVIRONMENT};TA_LIBXSMM=1") +endif() + if (NOT TARGET test-cases-tiledarray) add_custom_target_subproject(tiledarray test-cases) endif() diff --git a/tests/arena_strided_dgemm.cpp b/tests/arena_strided_dgemm.cpp index bd99b12f43..ee716a951b 100644 --- a/tests/arena_strided_dgemm.cpp +++ b/tests/arena_strided_dgemm.cpp @@ -2109,4 +2109,56 @@ BOOST_AUTO_TEST_CASE(ce_ce_seg_killswitch_matches_left) { } } +// Directly validate the scale-path libxsmm entry point +// (TiledArray::detail::scale_libxsmm_dgemm) for BOTH scale regimes' transpose +// configs -- tot_x_t maps to (trans_a=true, trans_b=false) and t_x_tot maps to +// (trans_a=false, trans_b=false). The ce+e / ce+ce arena tests above already +// exercise libxsmm_gemm_le64 for these same transpose patterns, but the scale +// wrapper itself (tensor.cpp) is otherwise only reachable through Tensor::gemm; +// this gates its numerics standalone. When libxsmm is not active (built without +// TILEDARRAY_HAS_LIBXSMM, or TA_LIBXSMM=0) the wrapper returns false and writes +// nothing, so the numeric check is keyed on the returned flag. +BOOST_AUTO_TEST_CASE(scale_path_libxsmm_matches_reference) { + // Row-major C(m x n) = beta*C + op_a(A)(m x k) . op_b(B)(k x n), alpha=1, + // matching scale_libxsmm_dgemm / blas::gemm semantics. + auto ref_gemm = [](bool ta, bool tb, long m, long n, long k, + const std::vector& A, long lda, + const std::vector& B, long ldb, double beta, + std::vector& C, long ldc) { + for (long i = 0; i < m; ++i) + for (long j = 0; j < n; ++j) { + double acc = 0.0; + for (long p = 0; p < k; ++p) { + const double av = ta ? A[p * lda + i] : A[i * lda + p]; + const double bv = tb ? B[j * ldb + p] : B[p * ldb + j]; + acc += av * bv; + } + C[i * ldc + j] = beta * C[i * ldc + j] + acc; + } + }; + struct Cfg { bool ta, tb; long m, n, k; }; + for (const Cfg cfg : {Cfg{true, false, 5, 7, 4}, // tot_x_t shape + Cfg{false, false, 6, 3, 8}}) { // t_x_tot shape + const long m = cfg.m, n = cfg.n, k = cfg.k; + const long lda = cfg.ta ? m : k; // A is (k x m) if trans_a else (m x k) + const long ldb = cfg.tb ? k : n; // B is (n x k) if trans_b else (k x n) + const long ldc = n; + std::vector A(static_cast(m) * k), + B(static_cast(k) * n), + C(static_cast(m) * n, 0.3); + for (std::size_t i = 0; i < A.size(); ++i) A[i] = 0.1 * double(i) + 0.5; + for (std::size_t i = 0; i < B.size(); ++i) B[i] = 0.2 * double(i) - 0.3; + std::vector Cref = C; + ref_gemm(cfg.ta, cfg.tb, m, n, k, A, lda, B, ldb, /*beta=*/1.0, Cref, ldc); + const bool ran = TA::detail::scale_libxsmm_dgemm( + cfg.ta, cfg.tb, m, n, k, A.data(), lda, B.data(), ldb, /*beta=*/1.0, + C.data(), ldc); + if (ran) { + for (long i = 0; i < m * n; ++i) + BOOST_CHECK_CLOSE(C[static_cast(i)], + Cref[static_cast(i)], 1e-10); + } + } +} + BOOST_AUTO_TEST_SUITE_END()