diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f85893c47..d9647031ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -405,6 +405,18 @@ endif(USE_KML) if (USE_SW) add_compile_definitions(__SW) + # IEEE-compliant FP: required for swFFT/DFTI numerical correctness on Sunway. + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag("-mieee" ABACUS_CXX_SUPPORTS_MIEEE) + if(ABACUS_CXX_SUPPORTS_MIEEE) + add_compile_options(-mieee) + endif() + # CPE-DFTI: accelerate FFT_CPU local 1D sticks FFTs via swFFT xMath DFTI (CPE). + # Needs the isolated libswfft_xmath_iso.a linked below. Disable with -DUSE_SWDFTI=OFF. + option(USE_SWDFTI "Use swFFT CPE DFTI for local 1D FFTs" ON) + if(USE_SWDFTI) + add_compile_definitions(__SWDFTI) + endif() set(SW ON) include_directories(${SW_MATH}/include) include_directories(${SW_FFT}/include) @@ -889,7 +901,17 @@ if(ENABLE_RAPIDJSON) endif() if (USE_SW) - target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswfft.a) + # CPE-DFTI engine: only needed when the SWDFTI backend is actually compiled in. + # Link the objcopy-ISOLATED xMath swfft (fftw_* renamed -> swfftpriv_*) so it + # provides DftiInitAthread/Compute* WITHOUT hijacking ABACUS FFTW (raw + # ${SW_MATH}/libswfft.a breaks the density FFT). Build per HOWTO. + if(USE_SWDFTI) + set(_swfft_iso ${CMAKE_CURRENT_SOURCE_DIR}/source/source_base/module_fft/libswfft_xmath_iso.a) + if(NOT EXISTS ${_swfft_iso}) + message(FATAL_ERROR "USE_SWDFTI=ON but ${_swfft_iso} is missing; build it per HOWTO or set -DUSE_SWDFTI=OFF.") + endif() + target_link_libraries(${ABACUS_BIN_NAME} ${_swfft_iso}) + endif() target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswscalapack.a) target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswlapack.a) target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a) diff --git a/source/source_base/module_fft/fft_bundle.cpp b/source/source_base/module_fft/fft_bundle.cpp index a1292c34e4..e65dbc2047 100644 --- a/source/source_base/module_fft/fft_bundle.cpp +++ b/source/source_base/module_fft/fft_bundle.cpp @@ -11,6 +11,9 @@ #if defined(__ROCM) #include "fft_rocm.h" #endif +#if defined(__SWDFTI) +#include "fft_swdfti.h" // CPE-DFTI CPU backend (Sunway) +#endif #if defined(__DSP) #include "fft_dsp.h" #endif @@ -90,7 +93,11 @@ void FFT_Bundle::initfft(int nx_in, } if (double_flag) { +#if defined(__SWDFTI) + fft_double = make_unique>(this->fft_mode); // CPE-DFTI sticks FFT +#else fft_double = make_unique>(this->fft_mode); +#endif fft_double ->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in); } diff --git a/source/source_base/module_fft/fft_cpu.h b/source/source_base/module_fft/fft_cpu.h index ec47768d8e..7f7c0582aa 100644 --- a/source/source_base/module_fft/fft_cpu.h +++ b/source/source_base/module_fft/fft_cpu.h @@ -98,7 +98,7 @@ class FFT_CPU : public FFT_BASE ABACUS_FFT_WEAK void fftxyc2r(std::complex* in, FPTYPE* out) const override; - private: + protected: // exposed so FFT_SWDFTI (CPE DFTI) can reuse plans/dims void clearfft(fftw_plan& plan); void clearfft(fftwf_plan& plan); diff --git a/source/source_base/module_fft/fft_swdfti.cpp b/source/source_base/module_fft/fft_swdfti.cpp new file mode 100644 index 0000000000..2ec6b076d6 --- /dev/null +++ b/source/source_base/module_fft/fft_swdfti.cpp @@ -0,0 +1,122 @@ +#include "fft_swdfti.h" + +#include +#include +#include +extern "C" { +#include "swfft.h" // xMath-SACA swFFT DFTI API (CPE) +} + +namespace ModuleBase +{ + +template <> +void FFT_SWDFTI::setupFFT() +{ + // build all the FFTW plans / buffers exactly as the CPU backend does + FFT_CPU::setupFFT(); + + if (std::getenv("ABACUS_NO_DFTI") != nullptr) { return; } // A/B: keep FFTW + + // thread-safe one-time CPE spawn (setupFFT may be reached from >1 thread) + static std::once_flag dfti_athread_once; + std::call_once(dfti_athread_once, []() { DftiInitAthread(DFTI_SPAWN_QUICK); }); + + // batched 1D-z: ns transforms of length nz, contiguous (stride 1, distance nz), in-place + DFTI_DESCRIPTOR_HANDLE hz = nullptr; + DftiCreateDescriptor(&hz, DFTI_DOUBLE, DFTI_COMPLEX, 1, (DFTI_LONG)this->nz); + DftiSetValue(hz, DFTI_NUMBER_OF_TRANSFORMS, (DFTI_LONG)this->ns); + DftiSetValue(hz, DFTI_INPUT_DISTANCE, (DFTI_LONG)this->nz); + DftiSetValue(hz, DFTI_OUTPUT_DISTANCE, (DFTI_LONG)this->nz); + DftiSetValue(hz, DFTI_PLACEMENT, (DFTI_LONG)DFTI_INPLACE); + DftiCommitDescriptor(hz); + this->dftiz = (void*)hz; + + // strided 1D-x: nx-length, (nplane*ny) transforms, stride npy, distance 1 + // (only the xprime / non-gamma k-point layout). y stays on FFTW. + if (this->xprime && !this->gamma_only) + { + const int npy_ = this->nplane * this->ny; + DFTI_DESCRIPTOR_HANDLE hx = nullptr; + DftiCreateDescriptor(&hx, DFTI_DOUBLE, DFTI_COMPLEX, 1, (DFTI_LONG)this->nx); + DftiSetValue(hx, DFTI_NUMBER_OF_TRANSFORMS, (DFTI_LONG)npy_); + { DFTI_LONG st[2] = {0, (DFTI_LONG)npy_}; DftiSetValue(hx, DFTI_INPUT_STRIDES, st); DftiSetValue(hx, DFTI_OUTPUT_STRIDES, st); } + DftiSetValue(hx, DFTI_INPUT_DISTANCE, (DFTI_LONG)1); + DftiSetValue(hx, DFTI_OUTPUT_DISTANCE, (DFTI_LONG)1); + DftiSetValue(hx, DFTI_PLACEMENT, (DFTI_LONG)DFTI_INPLACE); + DftiCommitDescriptor(hx); + this->dftix = (void*)hx; + } +} + +template <> +void FFT_SWDFTI::cleanFFT() +{ + FFT_CPU::cleanFFT(); + // release the DFTI descriptors before dropping the handles (else they leak) + if (this->dftiz != nullptr) + { + DFTI_DESCRIPTOR_HANDLE hz = (DFTI_DESCRIPTOR_HANDLE)this->dftiz; + DftiFreeDescriptor(&hz); + this->dftiz = nullptr; + } + if (this->dftix != nullptr) + { + DFTI_DESCRIPTOR_HANDLE hx = (DFTI_DESCRIPTOR_HANDLE)this->dftix; + DftiFreeDescriptor(&hx); + this->dftix = nullptr; + } +} + +template <> +void FFT_SWDFTI::fftzfor(std::complex* in, std::complex* out) const +{ + if (this->dftiz == nullptr) { FFT_CPU::fftzfor(in, out); return; } + if (in != out) std::memcpy(out, in, sizeof(std::complex) * (size_t)this->nz * (size_t)this->ns); + DftiComputeForward((DFTI_DESCRIPTOR_HANDLE)this->dftiz, (void*)out); +} + +template <> +void FFT_SWDFTI::fftzbac(std::complex* in, std::complex* out) const +{ + if (this->dftiz == nullptr) { FFT_CPU::fftzbac(in, out); return; } + if (in != out) std::memcpy(out, in, sizeof(std::complex) * (size_t)this->nz * (size_t)this->ns); + DftiComputeBackward((DFTI_DESCRIPTOR_HANDLE)this->dftiz, (void*)out); +} + +template <> +void FFT_SWDFTI::fftxyfor(std::complex* in, std::complex* out) const +{ + const int npy = this->nplane * this->ny; + if (this->xprime && this->dftix != nullptr) + { + if (in != out) std::memcpy(out, in, sizeof(std::complex) * (size_t)this->nx * (size_t)npy); + DftiComputeForward((DFTI_DESCRIPTOR_HANDLE)this->dftix, (void*)out); // x via CPE + for (int i = 0; i < this->lixy + 1; ++i) // y via FFTW + fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + for (int i = this->rixy; i < this->nx; ++i) + fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + return; + } + FFT_CPU::fftxyfor(in, out); // non-xprime / disabled -> FFTW +} + +template <> +void FFT_SWDFTI::fftxybac(std::complex* in, std::complex* out) const +{ + const int npy = this->nplane * this->ny; + if (this->xprime && this->dftix != nullptr) + { + if (in != out) std::memcpy(out, in, sizeof(std::complex) * (size_t)this->nx * (size_t)npy); + for (int i = 0; i < this->lixy + 1; ++i) // y via FFTW + fftw_execute_dft(this->planybac, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + for (int i = this->rixy; i < this->nx; ++i) + fftw_execute_dft(this->planybac, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + DftiComputeBackward((DFTI_DESCRIPTOR_HANDLE)this->dftix, (void*)out); // x via CPE + return; + } + FFT_CPU::fftxybac(in, out); // non-xprime / disabled -> FFTW +} + +template class FFT_SWDFTI; +} // namespace ModuleBase diff --git a/source/source_base/module_fft/fft_swdfti.h b/source/source_base/module_fft/fft_swdfti.h new file mode 100644 index 0000000000..20afe2ca4f --- /dev/null +++ b/source/source_base/module_fft/fft_swdfti.h @@ -0,0 +1,38 @@ +#ifndef FFT_SWDFTI_H +#define FFT_SWDFTI_H +// CPE-accelerated CPU FFT backend: subclasses FFT_CPU and overrides only the +// local 1D sticks FFTs (batched z, strided x) with the Sunway swFFT xMath DFTI +// API (offloaded to the 64 CPEs via DftiInitAthread). Everything else (plan +// setup, 2D-xy y-direction, r2c/c2r, box 3D) is inherited from FFT_CPU/FFTW. +// Compiled only on Sunway (USE_SWDFTI) and selected by the FFT factory in +// FFT_Bundle for device "cpu" -- so FFT_CPU itself stays free of any DFTI macro. +#include + +#include "fft_cpu.h" + +namespace ModuleBase +{ +template +class FFT_SWDFTI : public FFT_CPU +{ + public: + FFT_SWDFTI() {}; + FFT_SWDFTI(const int fft_mode_in) : FFT_CPU(fft_mode_in) {}; + ~FFT_SWDFTI() {}; + + void setupFFT() override; + void cleanFFT() override; + + void fftzfor(std::complex* in, std::complex* out) const override; + void fftzbac(std::complex* in, std::complex* out) const override; + void fftxyfor(std::complex* in, std::complex* out) const override; + void fftxybac(std::complex* in, std::complex* out) const override; + + private: + // swFFT DFTI descriptors: z (batched ns x nz contiguous) and x (strided). + // y stays on FFTW (CPE loses on the small per-slice y-batch). null => FFTW. + void* dftiz = nullptr; + void* dftix = nullptr; +}; +} // namespace ModuleBase +#endif diff --git a/source/source_basis/module_pw/CMakeLists.txt b/source/source_basis/module_pw/CMakeLists.txt index 912772e057..dfc415f925 100644 --- a/source/source_basis/module_pw/CMakeLists.txt +++ b/source/source_basis/module_pw/CMakeLists.txt @@ -20,6 +20,12 @@ if (USE_DSP) pw_transform_k_dsp.cpp) endif() +if (USE_SWDFTI) + list (APPEND FFT_SRC + ../../source_base/module_fft/fft_swdfti.cpp + ) +endif() + list(APPEND objects pw_basis.cpp pw_basis_k.cpp