diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c2de3..662f3a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,8 +60,5 @@ jobs: html=$(find -L result/share/doc -type d -name html | head -1) echo "HADDOCK_DIR=$html" >> "$GITHUB_ENV" - - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v4 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ${{ env.HADDOCK_DIR }} + - name: Build and run tests + run: nix develop --command bash -c 'cabal install && cabal test' diff --git a/.gitignore b/.gitignore index aee1772..d36b981 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ result/ cabal.project.local tags /.stack-work/ +/.ghc.environment* diff --git a/README.md b/README.md index 7e2e104..8511202 100644 --- a/README.md +++ b/README.md @@ -53,25 +53,25 @@ cd arrayfire-haskell To build and run all tests in response to file changes ```bash -nix-shell --run test-runner +nix develop --command cabal test ``` To perform interactive development w/ `ghcid` ```bash -nix-shell --run ghcid +nix develop --command cabal repl ``` To interactively evaluate code in the `repl` ```bash -nix-shell --run repl +nix develop --command cabal repl ``` To produce the haddocks and open them in a browser ```bash -nix-shell --run docs +nix develop --command cabal haddock ``` @@ -84,9 +84,12 @@ import qualified ArrayFire as A import Control.Exception (catch) main :: IO () -main = print newArray `catch` (\(e :: A.AFException) -> print e) - where - newArray = A.matrix @Double (2,2) [ [1..], [1..] ] * A.matrix @Double (2,2) [ [2..], [2..] ] +main = withArrayFire $ do + print newArray `catch` (\(e :: A.AFException) -> print e) + where + newArray = + A.matrix @Double (2,2) [ [1..], [1..] ] * + A.matrix @Double (2,2) [ [2..], [2..] ] {-| diff --git a/arrayfire.cabal b/arrayfire.cabal index d7474af..8bf85c4 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -1,6 +1,6 @@ cabal-version: 3.0 name: arrayfire -version: 0.7.1.0 +version: 0.8.0.0 synopsis: Haskell bindings to the ArrayFire general-purpose GPU library homepage: https://github.com/arrayfire/arrayfire-haskell license: BSD-3-Clause @@ -41,6 +41,8 @@ library ArrayFire.Backend ArrayFire.BLAS ArrayFire.Data + ArrayFire.Exception + ArrayFire.Internal.Defines ArrayFire.Device ArrayFire.Features ArrayFire.Graphics @@ -56,7 +58,6 @@ library ArrayFire.Vision other-modules: ArrayFire.FFI - ArrayFire.Exception ArrayFire.Orphans ArrayFire.Internal.Algorithm ArrayFire.Internal.Arith @@ -64,7 +65,6 @@ library ArrayFire.Internal.Backend ArrayFire.Internal.BLAS ArrayFire.Internal.Data - ArrayFire.Internal.Defines ArrayFire.Internal.Device ArrayFire.Internal.Exception ArrayFire.Internal.Features @@ -87,6 +87,7 @@ library af c-sources: cbits/wrapper.c + cbits/eigsh.c build-depends: base < 5, deepseq, filepath, vector hs-source-dirs: @@ -156,6 +157,7 @@ test-suite test HUnit, QuickCheck, quickcheck-classes, + semirings, vector, call-stack >=0.4 && <0.5 if !flag(disable-build-tool-depends) @@ -172,11 +174,13 @@ test-suite test ArrayFire.BackendSpec ArrayFire.DataSpec ArrayFire.DeviceSpec + ArrayFire.ExceptionSpec ArrayFire.FeaturesSpec ArrayFire.GraphicsSpec ArrayFire.ImageSpec ArrayFire.IndexSpec ArrayFire.LAPACKSpec + ArrayFire.NumericalSpec ArrayFire.RandomSpec ArrayFire.SignalSpec ArrayFire.SparseSpec diff --git a/cbits/eigsh.c b/cbits/eigsh.c new file mode 100644 index 0000000..2ecc77c --- /dev/null +++ b/cbits/eigsh.c @@ -0,0 +1,418 @@ +/* + * cbits/eigsh.c + * + * Symmetric eigendecomposition with backend dispatch: + * + * CUDA — cusolverDnDsyevd / cusolverDnSsyevd via dlopen/dlsym. + * Binds cuSOLVER to ArrayFire's CUDA stream (best-effort). + * Uses AF pinned memory for devInfo so convergence failures + * (devInfo != 0) are detected and trigger the CPU fallback. + * Falls back to the CPU path when cuSOLVER is unavailable. + * + * CPU / OpenCL — Classical Jacobi eigenvalue algorithm on the host. + * af_get_data_ptr copies the matrix to host memory; the Jacobi + * sweeps diagonalise it in place; af_create_array puts the + * results back. Handles degenerate eigenvalues correctly and + * needs no external library. + * + * No link-time dependency on the CUDA toolkit or libafcuda. + */ + +#define _GNU_SOURCE +#include "arrayfire.h" +#include +#include +#include +#include +#include + +/* ── column-major element access ── */ +#define ELEM(a, r, c, n) ((a)[(r) + (size_t)(c) * (n)]) + +/* ══════════════════════════════════════════════════════════════════════════ + * Jacobi eigenvalue algorithm (host, column-major, real symmetric). + * + * On entry a[n*n] — symmetric matrix. + * On exit a[n*n] — eigenvectors as columns. + * evals[n] — eigenvalues in the order Jacobi produced them + * (NOT yet sorted). + * Returns 0 on success, 1 if malloc fails. + * ══════════════════════════════════════════════════════════════════════════*/ + +static int jacobi_d(int n, double *a, double *evals) +{ + double *v = malloc((size_t)n * n * sizeof(double)); + if (!v) return 1; + + memset(v, 0, (size_t)n * n * sizeof(double)); + for (int i = 0; i < n; i++) ELEM(v, i, i, n) = 1.0; + + /* Up to 50 full sweeps; typical convergence is << 10 for moderate n. */ + for (int sweep = 0; sweep < 50 * n; sweep++) { + /* Locate largest off-diagonal element */ + int p = 0, q = 1; + double max_off = 0.0; + for (int c = 1; c < n; c++) { + for (int r = 0; r < c; r++) { + double val = fabs(ELEM(a, r, c, n)); + if (val > max_off) { max_off = val; p = r; q = c; } + } + } + if (max_off < 1e-14) break; + + double apq = ELEM(a, p, q, n); + double tau = (ELEM(a, q, q, n) - ELEM(a, p, p, n)) / (2.0 * apq); + double sign = (tau >= 0.0) ? 1.0 : -1.0; + double t = sign / (fabs(tau) + sqrt(1.0 + tau * tau)); + double cs = 1.0 / sqrt(1.0 + t * t); + double sn = t * cs; + + /* Rotate A */ + ELEM(a, p, p, n) -= t * apq; + ELEM(a, q, q, n) += t * apq; + ELEM(a, p, q, n) = ELEM(a, q, p, n) = 0.0; + for (int r = 0; r < n; r++) { + if (r == p || r == q) continue; + double arp = ELEM(a, r, p, n), arq = ELEM(a, r, q, n); + ELEM(a, r, p, n) = ELEM(a, p, r, n) = cs * arp - sn * arq; + ELEM(a, r, q, n) = ELEM(a, q, r, n) = cs * arq + sn * arp; + } + /* Accumulate rotation in V */ + for (int r = 0; r < n; r++) { + double vrp = ELEM(v, r, p, n), vrq = ELEM(v, r, q, n); + ELEM(v, r, p, n) = cs * vrp - sn * vrq; + ELEM(v, r, q, n) = cs * vrq + sn * vrp; + } + } + + for (int i = 0; i < n; i++) evals[i] = ELEM(a, i, i, n); + memcpy(a, v, (size_t)n * n * sizeof(double)); + free(v); + return 0; +} + +static int jacobi_f(int n, float *a, float *evals) +{ + float *v = malloc((size_t)n * n * sizeof(float)); + if (!v) return 1; + + memset(v, 0, (size_t)n * n * sizeof(float)); + for (int i = 0; i < n; i++) ELEM(v, i, i, n) = 1.0f; + + for (int sweep = 0; sweep < 50 * n; sweep++) { + int p = 0, q = 1; + float max_off = 0.0f; + for (int c = 1; c < n; c++) { + for (int r = 0; r < c; r++) { + float val = fabsf(ELEM(a, r, c, n)); + if (val > max_off) { max_off = val; p = r; q = c; } + } + } + if (max_off < 1e-6f) break; + + float apq = ELEM(a, p, q, n); + float tau = (ELEM(a, q, q, n) - ELEM(a, p, p, n)) / (2.0f * apq); + float sign = (tau >= 0.0f) ? 1.0f : -1.0f; + float t = sign / (fabsf(tau) + sqrtf(1.0f + tau * tau)); + float cs = 1.0f / sqrtf(1.0f + t * t); + float sn = t * cs; + + ELEM(a, p, p, n) -= t * apq; + ELEM(a, q, q, n) += t * apq; + ELEM(a, p, q, n) = ELEM(a, q, p, n) = 0.0f; + for (int r = 0; r < n; r++) { + if (r == p || r == q) continue; + float arp = ELEM(a, r, p, n), arq = ELEM(a, r, q, n); + ELEM(a, r, p, n) = ELEM(a, p, r, n) = cs * arp - sn * arq; + ELEM(a, r, q, n) = ELEM(a, q, r, n) = cs * arq + sn * arp; + } + for (int r = 0; r < n; r++) { + float vrp = ELEM(v, r, p, n), vrq = ELEM(v, r, q, n); + ELEM(v, r, p, n) = cs * vrp - sn * vrq; + ELEM(v, r, q, n) = cs * vrq + sn * vrp; + } + } + + for (int i = 0; i < n; i++) evals[i] = ELEM(a, i, i, n); + memcpy(a, v, (size_t)n * n * sizeof(float)); + free(v); + return 0; +} + +/* Selection sort on eigenvalues, mirroring the column swaps in evecs. */ +static void sort_eigs_d(int n, double *evals, double *evecs) +{ + for (int i = 0; i < n - 1; i++) { + int min_j = i; + for (int j = i + 1; j < n; j++) + if (evals[j] < evals[min_j]) min_j = j; + if (min_j == i) continue; + double tmp = evals[i]; evals[i] = evals[min_j]; evals[min_j] = tmp; + for (int r = 0; r < n; r++) { + double tv = evecs[r + (size_t)i * n]; + evecs[r + (size_t)i * n] = evecs[r + (size_t)min_j * n]; + evecs[r + (size_t)min_j * n] = tv; + } + } +} + +static void sort_eigs_f(int n, float *evals, float *evecs) +{ + for (int i = 0; i < n - 1; i++) { + int min_j = i; + for (int j = i + 1; j < n; j++) + if (evals[j] < evals[min_j]) min_j = j; + if (min_j == i) continue; + float tmp = evals[i]; evals[i] = evals[min_j]; evals[min_j] = tmp; + for (int r = 0; r < n; r++) { + float tv = evecs[r + (size_t)i * n]; + evecs[r + (size_t)i * n] = evecs[r + (size_t)min_j * n]; + evecs[r + (size_t)min_j * n] = tv; + } + } +} + +/* ══════════════════════════════════════════════════════════════════════════ + * CPU / OpenCL fallback: copy to host, Jacobi, copy back. + * ══════════════════════════════════════════════════════════════════════════*/ +static af_err eigsh_cpu(af_array *evals_out, af_array *evecs_out, + const af_array input) +{ + af_dtype dtype; + af_err err; + if ((err = af_get_type(&dtype, input)) != AF_SUCCESS) return err; + + dim_t d0, d1, d2, d3; + if ((err = af_get_dims(&d0, &d1, &d2, &d3, input)) != AF_SUCCESS) return err; + int n = (int)d0; + + size_t elem_size = (dtype == f64) ? sizeof(double) : sizeof(float); + + void *A = malloc((size_t)n * n * elem_size); + if (!A) return AF_ERR_NO_MEM; + void *W = malloc((size_t)n * elem_size); + if (!W) { free(A); return AF_ERR_NO_MEM; } + + if ((err = af_get_data_ptr(A, input)) != AF_SUCCESS) { + free(A); free(W); return err; + } + + int ret = (dtype == f64) ? jacobi_d(n, (double *)A, (double *)W) + : jacobi_f(n, (float *)A, (float *)W); + if (ret != 0) { free(A); free(W); return AF_ERR_NO_MEM; } + + if (dtype == f64) sort_eigs_d(n, (double *)W, (double *)A); + else sort_eigs_f(n, (float *)W, (float *)A); + + dim_t dims_eval = (dim_t)n; + dim_t dims_evec[2] = { (dim_t)n, (dim_t)n }; + af_array evals = NULL, evecs = NULL; + if ((err = af_create_array(&evals, W, 1, &dims_eval, dtype)) != AF_SUCCESS) + goto cleanup; + if ((err = af_create_array(&evecs, A, 2, dims_evec, dtype)) != AF_SUCCESS) { + af_release_array(evals); + goto cleanup; + } + free(A); free(W); + *evals_out = evals; + *evecs_out = evecs; + return AF_SUCCESS; + +cleanup: + free(A); free(W); + return err; +} + +/* ══════════════════════════════════════════════════════════════════════════ + * cuSOLVER GPU path (CUDA only). + * ══════════════════════════════════════════════════════════════════════════*/ + +/* ── minimal cuSOLVER types (avoids CUDA toolkit headers) ── */ +typedef void *cusolverDnHandle_t; +typedef void *af_cuda_stream_t; +typedef int cusolverStatus_t; + +#define CUSOLVER_STATUS_SUCCESS 0 +#define CUBLAS_FILL_MODE_LOWER 0 +#define CUSOLVER_EIG_MODE_VECTOR 1 + +typedef cusolverStatus_t (*pfn_Create) (cusolverDnHandle_t *); +typedef cusolverStatus_t (*pfn_SetStream) (cusolverDnHandle_t, af_cuda_stream_t); +typedef cusolverStatus_t (*pfn_DsyevdBuf)(cusolverDnHandle_t, int, int, int, + const double *, int, const double *, int *); +typedef cusolverStatus_t (*pfn_Dsyevd) (cusolverDnHandle_t, int, int, int, + double *, int, double *, double *, int, int *); +typedef cusolverStatus_t (*pfn_SsyevdBuf)(cusolverDnHandle_t, int, int, int, + const float *, int, const float *, int *); +typedef cusolverStatus_t (*pfn_Ssyevd) (cusolverDnHandle_t, int, int, int, + float *, int, float *, float *, int, int *); +typedef af_err (*pfn_GetStream) (af_cuda_stream_t *, int); + +static cusolverDnHandle_t g_handle = NULL; +static pfn_Create fn_Create = NULL; +static pfn_SetStream fn_SetStr = NULL; +static pfn_DsyevdBuf fn_DsyBuf = NULL; +static pfn_Dsyevd fn_Dsyevd = NULL; +static pfn_SsyevdBuf fn_SsyBuf = NULL; +static pfn_Ssyevd fn_Ssyevd = NULL; +static int g_init = 0; + +static af_err load_and_init(void) +{ + /* Try versioned sonames (CUDA 11 then 12) then the unversioned symlink. */ + void *lib = dlopen("libcusolver.so.11", RTLD_NOW | RTLD_NOLOAD); + if (!lib) lib = dlopen("libcusolver.so.11", RTLD_NOW | RTLD_GLOBAL); + if (!lib) lib = dlopen("libcusolver.so.12", RTLD_NOW | RTLD_GLOBAL); + if (!lib) lib = dlopen("libcusolver.so", RTLD_NOW | RTLD_GLOBAL); + if (!lib) return AF_ERR_RUNTIME; + + fn_Create = (pfn_Create) dlsym(lib, "cusolverDnCreate"); + fn_SetStr = (pfn_SetStream) dlsym(lib, "cusolverDnSetStream"); + fn_DsyBuf = (pfn_DsyevdBuf) dlsym(lib, "cusolverDnDsyevd_bufferSize"); + fn_Dsyevd = (pfn_Dsyevd) dlsym(lib, "cusolverDnDsyevd"); + fn_SsyBuf = (pfn_SsyevdBuf) dlsym(lib, "cusolverDnSsyevd_bufferSize"); + fn_Ssyevd = (pfn_Ssyevd) dlsym(lib, "cusolverDnSsyevd"); + + if (!fn_Create || !fn_SetStr || !fn_DsyBuf || !fn_Dsyevd || + !fn_SsyBuf || !fn_Ssyevd) + return AF_ERR_RUNTIME; + + if (fn_Create(&g_handle) != CUSOLVER_STATUS_SUCCESS) + return AF_ERR_INTERNAL; + + /* Bind cuSOLVER to AF's CUDA stream so calls are ordered with AF ops. */ + pfn_GetStream fn_GetStr = + (pfn_GetStream) dlsym(RTLD_DEFAULT, "afcu_get_stream"); + if (fn_GetStr) { + af_cuda_stream_t stream = NULL; + if (fn_GetStr(&stream, 0) == AF_SUCCESS && stream) + fn_SetStr(g_handle, stream); + } + return AF_SUCCESS; +} + +static af_err ensure_init(void) +{ + if (g_init) return g_handle ? AF_SUCCESS : AF_ERR_RUNTIME; + g_init = 1; + return load_and_init(); +} + +/* + * run_syevd — call cuSOLVER in-place; overwrites d_A with eigenvectors. + * + * devInfo is placed in AF pinned host memory so it is readable from the + * host after af_sync without a separate cudaMemcpy. Passing pinned host + * memory to cuSOLVER is valid under CUDA UVA (CUDA 4.0+ / CC 2.0+). + * Returns AF_ERR_INTERNAL if the solver signals non-convergence (devInfo != 0). + */ +static af_err run_syevd(int is_double, int n, void *d_A, void *d_W) +{ + int lwork; + cusolverStatus_t st; + + if (is_double) { + st = fn_DsyBuf(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, + n, (const double *)d_A, n, (const double *)d_W, &lwork); + } else { + st = fn_SsyBuf(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, + n, (const float *)d_A, n, (const float *)d_W, &lwork); + } + if (st != CUSOLVER_STATUS_SUCCESS) return AF_ERR_INTERNAL; + + dim_t wsz = (dim_t)lwork * (is_double ? sizeof(double) : sizeof(float)); + void *d_work = NULL; + af_err err; + if ((err = af_alloc_device_v2(&d_work, wsz)) != AF_SUCCESS) return err; + + /* Pinned host memory — accessible from device via UVA. */ + int *h_info = NULL; + if ((err = af_alloc_pinned((void **)&h_info, sizeof(int))) != AF_SUCCESS) { + af_free_device_v2(d_work); + return err; + } + *h_info = 0; + + if (is_double) { + st = fn_Dsyevd(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, + n, (double *)d_A, n, (double *)d_W, + (double *)d_work, lwork, h_info); + } else { + st = fn_Ssyevd(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, + n, (float *)d_A, n, (float *)d_W, + (float *)d_work, lwork, h_info); + } + af_free_device_v2(d_work); + + if (st != CUSOLVER_STATUS_SUCCESS) { + af_free_pinned(h_info); + return AF_ERR_INTERNAL; + } + + /* Sync so the cuSOLVER kernel's write to h_info is visible on the host. */ + int cur_dev = 0; + af_get_device(&cur_dev); + af_sync(cur_dev); + + int devInfo = *h_info; + af_free_pinned(h_info); + return (devInfo == 0) ? AF_SUCCESS : AF_ERR_INTERNAL; +} + +/* ── public entry point ── */ +af_err af_eigsh(af_array *evals_out, af_array *evecs_out, const af_array input) +{ + af_err err; + + af_dtype dtype; + if ((err = af_get_type(&dtype, input)) != AF_SUCCESS) return err; + if (dtype != f64 && dtype != f32) return AF_ERR_TYPE; + + af_backend backend; + if ((err = af_get_active_backend(&backend)) != AF_SUCCESS) return err; + + if (backend != AF_BACKEND_CUDA) + return eigsh_cpu(evals_out, evecs_out, input); + + if (ensure_init() != AF_SUCCESS) + return eigsh_cpu(evals_out, evecs_out, input); + + dim_t d0, d1, d2, d3; + if ((err = af_get_dims(&d0, &d1, &d2, &d3, input)) != AF_SUCCESS) return err; + int n = (int)d0; + + af_array evecs; + if ((err = af_copy_array(&evecs, input)) != AF_SUCCESS) return err; + + af_array evals; + dim_t n_dim = (dim_t)n; + if ((err = af_constant(&evals, 0.0, 1, &n_dim, dtype)) != AF_SUCCESS) { + af_release_array(evecs); + return err; + } + + void *d_A = NULL, *d_W = NULL; + if ((err = af_get_device_ptr(&d_A, evecs)) != AF_SUCCESS) { + af_release_array(evecs); af_release_array(evals); + return err; + } + if ((err = af_get_device_ptr(&d_W, evals)) != AF_SUCCESS) { + af_unlock_array(evecs); + af_release_array(evecs); af_release_array(evals); + return err; + } + + err = run_syevd(dtype == f64, n, d_A, d_W); + + af_unlock_array(evecs); + af_unlock_array(evals); + + if (err != AF_SUCCESS) { + af_release_array(evecs); af_release_array(evals); + return eigsh_cpu(evals_out, evecs_out, input); + } + + *evals_out = evals; + *evecs_out = evecs; + return AF_SUCCESS; +} diff --git a/cbits/wrapper.c b/cbits/wrapper.c index 1b101a6..43e8bc8 100644 --- a/cbits/wrapper.c +++ b/cbits/wrapper.c @@ -1,5 +1,4 @@ #include "arrayfire.h" -#include af_err af_random_engine_set_type_(af_random_engine engine, const af_random_engine_type rtype) { return af_random_engine_set_type(&engine, rtype); } @@ -7,35 +6,14 @@ af_err af_random_engine_set_seed_(af_random_engine engine, const unsigned long l return af_random_engine_set_seed(&engine, seed); } -void test_bool () { - double * data = malloc (sizeof (int) * 5); - data[0] = 2; - data[1] = 2; - data[2] = 2; - data[3] = 2; - data[4] = 2; - data[5] = 2; - dim_t * dims = malloc(sizeof(dim_t) * 4); - dims[0] = 5; - dims[1] = 1; - dims[2] = 1; - dims[3] = 1; - af_array arrin; - af_create_array(&arrin, data, 1, dims, f64); - printf("printing input array\n"); - af_print_array(arrin); - af_array arrout; - af_product(&arrout, arrin, 0); - printf("printing output array\n"); - af_print_array(arrout); -} +static volatile int af_shutting_down = 0; -void test_window () { - af_window window; - af_create_window(&window, 100, 100, "foo"); - af_show(window); +void af_notify_shutdown(void) { + af_shutting_down = 1; } -void zeroOutArray (af_array * arr) { - (*arr) = 0; +/* Safe finalizer: no-ops on null handles and after af_notify_shutdown(). */ +void af_release_array_safe(af_array arr) { + if (!af_shutting_down && arr) + af_release_array(arr); } diff --git a/exe/Main.hs b/exe/Main.hs index 80f26ca..f498aa9 100644 --- a/exe/Main.hs +++ b/exe/Main.hs @@ -9,93 +9,8 @@ import Control.Concurrent import Control.Exception import Prelude hiding (sum, product) --- import GHC.RTS -foreign import ccall safe "test_bool" - testBool :: IO () - -foreign import ccall safe "test_window" - testWindow :: IO () - -main' :: IO () -main' = print newArray `catch` (\(e :: AFException) -> print e) +main :: IO () +main = print newArray `catch` (\(e :: AFException) -> print e) where newArray = matrix @Double (2,2) [ [1..], [1..] ] * matrix @Double (2,2) [ [2..], [2..] ] - -main :: IO () -main = do - main' - -- testWindow - -- ks <- randn @Double [100,100] - -- saveArray "key" ks "array.txt" False - -- !ks' <- readArrayKey "array.txt" "key" - -- print ks' - --- info >> putStrLn "ok" >> afInit --- -- Info things --- print =<< getSizeOf (Proxy @ Double) --- print =<< getVersion --- print =<< getRevision --- -- getInfo --- -- print =<< errorToString afErrNoMem --- putStrLn =<< getInfoString --- print =<< getDeviceCount --- print =<< getDevice - --- -- Create and print an array --- -- arr1 <- constant 1 1 1 f64 --- -- arr2 <- constant 2 1 1 f64 --- -- r <- addArray arr1 arr2 True --- -- printArray r - --- -- print =<< isLAPACKAvailable --- -- print =<< getAvailableBackends --- -- print =<< getActiveBackend --- -- print =<< getAvailableBackends - --- -- array <- constant @'(10,10) 200 --- -- putStrLn "backend id" --- -- print (getBackendID array) --- -- putStrLn "device id" --- -- print (getDeviceID array) - --- -- array <- randu @'(9,9,9) @Double --- -- printArray array -- printArray (mean array 0) - --- -- printArray (add array 1) - --- -- putStrLn "got eeem" --- -- print =<< getDataPtr x - --- -- x <- constant 10 1 1 f64 --- -- printArray =<< mean x 0 - --- -- print =<< isLAPACKAvailable - --- a <- randu @'(3,3) @Float --- b <- randu @'(3,3) @Float --- printArray ((a `matmul` b) None None) --- `catch` (\(e :: AFException) -> do --- putStrLn "got one" --- print e) - - putStrLn "create window" - window <- createWindow 200 200 "hey" - putStrLn "set visibility" - setVisibility window True - putStrLn "show window" - showWindow window - threadDelay (secs 10) - --- -- print =<< getActiveBackend --- -- print =<< getDeviceCount --- -- print =<< getDevice --- -- putStrLn "info" --- -- getInfo --- -- putStrLn "info string" --- -- putStrLn =<< getInfoString --- -- print =<< getVersion - - -secs :: Int -> Int -secs = (*1000000) diff --git a/flake.nix b/flake.nix index c8c3c30..4d9c869 100644 --- a/flake.nix +++ b/flake.nix @@ -198,7 +198,7 @@ ps.shellFor { packages = ps: if hasArrayfire then [ ps.arrayfire ] else [ ]; withHoogle = true; - buildInputs = with pkgs; (if isLinux then [ ocl-icd ] else [ darwin.apple_sdk.frameworks.Security ]); + buildInputs = with pkgs; (if isLinux then [ ocl-icd ] else [ ]); nativeBuildInputs = with pkgs; with ps; [ # Building and testing cabal-install diff --git a/include/algorithm.h b/include/algorithm.h index 8894a73..c36f8d3 100644 --- a/include/algorithm.h +++ b/include/algorithm.h @@ -34,3 +34,12 @@ af_err af_sort_by_key(af_array *out_keys, af_array *out_values, const af_array k af_err af_set_unique(af_array *out, const af_array in, const bool is_sorted); af_err af_set_union(af_array *out, const af_array first, const af_array second, const bool is_unique); af_err af_set_intersect(af_array *out, const af_array first, const af_array second, const bool is_unique); +af_err af_sum_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_sum_by_key_nan(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim, const double nanval); +af_err af_product_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_product_by_key_nan(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim, const double nanval); +af_err af_min_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_max_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_all_true_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_any_true_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); +af_err af_count_by_key(af_array *keys_out, af_array *vals_out, const af_array keys, const af_array vals, const int dim); diff --git a/include/blas.h b/include/blas.h index d872069..e70bdba 100644 --- a/include/blas.h +++ b/include/blas.h @@ -5,3 +5,4 @@ af_err af_dot(af_array *out, const af_array lhs, const af_array rhs, const af_ma af_err af_dot_all(double *real, double *imag, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs); af_err af_transpose(af_array *out, af_array in, const bool conjugate); af_err af_transpose_inplace(af_array in, const bool conjugate); +af_err af_gemm(af_array *out, const af_mat_prop optLhs, const af_mat_prop optRhs, const void *alpha, const af_array lhs, const af_array rhs, const void *beta); diff --git a/include/statistics.h b/include/statistics.h index c3ddd9b..59f37bc 100644 --- a/include/statistics.h +++ b/include/statistics.h @@ -15,3 +15,4 @@ af_err af_stdev_all(double *real, double *imag, const af_array in); af_err af_median_all(double *realVal, double *imagVal, const af_array in); af_err af_corrcoef(double *realVal, double *imagVal, const af_array X, const af_array Y); af_err af_topk(af_array *values, af_array *indices, const af_array in, const int k, const int dim, const af_topk_function order); +af_err af_meanvar(af_array *mean, af_array *var, const af_array in, const af_array weights, const af_var_bias bias, const dim_t dim); diff --git a/src/ArrayFire.hs b/src/ArrayFire.hs index f5cf814..02a1141 100644 --- a/src/ArrayFire.hs +++ b/src/ArrayFire.hs @@ -1,7 +1,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -298,9 +298,14 @@ import Data.Word -- -- $conversion --- Any 'Array' can be exported into Haskell using `toVector'. This will create a Storable vector suitable for use in other C programs. +-- Any 'Array' can be exported into Haskell using 'toVector'. This will create a 'Storable' vector suitable for use in other C programs. 'fromVector' can be used -- -- >>> vector :: Vector Double <- toVector <$> randu @Double [10,10] +-- >>> let array :: Array Double = fromVector @Double [10,10] vector +-- +-- >>> original <- randu @Double [10,10] +-- >>> original == fromVector [10,10] (toVector og :: Vector Double) +-- >>> True -- -- $serialization @@ -328,7 +333,7 @@ import Data.Word -- $device -- The ArrayFire API is able to see which devices are present, and will by default use the GPU if available. -- --- >>> afInfo +-- >>> info -- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5) -- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB <-- brackets [] signify device being used. -- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB diff --git a/src/ArrayFire/Algorithm.hs b/src/ArrayFire/Algorithm.hs index b7fccba..1f2bca0 100644 --- a/src/ArrayFire/Algorithm.hs +++ b/src/ArrayFire/Algorithm.hs @@ -4,7 +4,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Algorithm --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -26,6 +26,9 @@ -------------------------------------------------------------------------------- module ArrayFire.Algorithm where +import Data.Word (Word32) +import Foreign.C.Types (CBool) + import ArrayFire.FFI import ArrayFire.Internal.Algorithm import ArrayFire.Internal.Types @@ -63,7 +66,7 @@ sum x (fromIntegral -> n) = (x `op1` (\p a -> af_sum p a n)) -- | Sum all of the elements in 'Array' along the specified dimension, using a default value for NaN -- --- >>> A.sumNaN (A.vector @Double 10 [1..]) 0 0.0 +-- >>> let nan = 0/0 in A.sumNaN (A.vector @Double 10 (nan : [1..])) 0 10.0 -- ArrayFire Array -- [1 1 1 1] -- 55.0000 @@ -97,7 +100,7 @@ product x (fromIntegral -> n) = (x `op1` (\p a -> af_product p a n)) -- | Product all of the elements in 'Array' along the specified dimension, using a default value for NaN -- --- >>> A.productNaN (A.vector @Double 10 [1..]) 0 0.0 +-- >>> let nan = 0/0 in A.productNaN (A.vector @Double 10 (nan : [1..])) 0 2.0 -- ArrayFire Array -- [1 1 1 1] -- 3628800.0000 @@ -147,18 +150,18 @@ max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n) -- | Find if all elements in an 'Array' are 'True' along a dimension -- --- >>> A.allTrue (A.vector @CBool 10 (repeat 0)) 0 +-- >>> A.allTrue (A.vector @CBool 10 (repeat 1)) 0 -- ArrayFire Array -- [1 1 1 1] --- 0 +-- 1 allTrue - :: forall a. AFType a + :: AFType a => Array a -- ^ Array input -> Int -- ^ Dimension along which to see if all elements are True - -> Array a - -- ^ Will contain the maximum of all values in the input array along dim + -> Array CBool + -- ^ Will contain 1 where all elements along dim are true, 0 otherwise allTrue x (fromIntegral -> n) = x `op1` (\p a -> af_all_true p a n) @@ -169,13 +172,13 @@ allTrue x (fromIntegral -> n) = -- [1 1 1 1] -- 0 anyTrue - :: forall a . AFType a + :: AFType a => Array a -- ^ Array input -> Int - -- ^ Dimension along which to see if all elements are True - -> Array a - -- ^ Returns if all elements are true + -- ^ Dimension along which to see if any elements are True + -> Array CBool + -- ^ Will contain 1 where any element along dim is true, 0 otherwise anyTrue x (fromIntegral -> n) = (x `op1` (\p a -> af_any_true p a n)) @@ -193,119 +196,119 @@ count -- ^ Dimension along which to count -> Array Int -- ^ Count of all elements along dimension -count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n) +count x (fromIntegral -> n) = x `op1` (\p a -> af_count p a n) -- | Sum all elements in an 'Array' along all dimensions -- -- >>> A.sumAll (A.vector @Double 10 [1..]) -- (55.0,0.0) sumAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -sumAll = (`infoFromArray2` af_sum_all) +sumAll = toAFResult @a . (`infoFromArray2` af_sum_all) -- | Sum all elements in an 'Array' along all dimensions, using a default value for NaN -- --- >>> A.sumNaNAll (A.vector @Double 10 [1..]) 0.0 +-- >>> let nan = 0/0 in A.sumNaNAll (A.vector @Double 10 (nan : [1..])) 0.0 -- (55.0,0.0) sumNaNAll - :: (AFType a, Fractional a) + :: forall a . (AFResult a, Fractional a) => Array a -- ^ Input array -> Double -- ^ NaN substitute - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -sumNaNAll a d = infoFromArray2 a (\p g x -> af_sum_nan_all p g x d) +sumNaNAll a d = toAFResult @a $ infoFromArray2 a (\p g x -> af_sum_nan_all p g x d) -- | Product all elements in an 'Array' along all dimensions, using a default value for NaN -- -- >>> A.productAll (A.vector @Double 10 [1..]) -- (3628800.0,0.0) productAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -productAll = (`infoFromArray2` af_product_all) +productAll = toAFResult @a . (`infoFromArray2` af_product_all) -- | Product all elements in an 'Array' along all dimensions, using a default value for NaN -- -- >>> A.productNaNAll (A.vector @Double 10 [1..]) 1.0 -- (3628800.0,0.0) productNaNAll - :: (AFType a, Fractional a) + :: forall a . (AFResult a, Fractional a) => Array a -- ^ Input array -> Double -- ^ NaN substitute - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -productNaNAll a d = infoFromArray2 a (\p x y -> af_product_nan_all p x y d) +productNaNAll a d = toAFResult @a $ infoFromArray2 a (\p x y -> af_product_nan_all p x y d) -- | Take the minimum across all elements along all dimensions in 'Array' -- -- >>> A.minAll (A.vector @Double 10 [1..]) -- (1.0,0.0) minAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -minAll = (`infoFromArray2` af_min_all) +minAll = toAFResult @a . (`infoFromArray2` af_min_all) -- | Take the maximum across all elements along all dimensions in 'Array' -- -- >>> A.maxAll (A.vector @Double 10 [1..]) -- (10.0,0.0) maxAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -maxAll = (`infoFromArray2` af_max_all) +maxAll = toAFResult @a . (`infoFromArray2` af_max_all) -- | Decide if all elements along all dimensions in 'Array' are True -- -- >>> A.allTrueAll (A.vector @CBool 10 (repeat 1)) -- (1.0, 0.0) allTrueAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -allTrueAll = (`infoFromArray2` af_all_true_all) +allTrueAll = toAFResult @a . (`infoFromArray2` af_all_true_all) -- | Decide if any elements along all dimensions in 'Array' are True -- -- >>> A.anyTrueAll $ A.vector @CBool 10 (repeat 0) -- (0.0,0.0) anyTrueAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -anyTrueAll = (`infoFromArray2` af_any_true_all) +anyTrueAll = toAFResult @a . (`infoFromArray2` af_any_true_all) -- | Count all elements along all dimensions in 'Array' -- -- >>> A.countAll (A.matrix @Double (100,100) (replicate 100 [1..])) -- (10000.0,0.0) countAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double) + -> Scalar a -- ^ imaginary and real part -countAll = (`infoFromArray2` af_count_all) +countAll = toAFResult @a . (`infoFromArray2` af_count_all) -- | Find the minimum element along a specified dimension in 'Array' -- @@ -323,7 +326,7 @@ imin -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the minimum of all values along dim, will also contain the location of minimum of all values in in along dim imin a (fromIntegral -> n) = op2p a (\x y z -> af_imin x y z n) @@ -343,7 +346,7 @@ imax -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the maximum of all values in in along dim, will also contain the location of maximum of all values in in along dim imax a (fromIntegral -> n) = op2p a (\x y z -> af_imax x y z n) @@ -352,28 +355,28 @@ imax a (fromIntegral -> n) = op2p a (\x y z -> af_imax x y z n) -- >>> A.iminAll (A.vector @Double 10 [1..]) -- (1.0,0.0,0) iminAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double, Int) + -> (Scalar a, Int) -- ^ will contain the real part of minimum value of all elements in input in, also will contain the imaginary part of minimum value of all elements in input in, will contain the location of minimum of all values in iminAll a = do let (x,y,fromIntegral -> z) = a `infoFromArray3` af_imin_all - (x,y,z) + (toAFResult @a (x,y), z) -- | Find the maximum element along all dimensions in 'Array' -- -- >>> A.imaxAll (A.vector @Double 10 [1..]) -- (10.0,0.0,9) imaxAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input array - -> (Double, Double, Int) + -> (Scalar a, Int) -- ^ will contain the real part of maximum value of all elements in input in, also will contain the imaginary part of maximum value of all elements in input in, will contain the location of maximum of all values in imaxAll a = do let (x,y,fromIntegral -> z) = a `infoFromArray3` af_imax_all - (x,y,z) + (toAFResult @a (x,y), z) -- | Calculate sum of 'Array' across specified dimension -- @@ -471,8 +474,8 @@ where' :: AFType a => Array a -- ^ Is the input array. - -> Array a - -- ^ will contain indices where input array is non-zero + -> Array Word32 + -- ^ Indices where input array is non-zero where' = (`op1` af_where) -- | First order numerical difference along specified dimension. @@ -513,7 +516,7 @@ diff2 a (fromIntegral -> n) = a `op1` (\p x -> af_diff2 p x n) -- | Sort an Array along a specified dimension, specifying ordering of results (ascending / descending) -- --- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 True +-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 Asc -- ArrayFire Array -- [4 1 1 1] -- 1.0000 @@ -521,7 +524,7 @@ diff2 a (fromIntegral -> n) = a `op1` (\p x -> af_diff2 p x n) -- 3.0000 -- 4.0000 -- --- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 False +-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 Desc -- ArrayFire Array -- [4 1 1 1] -- 4.0000 @@ -534,7 +537,7 @@ sort -- ^ Input array -> Int -- ^ Dimension along `sort` is performed - -> Bool + -> Order -- ^ Return results in ascending order -> Array a -- ^ Will contain sorted input @@ -543,7 +546,7 @@ sort a (fromIntegral -> n) (fromIntegral . fromEnum -> b) = -- | Sort an 'Array' along a specified dimension, specifying ordering of results (ascending / descending), returns indices of sorted results -- --- >>> A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 True +-- >>> A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 Asc -- (ArrayFire Array -- [4 1 1 1] -- 1.0000 @@ -563,13 +566,18 @@ sortIndex -- ^ Input array -> Int -- ^ Dimension along `sortIndex` is performed - -> Bool + -> Order -- ^ Return results in ascending order - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Contains the sorted, contains indices for original input sortIndex a (fromIntegral -> n) (fromIntegral . fromEnum -> b) = a `op2p` (\p1 p2 p3 -> af_sort_index p1 p2 p3 n b) + +-- | Data type for expressing sort order +data Order = Desc | Asc + deriving (Enum, Show, Eq) + -- | Sort an 'Array' along a specified dimension by keys, specifying ordering of results (ascending / descending) -- -- >>> A.sortByKey (A.vector @Double 4 [2,1,4,3]) (A.vector @Double 4 [10,9,8,7]) 0 True @@ -594,7 +602,7 @@ sortByKey -- ^ Values input array -> Int -- ^ Dimension along which to perform the operation - -> Bool + -> Order -- ^ Return results in ascending order -> (Array a, Array a) sortByKey a1 a2 (fromIntegral -> n) (fromIntegral . fromEnum -> b) = @@ -657,3 +665,143 @@ setIntersect -- ^ Intersection of first and second array setIntersect a1 a2 (fromIntegral . fromEnum -> b) = op2 a1 a2 (\x y z -> af_set_intersect x y z b) + +-- | Sum values in 'Array' grouped by keys along a dimension. +-- +-- Each contiguous run of equal keys in @keys@ produces one output element. +-- Returns @(keys_out, vals_out)@. +-- +-- >>> sumByKey (vector @Int 5 [1,1,2,2,2]) (vector @Double 5 [10,20,1,2,3]) 0 +-- (ArrayFire Array +-- [2 1 1 1] +-- 1 2, +-- ArrayFire Array +-- [2 1 1 1] +-- 30.0000 6.0000) +sumByKey + :: AFType a + => Array Int + -- ^ Keys array (contiguous equal keys form a group) + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension along which to reduce + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key ko vo k v dim) + +-- | 'sumByKey' replacing NaN values with a substitute before summing. +sumByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key_nan ko vo k v dim nanval) + +-- | Product of values in 'Array' grouped by keys along a dimension. +productByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +productByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_product_by_key ko vo k v dim) + +-- | 'productByKey' replacing NaN values with a substitute before multiplying. +productByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) +productByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_product_by_key_nan ko vo k v dim nanval) + +-- | Minimum of values in 'Array' grouped by keys along a dimension. +minByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +minByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_min_by_key ko vo k v dim) + +-- | Maximum of values in 'Array' grouped by keys along a dimension. +maxByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +maxByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_max_by_key ko vo k v dim) + +-- | True if all values are true within each key group. +-- +-- The value output is always boolean (@b8@) regardless of the input value type. +allTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array CBool) +allTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_all_true_by_key ko vo k v dim) + +-- | True if any value is true within each key group. +-- +-- The value output is always boolean (@b8@) regardless of the input value type. +anyTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array CBool) +anyTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_any_true_by_key ko vo k v dim) + +-- | Count non-zero values within each key group. +-- +-- The value output is always @u32@ regardless of the input value type. +countByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array Word32) +countByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_count_by_key ko vo k v dim) diff --git a/src/ArrayFire/Arith.hs b/src/ArrayFire/Arith.hs index ec2cc25..6e689d4 100644 --- a/src/ArrayFire/Arith.hs +++ b/src/ArrayFire/Arith.hs @@ -5,7 +5,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Arith --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -28,7 +28,7 @@ -------------------------------------------------------------------------------- module ArrayFire.Arith where -import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFrac) +import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFloat) import Data.Coerce import Data.Proxy @@ -512,7 +512,7 @@ not -- ^ Input 'Array' -> Array CBool -- ^ Result of 'not' on an 'Array' -not = flip op1d af_not +not = flip op1 af_not -- | Bitwise and the values in one 'Array' against another 'Array' -- @@ -526,10 +526,10 @@ bitAnd -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bitwise and bitAnd x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 1 -- | Bitwise and the values in one 'Array' against another 'Array' @@ -546,10 +546,10 @@ bitAndBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool + -> Array a -- ^ Result of bitwise and bitAndBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 batch -- | Bitwise or the values in one 'Array' against another 'Array' @@ -564,10 +564,10 @@ bitOr -- ^ First input -> Array a -- ^ Second input - -> Array CBool - -- ^ Result of bit or + -> Array a + -- ^ Result of bitwise or bitOr x y = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 1 -- | Bitwise or the values in one 'Array' against another 'Array' @@ -584,10 +584,10 @@ bitOrBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit or + -> Array a + -- ^ Result of bitwise or bitOrBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 batch -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -602,10 +602,10 @@ bitXor -- ^ First input -> Array a -- ^ Second input - -> Array CBool - -- ^ Result of bit xor + -> Array a + -- ^ Result of bitwise xor bitXor x y = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 1 -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -622,10 +622,10 @@ bitXorBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit xor + -> Array a + -- ^ Result of bitwise xor bitXorBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 batch -- | Left bit shift the values in one 'Array' against another 'Array' @@ -640,10 +640,10 @@ bitShiftL -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bit shift left bitShiftL x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 1 -- | Left bit shift the values in one 'Array' against another 'Array' @@ -660,10 +660,10 @@ bitShiftLBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool + -> Array a -- ^ Result of bit shift left bitShiftLBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 batch -- | Right bit shift the values in one 'Array' against another 'Array' @@ -678,10 +678,10 @@ bitShiftR -- ^ First input -> Array a -- ^ Second input - -> Array CBool + -> Array a -- ^ Result of bit shift right bitShiftR x y = - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 1 -- | Right bit shift the values in one 'Array' against another 'Array' @@ -698,10 +698,10 @@ bitShiftRBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array CBool - -- ^ Result of bit shift left + -> Array a + -- ^ Result of bit shift right bitShiftRBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2bool` y $ \arr arr1 arr2 -> + x `op2` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 batch -- | Cast one 'Array' into another @@ -717,7 +717,7 @@ cast -> Array b -- ^ Result of cast cast afArr = - coerce $ afArr `op1` (\x y -> af_cast x y dtyp) + coerce $ afArr `op1` (\x y -> ArrayFire.Internal.Arith.af_cast x y dtyp) where dtyp = afType (Proxy @b) @@ -1299,7 +1299,8 @@ atan2Batched x y (fromIntegral . fromEnum -> batch) = do x `op2` y $ \arr arr1 arr2 -> af_atan2 arr arr1 arr2 batch --- | Take the cplx2 of all values in an 'Array' +-- | Construct a complex 'Array' from two real 'Array's, taking the first as the +-- real part and the second as the imaginary part. -- -- >>> A.cplx2 (A.vector @Int 10 [1..]) (A.vector @Int 10 [1..]) -- ArrayFire Array @@ -1315,18 +1316,19 @@ atan2Batched x y (fromIntegral . fromEnum -> batch) = do -- (9.0000,9.0000) -- (10.0000,10.0000) cplx2 - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a - -- ^ First input + -- ^ First input (real part) -> Array a - -- ^ Second input - -> Array a - -- ^ Result of cplx2 + -- ^ Second input (imaginary part) + -> Array (Complex a) + -- ^ Complex result with the inputs as real and imaginary parts cplx2 x y = x `op2` y $ \arr arr1 arr2 -> af_cplx2 arr arr1 arr2 1 --- | Take the cplx2Batched of all values in an 'Array' +-- | Construct a complex 'Array' from two real 'Array's (real and imaginary +-- parts), with explicit control over batched broadcasting of the inputs. -- -- >>> A.cplx2Batched (A.vector @Int 10 [1..]) (A.vector @Int 10 [1..]) True -- ArrayFire Array @@ -1342,15 +1344,15 @@ cplx2 x y = -- (9.0000,9.0000) -- (10.0000,10.0000) cplx2Batched - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a - -- ^ First input + -- ^ First input (real part) -> Array a - -- ^ Second input + -- ^ Second input (imaginary part) -> Bool - -- ^ Use batch - -> Array a - -- ^ Result of cplx2 + -- ^ Whether to enable batched broadcasting of the inputs + -> Array (Complex a) + -- ^ Complex result with the inputs as real and imaginary parts cplx2Batched x y (fromIntegral . fromEnum -> batch) = do x `op2` y $ \arr arr1 arr2 -> af_cplx2 arr arr1 arr2 batch @@ -1371,11 +1373,11 @@ cplx2Batched x y (fromIntegral . fromEnum -> batch) = do -- (9.0000,0.0000) -- (10.0000,0.0000) cplx - :: AFType a + :: (RealFloat a, AFType a, AFType (Complex a)) => Array a -- ^ Input array - -> Array a - -- ^ Result of calling 'atan' + -> Array (Complex a) + -- ^ Complex array with input as real part and zero imaginary part cplx = flip op1 af_cplx -- | Execute real @@ -1385,12 +1387,12 @@ cplx = flip op1 af_cplx -- [1 1 1 1] -- 10.0000 real - :: (AFType a, AFType (Complex b), RealFrac a, RealFrac b) - => Array (Complex b) + :: (RealFloat a, AFType a, AFType (Complex a)) + => Array (Complex a) -- ^ Input array -> Array a - -- ^ Result of calling 'real' -real = flip op1d af_real + -- ^ Real part of each element +real = flip op1 af_real -- | Execute imag -- @@ -1399,12 +1401,12 @@ real = flip op1d af_real -- [1 1 1 1] -- 11.0000 imag - :: (AFType a, AFType (Complex b), RealFrac a, RealFrac b) - => Array (Complex b) + :: (RealFloat a, AFType a, AFType (Complex a)) + => Array (Complex a) -- ^ Input array -> Array a - -- ^ Result of calling 'imag' -imag = flip op1d af_imag + -- ^ Imaginary part of each element +imag = flip op1 af_imag -- | Execute conjg -- @@ -1567,32 +1569,24 @@ atanh -- ^ Result of calling 'tanh' atanh = flip op1 af_atanh --- | Execute root +-- | Execute root: compute the nth root of each element. +-- @root base n@ computes @base^(1\/n)@. -- --- >>> A.root (A.vector @Double 10 [1..]) (A.vector @Double 10 [1..]) +-- >>> A.root (A.scalar @Double 1 8) (A.scalar @Double 1 3) -- ArrayFire Array --- [10 1 1 1] --- 1.0000 --- 1.4142 --- 1.4422 --- 1.4142 --- 1.3797 --- 1.3480 --- 1.3205 --- 1.2968 --- 1.2765 --- 1.2589 +-- [1 1 1 1] +-- 2.0000 root :: AFType a => Array a - -- ^ First input + -- ^ The input data (base) -> Array a - -- ^ Second input + -- ^ The root degree (n) -> Array a - -- ^ Result of root + -- ^ Result: base^(1\/n) root x y = x `op2` y $ \arr arr1 arr2 -> - af_root arr arr1 arr2 1 + af_root arr arr2 arr1 1 -- | Execute rootBatched -- @@ -1619,9 +1613,9 @@ rootBatched -- ^ Use batch -> Array a -- ^ Result of root -rootBatched x y (fromIntegral . fromEnum -> batch) = do +rootBatched x y (fromIntegral . fromEnum -> batch) = x `op2` y $ \arr arr1 arr2 -> - af_root arr arr1 arr2 batch + af_root arr arr2 arr1 batch -- | Execute pow -- @@ -1911,19 +1905,19 @@ log2 = flip op1 af_log2 -- | Execute sqrt -- --- >>> A.sqrt (A.vector @Int 10 [1..]) +-- >>> A.sqrt (A.vector @Int 10 [ x * x | x <- [ 1 .. 10 ]]) -- ArrayFire Array -- [10 1 1 1] -- 1.0000 --- 1.4142 --- 1.7321 -- 2.0000 --- 2.2361 --- 2.4495 --- 2.6458 --- 2.8284 -- 3.0000 --- 3.1623 +-- 4.0000 +-- 5.0000 +-- 6.0000 +-- 7.0000 +-- 8.0000 +-- 9.0000 +-- 10.0000 sqrt :: AFType a => Array a @@ -1934,19 +1928,19 @@ sqrt = flip op1 af_sqrt -- | Execute cbrt -- --- >>> A.cbrt (A.vector @Int 10 [1..]) +-- >>> A.cbrt (A.vector @Int 10 [ x * x * x | x <- [ 1 .. 10 ]]) -- ArrayFire Array -- [10 1 1 1] -- 1.0000 --- 1.2599 --- 1.4422 --- 1.5874 --- 1.7100 --- 1.8171 --- 1.9129 -- 2.0000 --- 2.0801 --- 2.1544 +-- 3.0000 +-- 4.0000 +-- 5.0000 +-- 6.0000 +-- 7.0000 +-- 8.0000 +-- 9.0000 +-- 10.0000 cbrt :: AFType a => Array a @@ -2043,7 +2037,7 @@ isZero :: AFType a => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ Result of calling 'isZero' isZero = (`op1` af_iszero) @@ -2066,7 +2060,7 @@ isInf :: (Real a, AFType a) => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ will contain 1's where input is Inf or -Inf, and 0 otherwise. isInf = (`op1` af_isinf) @@ -2086,9 +2080,9 @@ isInf = (`op1` af_isinf) -- 1 -- 1 isNaN - :: forall a. (AFType a, Real a) + :: (AFType a, Real a) => Array a -- ^ Input array - -> Array a + -> Array CBool -- ^ Will contain 1's where input is NaN, and 0 otherwise. isNaN = (`op1` af_isnan) diff --git a/src/ArrayFire/Array.hs b/src/ArrayFire/Array.hs index b0abc01..a618a47 100644 --- a/src/ArrayFire/Array.hs +++ b/src/ArrayFire/Array.hs @@ -1,3 +1,4 @@ +-------------------------------------------------------------------------------- {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE PolyKinds #-} @@ -10,7 +11,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Array --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -82,6 +83,20 @@ scalar x = mkArray [1] [x] vector :: AFType a => Int -> [a] -> Array a vector n = mkArray [n] . take n +-- | Construct an 'Array' from a flat list with explicit dimensions. +-- +-- Dimensions are in column-major order (first dim varies fastest). +-- Prefer 'fromVector' when data is already in a 'Data.Vector.Storable.Vector' +-- to avoid the intermediate list allocation. +-- +-- >>> fromList [2,3] [1..6 :: Double] +-- ArrayFire Array +-- [2 3 1 1] +-- 1.0000 3.0000 5.0000 +-- 2.0000 4.0000 6.0000 +fromList :: AFType a => [Int] -> [a] -> Array a +fromList = mkArray + -- | Smart constructor for creating a matrix 'Array' -- -- >>> A.matrix @Double (3,2) [[1,2,3],[4,5,6]] @@ -95,8 +110,8 @@ matrix :: AFType a => (Int,Int) -> [[a]] -> Array a matrix (x,y) = mkArray [x,y] . concat - . take y . fmap (take x) + . take y -- | Smart constructor for creating a cubic 'Array' -- @@ -116,9 +131,9 @@ cube (x,y,z) = mkArray [x,y,z] . concat . fmap concat - . take z . fmap (take y) . (fmap . fmap . take) x + . take z -- | Smart constructor for creating a tensor 'Array' -- @@ -140,16 +155,16 @@ cube (x,y,z) -- 2.0000 2.0000 -- 2.0000 2.0000 -- @ -tensor :: AFType a => (Int, Int,Int,Int) -> [[[[a]]]] -> Array a +tensor :: AFType a => (Int,Int,Int,Int) -> [[[[a]]]] -> Array a tensor (w,x,y,z) = mkArray [w,x,y,z] . concat . fmap concat . (fmap . fmap) concat - . take z - . (fmap . take) y - . (fmap . fmap . take) x . (fmap . fmap . fmap . take) w + . (fmap . fmap . take) x + . (fmap . take) y + . take z -- | Internal function for 'Array' construction -- @@ -177,28 +192,75 @@ mkArray -- ^ Returned array {-# NOINLINE mkArray #-} mkArray dims xs = - unsafePerformIO $ do - when (Prelude.length (take size xs) < size) $ do - let msg = "Invalid elements provided. " - <> "Expected " - <> show size - <> " elements received " - <> show (Prelude.length xs) - throwIO (AFException SizeError 203 msg) - dataPtr <- castPtr <$> newArray (Prelude.take size xs) + unsafePerformIO . mask_ $ do let ndims = fromIntegral (Prelude.length dims) - alloca $ \arrayPtr -> do - zeroOutArray arrayPtr + calloca $ \arrayPtr -> do dimsPtr <- newArray (DimT . fromIntegral <$> dims) - throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType - free dataPtr >> free dimsPtr + if size == 0 + then onException + (do throwAFError =<< af_create_handle arrayPtr ndims dimsPtr dType + free dimsPtr) + (free dimsPtr) + else do + when (Prelude.length (Prelude.take size xs) < size) $ do + free dimsPtr + let msg = "Invalid elements provided. " + <> "Expected " + <> show size + <> " elements received " + <> show (Prelude.length xs) + throwIO (AFException SizeError 203 msg) + dataPtr <- castPtr <$> newArray (Prelude.take size xs) + onException + (do throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType + free dataPtr >> free dimsPtr) + (free dataPtr >> free dimsPtr) arr <- peek arrayPtr Array <$> newForeignPtr af_release_array_finalizer arr where size = Prelude.product dims dType = afType (Proxy @array) --- af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type); +-- | Constructs an 'Array' from a 'Storable' 'Vector', avoiding the intermediate list allocation of 'mkArray'. +-- +-- The vector's contiguous buffer is handed straight to @af_create_array@, which +-- copies it into the 'Array' (and uploads to device memory on GPU backends), so +-- no intermediate Haskell list is built. +-- Throws 'AFException' if the vector length does not match the product of the given dimensions. +-- +-- >>> fromVector @Double [3] (Data.Vector.Storable.fromList [1,2,3]) +-- ArrayFire Array +-- [3 1 1 1] +-- 1.0000 +-- 2.0000 +-- 3.0000 +fromVector + :: forall a + . AFType a + => [Int] + -- ^ Dimensions + -> Vector a + -- ^ Source storable vector + -> Array a +{-# NOINLINE fromVector #-} +fromVector dims vec = + unsafePerformIO . mask_ $ do + let size = Prelude.product dims + ndims = fromIntegral (Prelude.length dims) + dType = afType (Proxy @a) + when (V.length vec /= size) $ + throwIO $ AFException SizeError 203 $ + "fromVector: dimension product " <> show size <> + " does not match vector length " <> show (V.length vec) + calloca $ \arrayPtr -> do + dimsPtr <- newArray (DimT . fromIntegral <$> dims) + onException + (V.unsafeWith vec $ \ptr -> do + throwAFError =<< af_create_array arrayPtr (castPtr ptr) ndims dimsPtr dType + free dimsPtr) + (free dimsPtr) + arr <- peek arrayPtr + Array <$> newForeignPtr af_release_array_finalizer arr -- | Copies an 'Array' to a new 'Array' -- @@ -213,8 +275,6 @@ copyArray -> Array a -- ^ Newly copied 'Array' copyArray = (`op1` af_copy_array) --- af_err af_write_array(af_array arr, const void *data, const size_t bytes, af_source src); --- af_err af_get_data_ptr(void *data, const af_array arr); -- | Retains an 'Array', increases reference count -- @@ -233,7 +293,7 @@ retainArray = -- | Retrieves 'Array' reference count -- -- >>> initialArray = scalar @Double 10 --- >>> retainedArray = retain initialArray +-- >>> retainedArray = retainArray initialArray -- >>> getDataRefCount retainedArray -- 2 -- @@ -246,8 +306,17 @@ getDataRefCount getDataRefCount = fromIntegral . (`infoFromArray` af_get_data_ref_count) --- af_err af_eval(af_array in); --- af_err af_eval_multiple(const int num, af_array *arrays); +-- | Force evaluation of a lazily-deferred 'Array', flushing any pending +-- computation in the JIT queue and returning the same array. +-- +-- >>> eval (vector @Double 10 [1..]) +-- ArrayFire Array +-- ... +-- +eval :: AFType a => Array a -> Array a +eval arr@(Array fptr) = unsafePerformIO . mask_ $ + withForeignPtr fptr (throwAFError <=< af_eval) >> pure arr +{-# NOINLINE eval #-} -- | Should manual evaluation occur -- @@ -479,11 +548,12 @@ isSparse a = toEnum . fromIntegral $ (a `infoFromArray` af_is_sparse) -- >>> toVector (vector @Double 10 [1..]) -- [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0] toVector :: forall a . AFType a => Array a -> Vector a -toVector arr@(Array fptr) = do +{-# NOINLINE toVector #-} +toVector arr@(Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do let len = getElements arr size = len * getSizeOf (Proxy @a) - ptr <- mallocBytes (len * size) + ptr <- mallocBytes size throwAFError =<< af_get_data_ptr (castPtr ptr) arrPtr newFptr <- newForeignPtr finalizerFree ptr pure $ unsafeFromForeignPtr0 newFptr len @@ -500,6 +570,7 @@ toList = V.toList . toVector -- >>> getScalar (scalar @Double 22.0) :: Double -- 22.0 getScalar :: forall a b . (Storable a, AFType b) => Array b -> a +{-# NOINLINE getScalar #-} getScalar (Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do alloca $ \ptr -> do diff --git a/src/ArrayFire/BLAS.hs b/src/ArrayFire/BLAS.hs index 321980a..8deb283 100644 --- a/src/ArrayFire/BLAS.hs +++ b/src/ArrayFire/BLAS.hs @@ -1,8 +1,10 @@ +-------------------------------------------------------------------------------- +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.BLAS --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -31,8 +33,16 @@ -------------------------------------------------------------------------------- module ArrayFire.BLAS where +import Control.Exception (mask_) import Data.Complex +import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.Marshal.Alloc (alloca) +import Foreign.Marshal.Utils (fillBytes) +import Foreign.Ptr (Ptr, castPtr) +import Foreign.Storable (peek, poke, sizeOf) +import System.IO.Unsafe (unsafePerformIO) +import ArrayFire.Exception import ArrayFire.FFI import ArrayFire.Internal.BLAS import ArrayFire.Internal.Types @@ -69,6 +79,16 @@ matmul matmul arr1 arr2 prop1 prop2 = do op2 arr1 arr2 (\p a b -> af_matmul p a b (toMatProp prop1) (toMatProp prop2)) +-- | Plain matrix multiplication — shorthand for @'matmul' a b 'None' 'None'@. +-- +-- >>> mm (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) +-- ArrayFire Array +-- [2 2 1 1] +-- 3.0000 5.0000 +-- 4.0000 6.0000 +mm :: AFType a => Array a -> Array a -> Array a +mm a b = matmul a b None None + -- | Scalar dot product between two vectors. Also referred to as the inner product. -- -- >>> dot (vector @Double 10 [1..]) (vector @Double 10 [1..]) None None @@ -139,6 +159,16 @@ transpose transpose arr1 (fromIntegral . fromEnum -> b) = arr1 `op1` (\x y -> af_transpose x y b) +-- | Real (non-conjugate) transpose — shorthand for @'transpose' a False@. +-- +-- >>> tr (matrix @Double (2,3) [[1,2],[3,4],[5,6]]) +-- ArrayFire Array +-- [3 2 1 1] +-- 1.0000 3.0000 5.0000 +-- 2.0000 4.0000 6.0000 +tr :: AFType a => Array a -> Array a +tr a = transpose a False + -- | Transposes a matrix. -- -- * Warning: This function mutates an array in-place, all subsequent references will be changed. Use carefully. @@ -167,3 +197,40 @@ transposeInPlace -> IO () transposeInPlace arr (fromIntegral . fromEnum -> b) = arr `inPlace` (`af_transpose_inplace` b) + +-- | General Matrix Multiply: C = alpha * op(A) * op(B) +-- +-- More general than 'matmul': supports per-element scaling and optional +-- transposition via 'MatProp'. +-- +-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) +-- ArrayFire Array +-- [2 2 1 1] +-- 3.0000 5.0000 +-- 4.0000 6.0000 +gemm + :: forall a . AFType a + => MatProp + -- ^ Transformation applied to A ('None', 'Trans', or 'CTrans') + -> MatProp + -- ^ Transformation applied to B ('None', 'Trans', or 'CTrans') + -> a + -- ^ Scalar alpha + -> Array a + -- ^ Matrix A + -> Array a + -- ^ Matrix B + -> Array a + -- ^ Result C = alpha * op(A) * op(B) +gemm opA opB alpha (Array fptrA) (Array fptrB) = + unsafePerformIO . mask_ $ + withForeignPtr fptrA $ \ptrA -> + withForeignPtr fptrB $ \ptrB -> + calloca $ \pOut -> + alloca $ \pAlpha -> + alloca $ \(pBeta :: Ptr a) -> do + poke pAlpha alpha + fillBytes pBeta 0 (sizeOf alpha) + throwAFError =<< af_gemm pOut (toMatProp opA) (toMatProp opB) (castPtr pAlpha) ptrA ptrB (castPtr pBeta) + Array <$> (newForeignPtr af_release_array_finalizer =<< peek pOut) +{-# NOINLINE gemm #-} diff --git a/src/ArrayFire/Backend.hs b/src/ArrayFire/Backend.hs index 7b9b14f..1abdc47 100644 --- a/src/ArrayFire/Backend.hs +++ b/src/ArrayFire/Backend.hs @@ -1,7 +1,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Backend --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index 8bcfe54..3201988 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -10,7 +10,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Data --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -30,6 +30,7 @@ module ArrayFire.Data where import Control.Exception +import Control.Monad (when) import Data.Complex import Data.Int import Data.Proxy @@ -42,13 +43,37 @@ import Foreign.Storable import System.IO.Unsafe import Unsafe.Coerce +import Data.Bits + import ArrayFire.Exception import ArrayFire.FFI +import ArrayFire.Internal.Array (af_get_dims) import ArrayFire.Internal.Data import ArrayFire.Internal.Defines import ArrayFire.Internal.Types import ArrayFire.Arith +-- | Bitwise complement of every element in an 'Array' +-- +-- >>> A.bitNot (A.scalar @Int32 0) +-- ArrayFire Array +-- [1 1 1 1] +-- -1 +bitNot + :: (AFType a, Bits a) + => Array a + -> Array a +bitNot arr = arr `bitXor` ones + where + (d0, d1, d2, d3) = arr `infoFromArray4` af_get_dims + ones = constant + [ fromIntegral d0 + , fromIntegral d1 + , fromIntegral d2 + , fromIntegral d3 + ] + (complement zeroBits) + -- | Creates an 'Array' from a scalar value from given dimensions -- -- >>> constant @Double [2,2] 2.0 @@ -63,6 +88,7 @@ constant -> a -- ^ Scalar value -> Array a +{-# NOINLINE constant #-} constant dims val = case dtyp of x | x == c64 -> @@ -101,8 +127,7 @@ constant dims val = -> Array Double constant' dims' val' = unsafePerformIO . mask_ $ do - ptr <- alloca $ \ptrPtr -> do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims') $ \dimArray -> do throwAFError =<< af_constant ptrPtr val' n dimArray typ peek ptrPtr @@ -128,8 +153,7 @@ constant dims val = -- ^ Scalar val'ue -> Array (Complex arr) constantComplex dims' ((realToFrac -> x) :+ (realToFrac -> y)) = unsafePerformIO . mask_ $ do - ptr <- alloca $ \ptrPtr -> do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims') $ \dimArray -> do throwAFError =<< af_constant_complex ptrPtr x y n dimArray typ peek ptrPtr @@ -154,8 +178,7 @@ constant dims val = -- ^ Scalar val'ue -> Array Int constantLong dims' val' = unsafePerformIO . mask_ $ do - ptr <- alloca $ \ptrPtr -> do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims') $ \dimArray -> do throwAFError =<< af_constant_long ptrPtr (fromIntegral val') n dimArray peek ptrPtr @@ -177,8 +200,7 @@ constant dims val = -> Word64 -> Array Word64 constantULong dims' val' = unsafePerformIO . mask_ $ do - ptr <- alloca $ \ptrPtr -> do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims') $ \dimArray -> do throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val') n dimArray peek ptrPtr @@ -191,7 +213,7 @@ constant dims val = -- | Creates a range of values in an Array -- --- >>> range @Double [10] (-1) +-- >>> arange @Double [10] (-1) -- ArrayFire Array -- [10 1 1 1] -- 0.0000 @@ -204,14 +226,15 @@ constant dims val = -- 7.0000 -- 8.0000 -- 9.0000 -range +arange :: forall a . AFType a => [Int] -> Int -> Array a -range dims (fromIntegral -> k) = unsafePerformIO $ do - ptr <- alloca $ \ptrPtr -> mask_ $ do +{-# NOINLINE arange #-} +arange dims (fromIntegral -> k) = unsafePerformIO . mask_ $ do + ptr <- alloca $ \ptrPtr -> do withArray (fromIntegral <$> dims) $ \dimArray -> do throwAFError =<< af_range ptrPtr n dimArray k typ peek ptrPtr @@ -252,11 +275,11 @@ iota -- ^ is array containing the number of repetitions of the unit dimensions -> Array a -- ^ is the generated array -iota dims tdims = unsafePerformIO $ do +{-# NOINLINE iota #-} +iota dims tdims = unsafePerformIO . mask_ $ do let dims' = take 4 (dims ++ repeat 1) tdims' = take 4 (tdims ++ repeat 1) - ptr <- alloca $ \ptrPtr -> mask_ $ do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims') $ \dimArray -> withArray (fromIntegral <$> tdims') $ \tdimArray -> do throwAFError =<< af_iota ptrPtr 4 dimArray 4 tdimArray typ @@ -280,10 +303,16 @@ identity => [Int] -- ^ Dimensions -> Array a +{-# NOINLINE identity #-} identity dims = unsafePerformIO . mask_ $ do + when (length dims > 4) $ + throwIO AFException + { afExceptionType = ArgError + , afExceptionCode = 202 + , afExceptionMsg = "identity: ndims must be <= 4" + } let dims' = take 4 (dims ++ repeat 1) - ptr <- alloca $ \ptrPtr -> mask_ $ do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> mask_ $ do withArray (fromIntegral <$> dims') $ \dimArray -> do throwAFError =<< af_identity ptrPtr n dimArray typ peek ptrPtr @@ -303,7 +332,7 @@ identity dims = unsafePerformIO . mask_ $ do -- 1.0000 0.0000 -- 0.0000 2.0000 diagCreate - :: AFType (a :: *) + :: AFType a => Array a -- ^ is the input array which is the diagonal -> Int @@ -320,7 +349,7 @@ diagCreate x (fromIntegral -> n) = -- 1.0000 -- 4.0000 diagExtract - :: AFType (a :: *) + :: AFType a => Array a -> Int -> Array a @@ -339,27 +368,29 @@ diagExtract x (fromIntegral -> n) = -- join :: Int - -> Array (a :: *) + -> Array a -> Array a -> Array a join (fromIntegral -> n) arr1 arr2 = op2 arr1 arr2 (\p a b -> af_join p n a b) -- | Join many Arrays together along a specified dimension -- --- *FIX ME* --- --- >>> joinMany 0 [1,2,3] +-- >>> joinMany 0 [vector @Int 3 [1..], vector @Int 3 [1..]] -- ArrayFire Array --- [3 1 1 1] --- 1.0000 2.0000 3.0000 --- +-- [6 1 1 1] +-- 1 +-- 2 +-- 3 +-- 1 +-- 2 +-- 3 joinMany :: Int -> [Array a] -> Array a +{-# NOINLINE joinMany #-} joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do - newPtr <- alloca $ \aPtr -> do - zeroOutArray aPtr + newPtr <- calloca $ \aPtr -> do (throwAFError =<<) $ withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr -> af_join_many aPtr n nArrays fPtrsPtr @@ -367,6 +398,10 @@ joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerfor Array <$> newForeignPtr af_release_array_finalizer newPtr +-- | Marshals a list of 'ForeignPtr' into a temporary, contiguous C array of +-- raw pointers, keeping every 'ForeignPtr' alive for the duration of the +-- action. The continuation receives the number of pointers and a pointer to +-- the array. withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b withManyForeignPtr fptrs action = go [] fptrs where @@ -385,7 +420,7 @@ withManyForeignPtr fptrs action = go [] fptrs -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- tile - :: Array (a :: *) + :: Array a -> [Int] -> Array a tile a (take 4 . (++repeat 1) -> [x,y,z,w]) = @@ -406,12 +441,15 @@ tile _ _ = error "impossible" -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- reorder - :: Array (a :: *) + :: Array a -> [Int] -> Array a -reorder a (take 4 . (++ repeat 0) -> [x,y,z,w]) = - a `op1` (\p k -> af_reorder p k (fromIntegral x) (fromIntegral y) (fromIntegral z) (fromIntegral w)) -reorder _ _ = error "impossible" +reorder a dims = + let base = take 4 dims + padding = filter (`notElem` base) [0..3] + in case take 4 (base ++ padding) of + [x,y,z,w] -> a `op1` (\p k -> af_reorder p k (fromIntegral x) (fromIntegral y) (fromIntegral z) (fromIntegral w)) + _ -> error "impossible" -- | Shift elements in an Array along a specified dimension (elements will wrap). -- @@ -424,7 +462,7 @@ reorder _ _ = error "impossible" -- 2.0000 -- shift - :: Array (a :: *) + :: Array a -> Int -> Int -> Int @@ -441,14 +479,13 @@ shift a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegra -- 1.0000 2.0000 3.0000 -- moddims - :: forall a - . Array (a :: *) + :: Array a -> [Int] -> Array a +{-# NOINLINE moddims #-} moddims (Array fptr) dims = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do - newPtr <- alloca $ \aPtr -> do - zeroOutArray aPtr + newPtr <- calloca $ \aPtr -> do withArray (fromIntegral <$> dims) $ \dimsPtr -> do throwAFError =<< af_moddims aPtr ptr n dimsPtr peek aPtr diff --git a/src/ArrayFire/Device.hs b/src/ArrayFire/Device.hs index 29a9e63..1d2a979 100644 --- a/src/ArrayFire/Device.hs +++ b/src/ArrayFire/Device.hs @@ -2,7 +2,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Device --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -18,10 +18,27 @@ -------------------------------------------------------------------------------- module ArrayFire.Device where +import Control.Exception (finally) import Foreign.C.String import ArrayFire.Internal.Device import ArrayFire.FFI +foreign import ccall unsafe "af_notify_shutdown" + afNotifyShutdown :: IO () + +-- | Bracket for ArrayFire usage. Wrap your @main@ (or top-level IO action) +-- with this to ensure the safe-finalizer shutdown flag is set before GHC's +-- finalizer thread runs, preventing a "double free or corruption" abort when +-- GC-managed array handles outlive ArrayFire's C++ allocator teardown. +-- +-- @ +-- main :: IO () +-- main = withArrayFire $ do +-- ... +-- @ +withArrayFire :: IO a -> IO a +withArrayFire action = action `finally` afNotifyShutdown + -- | Retrieve info from ArrayFire API -- -- @ @@ -46,8 +63,6 @@ afInit = afCall af_init getInfoString :: IO String getInfoString = peekCString =<< afCall1 (flip af_info_string 1) --- af_err af_device_info(char* d_name, char* d_platform, char *d_toolkit, char* d_compute); - -- | Retrieves count of devices -- -- >>> getDeviceCount @@ -55,7 +70,6 @@ getInfoString = peekCString =<< afCall1 (flip af_info_string 1) getDeviceCount :: IO Int getDeviceCount = fromIntegral <$> afCall1 af_get_device_count --- af_err af_get_dbl_support(bool* available, const int device); -- | Sets a device by 'Int' -- -- >>> setDevice 0 @@ -69,22 +83,3 @@ setDevice (fromIntegral -> x) = afCall (af_set_device x) -- 0 getDevice :: IO Int getDevice = fromIntegral <$> afCall1 af_get_device - --- af_err af_sync(const int device); --- af_err af_alloc_device(void **ptr, const dim_t bytes); --- af_err af_free_device(void *ptr); --- af_err af_alloc_pinned(void **ptr, const dim_t bytes); --- af_err af_free_pinned(void *ptr); --- af_err af_alloc_host(void **ptr, const dim_t bytes); --- af_err af_free_host(void *ptr); --- af_err af_device_array(af_array *arr, const void *data, const unsigned ndims, const dim_t * const dims, const af_dtype type); --- af_err af_device_mem_info(size_t *alloc_bytes, size_t *alloc_buffers, size_t *lock_bytes, size_t *lock_buffers); --- af_err af_print_mem_info(const char *msg, const int device_id); --- af_err af_device_gc(); --- af_err af_set_mem_step_size(const size_t step_bytes); --- af_err af_get_mem_step_size(size_t *step_bytes); --- af_err af_lock_device_ptr(const af_array arr); --- af_err af_unlock_device_ptr(const af_array arr); --- af_err af_lock_array(const af_array arr); --- af_err af_is_locked_array(bool *res, const af_array arr); --- af_err af_get_device_ptr(void **ptr, const af_array arr); diff --git a/src/ArrayFire/Exception.hs b/src/ArrayFire/Exception.hs index bc8a12d..b647760 100644 --- a/src/ArrayFire/Exception.hs +++ b/src/ArrayFire/Exception.hs @@ -3,7 +3,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Exception --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -114,5 +114,6 @@ foreign import ccall unsafe "&af_release_random_engine" foreign import ccall unsafe "&af_destroy_window" af_release_window_finalizer :: FunPtr (AFWindow -> IO ()) -foreign import ccall unsafe "&af_release_array" +foreign import ccall unsafe "&af_release_array_safe" af_release_array_finalizer :: FunPtr (AFArray -> IO ()) + diff --git a/src/ArrayFire/FFI.hs b/src/ArrayFire/FFI.hs index e776ace..254cdc6 100644 --- a/src/ArrayFire/FFI.hs +++ b/src/ArrayFire/FFI.hs @@ -4,12 +4,18 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.FFI --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- +-- Internal marshalling combinators that bridge the high-level API modules and +-- the raw @ArrayFire.Internal.*@ FFI bindings. Each combinator unwraps the +-- managed handles ('Array', 'Window', 'Features', 'RandomEngine'), allocates +-- the output pointers, invokes the supplied C function, checks the returned +-- 'AFErr' with 'throwAFError', and attaches the appropriate finalizer to any +-- newly-created handle. These helpers are not part of the public API. -------------------------------------------------------------------------------- module ArrayFire.FFI where @@ -28,8 +34,24 @@ import Foreign.Storable import Foreign.Ptr import Foreign.C import Foreign.Marshal.Alloc +import Foreign.Marshal.Utils (fillBytes) import System.IO.Unsafe +-- | Like 'alloca' but zero-initialises the memory before handing the pointer +-- to the continuation. Prevents uninitialized stack garbage from leaking into +-- output scalars when the C function does not write the imaginary-part pointer +-- for real-valued arrays (e.g. af_mean_all_weighted). +calloca :: forall a b. Storable a => (Ptr a -> IO b) -> IO b +calloca f = alloca $ \p -> fillBytes p 0 (sizeOf (undefined :: a)) >> f p + +foreign import ccall unsafe "af_cast" + af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr + +foreign import ccall unsafe "af_release_array" + af_release_array_ffi :: AFArray -> IO AFErr + +-- | Applies a C function that takes three input 'Array's and produces a single +-- output 'Array'. op3 :: Array b -> Array a @@ -38,17 +60,19 @@ op3 -> Array a {-# NOINLINE op3 #-} op3 (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 ptr2 ptr3 peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Like 'op3', but specialised to two 'Int32' index 'Array's alongside the +-- primary input. op3Int :: Array a -> Array Int32 @@ -57,17 +81,19 @@ op3Int -> Array a {-# NOINLINE op3Int #-} op3Int (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 ptr2 ptr3 peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Applies a C function that takes two input 'Array's and produces a single +-- output 'Array'. op2 :: Array b -> Array a @@ -75,16 +101,18 @@ op2 -> Array c {-# NOINLINE op2 #-} op2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 ptr2 peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Like 'op2', but for comparison operations whose output 'Array' holds +-- boolean ('CBool') values. op2bool :: Array b -> Array a @@ -92,44 +120,48 @@ op2bool -> Array CBool {-# NOINLINE op2bool #-} op2bool (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 ptr2 peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Applies a C function that takes one input 'Array' and produces a pair of +-- output 'Array's. op2p :: Array a -> (Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr) - -> (Array a, Array a) + -> (Array a, Array b) {-# NOINLINE op2p #-} op2p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> do throwAFError =<< op ptrInput1 ptrInput2 ptr1 (,) <$> peek ptrInput1 <*> peek ptrInput2 fptrA <- newForeignPtr af_release_array_finalizer x fptrB <- newForeignPtr af_release_array_finalizer y pure (Array fptrA, Array fptrB) +-- | Applies a C function that takes one input 'Array' and produces a triple of +-- output 'Array's (e.g. an SVD or LU decomposition). op3p :: Array a -> (Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr) -> (Array a, Array a, Array a) {-# NOINLINE op3p #-} op3p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z) <- withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do - alloca $ \ptrInput3 -> do + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> + calloca $ \ptrInput3 -> do throwAFError =<< op ptrInput1 ptrInput2 ptrInput3 ptr1 (,,) <$> peek ptrInput1 <*> peek ptrInput2 <*> peek ptrInput3 fptrA <- newForeignPtr af_release_array_finalizer x @@ -137,6 +169,8 @@ op3p (Array fptr1) op = fptrC <- newForeignPtr af_release_array_finalizer z pure (Array fptrA, Array fptrB, Array fptrC) +-- | Like 'op3p', but the C function also writes back a single 'Storable' +-- scalar in addition to the three output 'Array's. op3p1 :: Storable b => Array a @@ -144,11 +178,11 @@ op3p1 -> (Array a, Array a, Array a, b) {-# NOINLINE op3p1 #-} op3p1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z,g) <- withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do - alloca $ \ptrInput3 -> do + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> + calloca $ \ptrInput3 -> alloca $ \ptrInput4 -> do throwAFError =<< op ptrInput1 ptrInput2 ptrInput3 ptrInput4 ptr1 (,,,) <$> peek ptrInput1 @@ -160,6 +194,8 @@ op3p1 (Array fptr1) op = fptrC <- newForeignPtr af_release_array_finalizer z pure (Array fptrA, Array fptrB, Array fptrC, g) +-- | Applies a C function that takes two input 'Array's and produces a pair of +-- output 'Array's. op2p2 :: Array a -> Array a @@ -167,18 +203,58 @@ op2p2 -> (Array a, Array a) {-# NOINLINE op2p2 #-} op2p2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do withForeignPtr fptr2 $ \ptr2 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> do throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2 (,) <$> peek ptrInput1 <*> peek ptrInput2 fptrA <- newForeignPtr af_release_array_finalizer x fptrB <- newForeignPtr af_release_array_finalizer y pure (Array fptrA, Array fptrB) +-- | Key/value variant of 'op2p2' used by sort-by-key operations. The input key +-- 'Array' is cast down to @s32@ before the C call (ArrayFire requires 32-bit +-- keys) and the resulting key 'Array' is cast back up to @s64@, releasing the +-- intermediate handles along the way. +op2p2kv + :: Array Int + -> Array a + -> (Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> IO AFErr) + -> (Array Int, Array b) +{-# NOINLINE op2p2kv #-} +op2p2kv (Array fptr1) (Array fptr2) op = + unsafePerformIO . mask_ $ do + (x, y) <- + withForeignPtr fptr1 $ \ptr1 -> + withForeignPtr fptr2 $ \ptr2 -> do + castedKey <- calloca $ \p -> do + throwAFError =<< af_cast p ptr1 s32 + peek p + calloca $ \ptrOutput1 -> + calloca $ \ptrOutput2 -> do + onException + (throwAFError =<< op ptrOutput1 ptrOutput2 castedKey ptr2) + (af_release_array_ffi castedKey) + _ <- af_release_array_ffi castedKey + outKey <- peek ptrOutput1 + outVal <- peek ptrOutput2 + finalKey <- calloca $ \p -> do + onException + (throwAFError =<< af_cast p outKey s64) + (af_release_array_ffi outKey >> af_release_array_ffi outVal) + peek p + _ <- af_release_array_ffi outKey + pure (finalKey, outVal) + fptrA <- newForeignPtr af_release_array_finalizer x + fptrB <- newForeignPtr af_release_array_finalizer y + pure (Array fptrA, Array fptrB) + +-- | Runs a C function that constructs a fresh 'Array' (taking no input +-- 'Array'), returning the result in 'IO'. The output pointer is zeroed before +-- the call so the finalizer is safe even if construction fails. createArray' :: (Ptr AFArray -> IO AFErr) -> IO (Array a) @@ -186,13 +262,15 @@ createArray' createArray' op = mask_ $ do ptr <- - alloca $ \ptrInput -> do - zeroOutArray ptrInput + calloca $ \ptrInput -> do throwAFError =<< op ptrInput peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Pure counterpart of 'createArray'' for constructing an 'Array' from a C +-- function that takes no input 'Array'. The effect is hidden behind +-- 'unsafePerformIO'. createArray :: (Ptr AFArray -> IO AFErr) -> Array a @@ -200,25 +278,28 @@ createArray createArray op = unsafePerformIO . mask_ $ do ptr <- - alloca $ \ptrInput -> do - zeroOutArray ptrInput + calloca $ \ptrInput -> do throwAFError =<< op ptrInput peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Runs a C function that constructs a 'Window' handle, attaching the +-- window-release finalizer to the result. createWindow' :: (Ptr AFWindow -> IO AFErr) -> IO Window createWindow' op = mask_ $ do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput peek ptrInput fptr <- newForeignPtr af_release_window_finalizer ptr pure (Window fptr) +-- | Runs a C function against an existing 'Window' for its side effects, +-- returning unit. opw :: Window -> (AFWindow -> IO AFErr) @@ -226,6 +307,8 @@ opw opw (Window fptr) op = mask_ . withForeignPtr fptr $ (throwAFError <=< op) +-- | Runs a C function against an existing 'Window' that writes back a single +-- 'Storable' value, returning it. opw1 :: Storable a => Window @@ -238,37 +321,25 @@ opw1 (Window fptr) op throwAFError =<< op p ptr peek p -op1d - :: Array a - -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array b -{-# NOINLINE op1d #-} -op1d (Array fptr1) op = - unsafePerformIO $ do - withForeignPtr fptr1 $ \ptr1 -> do - ptr <- - alloca $ \ptrInput -> do - throwAFError =<< op ptrInput ptr1 - peek ptrInput - fptr <- newForeignPtr af_release_array_finalizer ptr - pure (Array fptr) - - +-- | Applies a C function that takes a single input 'Array' and produces a +-- single output 'Array'. op1 :: Array a -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array a + -> Array b {-# NOINLINE op1 #-} op1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 peek ptrInput fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +-- | Applies a C function that takes a single input 'Features' and produces a +-- new 'Features' handle. op1f :: Features -> (Ptr AFFeatures -> AFFeatures -> IO AFErr) @@ -278,12 +349,14 @@ op1f (Features x) op = unsafePerformIO . mask_ $ do withForeignPtr x $ \ptr1 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 peek ptrInput fptr <- newForeignPtr af_release_features ptr pure (Features fptr) +-- | Applies a C function that takes a single input 'RandomEngine' and produces +-- a new 'RandomEngine' handle, returned in 'IO'. op1re :: RandomEngine -> (Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr) @@ -291,12 +364,15 @@ op1re op1re (RandomEngine x) op = mask_ $ withForeignPtr x $ \ptr1 -> do ptr <- - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 peek ptrInput fptr <- newForeignPtr af_release_random_engine_finalizer ptr pure (RandomEngine fptr) +-- | Applies a C function that takes a single input 'Array' and produces both a +-- 'Storable' scalar and an output 'Array' (e.g. an operation returning a value +-- and its location). op1b :: Storable b => Array a @@ -304,21 +380,26 @@ op1b -> (b, Array a) {-# NOINLINE op1b #-} op1b (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> do (y,x) <- - alloca $ \ptrInput1 -> do + calloca $ \ptrInput1 -> alloca $ \ptrInput2 -> do throwAFError =<< op ptrInput1 ptrInput2 ptr1 (,) <$> peek ptrInput1 <*> peek ptrInput2 fptr <- newForeignPtr af_release_array_finalizer y pure (x, Array fptr) +-- | Runs an 'AFErr'-returning C action purely for its side effects, throwing +-- on a non-success status. afCall :: IO AFErr -> IO () afCall = mask_ . (throwAFError =<<) +-- | Loads an image from the given file path into a new 'Array'. The 'Bool' +-- flag selects whether the image is loaded in colour, and is marshalled to the +-- 'CBool' expected by the C function. loadAFImage :: String -> Bool @@ -326,32 +407,38 @@ loadAFImage -> IO (Array a) loadAFImage s (fromIntegral . fromEnum -> b) op = mask_ $ withCString s $ \cstr -> do - p <- alloca $ \ptr -> do + p <- calloca $ \ptr -> do throwAFError =<< op ptr cstr b peek ptr fptr <- newForeignPtr af_release_array_finalizer p pure (Array fptr) +-- | Loads an image from the given file path into a new 'Array' in its native +-- format, without any colour-space conversion. loadAFImageNative :: String -> (Ptr AFArray -> CString -> IO AFErr) -> IO (Array a) loadAFImageNative s op = mask_ $ withCString s $ \cstr -> do - p <- alloca $ \ptr -> do + p <- calloca $ \ptr -> do throwAFError =<< op ptr cstr peek ptr fptr <- newForeignPtr af_release_array_finalizer p pure (Array fptr) +-- | Runs a C function that mutates an 'Array' in place, returning unit. inPlace :: Array a -> (AFArray -> IO AFErr) -> IO () inPlace (Array fptr) op = mask_ . withForeignPtr fptr $ (throwAFError <=< op) +-- | Runs a C function that mutates a 'RandomEngine' in place, returning unit. inPlaceEng :: RandomEngine -> (AFRandomEngine -> IO AFErr) -> IO () inPlaceEng (RandomEngine fptr) op = mask_ . withForeignPtr fptr $ (throwAFError <=< op) +-- | Runs a C function that writes back a single 'Storable' value through an +-- output pointer, returning that value in 'IO'. afCall1 :: Storable a => (Ptr a -> IO AFErr) @@ -361,6 +448,8 @@ afCall1 op = throwAFError =<< op ptrInput peek ptrInput +-- | Pure counterpart of 'afCall1' for reading back a single 'Storable' value. +-- The effect is hidden behind 'unsafePerformIO'. afCall1' :: Storable a => (Ptr a -> IO AFErr) @@ -382,13 +471,15 @@ featuresToArray featuresToArray (Features fptr1) op = unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput -> do + calloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 - alloca $ \retainedArray -> do + calloca $ \retainedArray -> do throwAFError =<< af_retain_array retainedArray =<< peek ptrInput fptr <- newForeignPtr af_release_array_finalizer =<< peek retainedArray pure (Array fptr) +-- | Reads back a single 'Storable' scalar describing a 'Features' handle (for +-- example its feature count), hiding the effect behind 'unsafePerformIO'. infoFromFeatures :: Storable a => Features @@ -396,12 +487,14 @@ infoFromFeatures -> a {-# NOINLINE infoFromFeatures #-} infoFromFeatures (Features fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 peek ptrInput +-- | Reads back a single 'Storable' scalar describing a 'RandomEngine' (for +-- example its seed or type), returning it in 'IO'. infoFromRandomEngine :: Storable a => RandomEngine @@ -414,6 +507,7 @@ infoFromRandomEngine (RandomEngine fptr1) op = throwAFError =<< op ptrInput ptr1 peek ptrInput +-- | Saves an 'Array' to the given file path using the supplied C function. afSaveImage :: Array b -> String @@ -424,6 +518,8 @@ afSaveImage (Array fptr1) str op = withForeignPtr fptr1 $ throwAFError <=< op cstr +-- | Reads back a single 'Storable' scalar describing an 'Array' (for example a +-- dimension or count), hiding the effect behind 'unsafePerformIO'. infoFromArray :: Storable a => Array b @@ -431,59 +527,67 @@ infoFromArray -> a {-# NOINLINE infoFromArray #-} infoFromArray (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput -> do throwAFError =<< op ptrInput ptr1 peek ptrInput +-- | Like 'infoFromArray', but reads back a pair of 'Storable' scalars from a +-- single input 'Array'. infoFromArray2 - :: (Storable a, Storable b) + :: forall a b arr. (Storable a, Storable b) => Array arr -> (Ptr a -> Ptr b -> AFArray -> IO AFErr) -> (a,b) {-# NOINLINE infoFromArray2 #-} infoFromArray2 (Array fptr1) op = - unsafePerformIO $ do - withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do + unsafePerformIO . mask_ $ do + withForeignPtr fptr1 $ \ptr1 -> + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> do throwAFError =<< op ptrInput1 ptrInput2 ptr1 (,) <$> peek ptrInput1 <*> peek ptrInput2 +-- | Like 'infoFromArray2', but reads back a pair of 'Storable' scalars derived +-- from two input 'Array's. infoFromArray22 - :: (Storable a, Storable b) + :: forall a b arr. (Storable a, Storable b) => Array arr -> Array arr -> (Ptr a -> Ptr b -> AFArray -> AFArray -> IO AFErr) -> (a,b) {-# NOINLINE infoFromArray22 #-} infoFromArray22 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do - withForeignPtr fptr1 $ \ptr1 -> do - withForeignPtr fptr2 $ \ptr2 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do - throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2 - (,) <$> peek ptrInput1 <*> peek ptrInput2 + unsafePerformIO . mask_ $ do + withForeignPtr fptr1 $ \ptr1 -> + withForeignPtr fptr2 $ \ptr2 -> + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> do + throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2 + (,) <$> peek ptrInput1 <*> peek ptrInput2 +-- | Like 'infoFromArray', but reads back three 'Storable' scalars from a +-- single input 'Array'. infoFromArray3 - :: (Storable a, Storable b, Storable c) + :: forall a b c arr. (Storable a, Storable b, Storable c) => Array arr -> (Ptr a -> Ptr b -> Ptr c -> AFArray -> IO AFErr) -> (a,b,c) {-# NOINLINE infoFromArray3 #-} infoFromArray3 (Array fptr1) op = - unsafePerformIO $ - withForeignPtr fptr1 $ \ptr1 -> do - alloca $ \ptrInput1 -> do - alloca $ \ptrInput2 -> do - alloca $ \ptrInput3 -> do + unsafePerformIO . mask_ $ + withForeignPtr fptr1 $ \ptr1 -> + calloca $ \ptrInput1 -> + calloca $ \ptrInput2 -> + calloca $ \ptrInput3 -> do throwAFError =<< op ptrInput1 ptrInput2 ptrInput3 ptr1 (,,) <$> peek ptrInput1 <*> peek ptrInput2 <*> peek ptrInput3 +-- | Like 'infoFromArray', but reads back four 'Storable' scalars from a single +-- input 'Array' (for example all four dimensions). infoFromArray4 :: (Storable a, Storable b, Storable c, Storable d) => Array arr @@ -491,7 +595,7 @@ infoFromArray4 -> (a,b,c,d) {-# NOINLINE infoFromArray4 #-} infoFromArray4 (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> alloca $ \ptrInput1 -> alloca $ \ptrInput2 -> @@ -503,5 +607,3 @@ infoFromArray4 (Array fptr1) op = <*> peek ptrInput3 <*> peek ptrInput4 -foreign import ccall unsafe "zeroOutArray" - zeroOutArray :: Ptr AFArray -> IO () diff --git a/src/ArrayFire/Features.hs b/src/ArrayFire/Features.hs index a84f58d..7e6cf34 100644 --- a/src/ArrayFire/Features.hs +++ b/src/ArrayFire/Features.hs @@ -2,7 +2,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Features --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -17,6 +17,7 @@ -------------------------------------------------------------------------------- module ArrayFire.Features where +import Control.Exception (mask_) import Foreign.Marshal import Foreign.Storable import Foreign.ForeignPtr @@ -34,8 +35,9 @@ import ArrayFire.Exception createFeatures :: Int -> Features +{-# NOINLINE createFeatures #-} createFeatures (fromIntegral -> n) = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do ptr <- alloca $ \ptrInput -> do throwAFError =<< ptrInput `af_create_features` n diff --git a/src/ArrayFire/Graphics.hs b/src/ArrayFire/Graphics.hs index e657625..12cb55f 100644 --- a/src/ArrayFire/Graphics.hs +++ b/src/ArrayFire/Graphics.hs @@ -2,7 +2,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Graphics --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -492,13 +492,13 @@ drawVectorField2d -> Array a -- ^ is an 'Array' with the x-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the x-axis directions -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis directions -> Cell - -- ^ is the window handle + -- ^ is structure 'Cell' that has the properties that are used for the current rendering. -> IO () drawVectorField2d (Window w) (Array fptr1) (Array fptr2) (Array fptr3) (Array fptr4) cell = mask_ $ do diff --git a/src/ArrayFire/Image.hs b/src/ArrayFire/Image.hs index 9ae11d8..d8937cb 100644 --- a/src/ArrayFire/Image.hs +++ b/src/ArrayFire/Image.hs @@ -4,7 +4,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Image --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -25,7 +25,6 @@ import Data.Word import ArrayFire.Internal.Types import ArrayFire.Internal.Image import ArrayFire.FFI -import ArrayFire.Arith -- | Calculates the gradient of an image -- @@ -260,7 +259,7 @@ histogram -> Array Word32 -- ^ (type u32) is the histogram for input array in histogram a (fromIntegral -> b) c d = - cast (a `op1` (\ptr x -> af_histogram ptr x b c d)) + a `op1` (\ptr x -> af_histogram ptr x b c d) -- | Dilation(morphological operator) for images. -- diff --git a/src/ArrayFire/Index.hs b/src/ArrayFire/Index.hs index ae1eaa4..96f88df 100644 --- a/src/ArrayFire/Index.hs +++ b/src/ArrayFire/Index.hs @@ -1,7 +1,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Index --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -10,6 +10,7 @@ -- Functions for indexing into an 'Array' -- -------------------------------------------------------------------------------- +{-# LANGUAGE FlexibleInstances #-} module ArrayFire.Index where import ArrayFire.Internal.Index @@ -29,6 +30,7 @@ index -> [Seq] -- ^ 'Seq' to use for indexing -> Array a +{-# NOINLINE index #-} index (Array fptr) seqs = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do alloca $ \aptr -> @@ -41,65 +43,156 @@ index (Array fptr) seqs = n = fromIntegral (length seqs) -- | Lookup an Array by keys along a specified dimension -lookup - :: Array a +lookup + :: Array a -- ^ Input Array - -> Array Int + -> Array Int -- ^ Indices - -> Int + -> Int -- ^ Dimension -> Array a lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n) --- | A special value representing the entire axis of an 'Array'. -span :: Seq -span = Seq 1 1 0 -- From include/af/seq.h - -- Hard-coded here because FFI cannot import static const values. - --- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' range defined by 'Seq' indices -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- >>> let a = vector \@Double 5 [1..] +-- >>> assignSeq a [Seq 1 3 1] (vector \@Double 3 [0,0,0]) -- @ +assignSeq + :: Array a + -- ^ Destination array + -> [Seq] + -- ^ Indices defining the range to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignSeq #-} +assignSeq (Array fptr) seqs (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> + withArray (toAFSeq <$> seqs) $ \sptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_seq aptr ptr n sptr rhsPtr + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = fromIntegral (length seqs) + +-- | Index into an 'Array' using generalized 'Index' values (arrays or sequences) +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> indexGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] -- @ --- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a --- assignSeq = error "Not implemneted" +indexGen + :: Array a + -- ^ Input array + -> [Index] + -- ^ List of 'Index' values (one per dimension) + -> Array a + -- ^ Indexed result +{-# NOINLINE indexGen #-} +indexGen (Array fptr) indices = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_index_gen aptr ptr (fromIntegral n) iptr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () --- af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' using generalized 'Index' values -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> let b = matrix \@Double (2,2) [[0,0],[0,0]] +-- >>> assignGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] b -- @ +assignGen + :: Array a + -- ^ Destination array + -> [Index] + -- ^ List of 'Index' values defining the range to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignGen #-} +assignGen (Array fptr) indices (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_gen aptr ptr (fromIntegral n) iptr rhsPtr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () + +-- | A special 'Seq' value representing the entire axis of an 'Array'. +-- Hard-coded from include\/af\/seq.h because FFI cannot import static const values. +afSpan :: Seq +afSpan = Seq 1 1 0 + +-- | Select the full extent of a dimension. Use in tuple indices where you want all elements along an axis. +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- arr ! (range 0 2, full, at 1) -- @ --- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- indexGen = error "Not implemneted" +full :: Index +full = SeqIndex False afSpan + +-- | Convert index expressions to a list of 'Index'. +-- Supports a single 'Index' or tuples of up to four 'Index' values +-- (matching ArrayFire's maximum of 4 dimensions). +class ToIndexList a where + toIndexList :: a -> [Index] + +instance ToIndexList Index where + toIndexList x = [x] --- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +instance ToIndexList (Index, Index) where + toIndexList (a, b) = [a, b] + +instance ToIndexList (Index, Index, Index) where + toIndexList (a, b, c) = [a, b, c] + +instance ToIndexList (Index, Index, Index, Index) where + toIndexList (a, b, c, d) = [a, b, c, d] + +-- | Lift a 'Seq' to an 'Index' for use in tuple-based indexing. +idx :: Seq -> Index +idx s = SeqIndex False s + +-- | Index an 'Array'. Accepts a single 'Index' or a tuple of up to four. -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) +-- arr ! at 0 -- 1D: element 0 +-- arr ! range 1 3 -- 1D: rows 1-3 +-- arr ! (range 0 2, at 1) -- 2D +-- arr ! (range 0 2, full, at 1) -- 3D, full second axis -- @ +(!) :: ToIndexList ix => Array a -> ix -> Array a +a ! ix = indexGen a (toIndexList ix) +infixl 9 ! + +-- | Assign into a range of an 'Array'. Lens-style: use with '(&)'. +-- -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- arr & range 1 3 .~ src +-- arr & (range 0 1, at 2) .~ src -- @ --- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- assignGen = error "Not implemneted" - --- af_err af_create_indexers(af_index_t** indexers); --- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim); --- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch); --- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch); --- af_err af_release_indexers(af_index_t* indexers); +(.~) :: ToIndexList ix => ix -> Array a -> Array a -> Array a +(ix .~ rhs) arr = assignGen arr (toIndexList ix) rhs +infixr 4 .~ diff --git a/src/ArrayFire/Internal/Algorithm.hsc b/src/ArrayFire/Internal/Algorithm.hsc index c683a0d..7c20814 100644 --- a/src/ArrayFire/Internal/Algorithm.hsc +++ b/src/ArrayFire/Internal/Algorithm.hsc @@ -75,3 +75,21 @@ foreign import ccall unsafe "af_set_union" af_set_union :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_set_intersect" af_set_intersect :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_sum_by_key" + af_sum_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_sum_by_key_nan" + af_sum_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_product_by_key" + af_product_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_product_by_key_nan" + af_product_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_min_by_key" + af_min_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_max_by_key" + af_max_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_all_true_by_key" + af_all_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_any_true_by_key" + af_any_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_count_by_key" + af_count_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr diff --git a/src/ArrayFire/Internal/BLAS.hsc b/src/ArrayFire/Internal/BLAS.hsc index b3b1788..f75beb2 100644 --- a/src/ArrayFire/Internal/BLAS.hsc +++ b/src/ArrayFire/Internal/BLAS.hsc @@ -17,3 +17,5 @@ foreign import ccall unsafe "af_transpose" af_transpose :: Ptr AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_transpose_inplace" af_transpose_inplace :: AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_gemm" + af_gemm :: Ptr AFArray -> AFMatProp -> AFMatProp -> Ptr () -> AFArray -> AFArray -> Ptr () -> IO AFErr diff --git a/src/ArrayFire/Internal/Defines.hsc b/src/ArrayFire/Internal/Defines.hsc index 9de5f06..2cbdd5e 100644 --- a/src/ArrayFire/Internal/Defines.hsc +++ b/src/ArrayFire/Internal/Defines.hsc @@ -253,7 +253,7 @@ newtype AFBackend = AFBackend CInt #{enum AFBackend, AFBackend , afBackendDefault = AF_BACKEND_DEFAULT - , afBackendCpu = AF_BACKEND_DEFAULT + , afBackendCpu = AF_BACKEND_CPU , afBackendCuda = AF_BACKEND_CUDA , afBackendOpencl = AF_BACKEND_OPENCL } @@ -381,14 +381,14 @@ newtype AFInverseDeconvAlgo = AFInverseDeconvAlgo CInt afInverseDeconvDefault = AF_INVERSE_DECONV_DEFAULT } --- newtype AFVarBias = AFVarBias Int --- deriving (Ord, Show, Eq) +newtype AFVarBias = AFVarBias CInt + deriving (Ord, Show, Eq, Storable) --- #{enum AFVarBias, AFVarBias --- , afVarianceDefault = AF_VARIANCE_DEFAULT --- , afVarianceSample = AF_VARIANCE_SAMPLE --- , afVariancePopulation = AF_VARIANCE_POPULATION --- } +#{enum AFVarBias, AFVarBias + , afVarianceDefault = AF_VARIANCE_DEFAULT + , afVarianceSample = AF_VARIANCE_SAMPLE + , afVariancePopulation = AF_VARIANCE_POPULATION + } newtype DimT = DimT CLLong deriving (Show, Eq, Storable, Num, Integral, Real, Enum, Ord) diff --git a/src/ArrayFire/Internal/LAPACK.hsc b/src/ArrayFire/Internal/LAPACK.hsc index 52ca518..2b0797c 100644 --- a/src/ArrayFire/Internal/LAPACK.hsc +++ b/src/ArrayFire/Internal/LAPACK.hsc @@ -29,6 +29,8 @@ foreign import ccall unsafe "af_solve_lu" af_solve_lu :: Ptr AFArray -> AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr foreign import ccall unsafe "af_inverse" af_inverse :: Ptr AFArray -> AFArray -> AFMatProp -> IO AFErr +foreign import ccall unsafe "af_pinverse" + af_pinverse :: Ptr AFArray -> AFArray -> Double -> AFMatProp -> IO AFErr foreign import ccall unsafe "af_rank" af_rank :: Ptr CUInt -> AFArray -> Double -> IO AFErr foreign import ccall unsafe "af_det" @@ -37,3 +39,5 @@ foreign import ccall unsafe "af_norm" af_norm :: Ptr Double -> AFArray -> AFNormType -> Double -> Double -> IO AFErr foreign import ccall unsafe "af_is_lapack_available" af_is_lapack_available :: Ptr CBool -> IO AFErr +foreign import ccall unsafe "af_eigsh" + af_eigsh :: Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr diff --git a/src/ArrayFire/Internal/Statistics.hsc b/src/ArrayFire/Internal/Statistics.hsc index 744e7b1..1decabc 100644 --- a/src/ArrayFire/Internal/Statistics.hsc +++ b/src/ArrayFire/Internal/Statistics.hsc @@ -36,3 +36,5 @@ foreign import ccall unsafe "af_corrcoef" af_corrcoef :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr foreign import ccall unsafe "af_topk" af_topk :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> CInt -> AFTopkFunction -> IO AFErr +foreign import ccall unsafe "af_meanvar" + af_meanvar :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> AFVarBias -> DimT -> IO AFErr diff --git a/src/ArrayFire/Internal/Types.hsc b/src/ArrayFire/Internal/Types.hsc index 3198d79..1f8b58e 100644 --- a/src/ArrayFire/Internal/Types.hsc +++ b/src/ArrayFire/Internal/Types.hsc @@ -1,7 +1,9 @@ -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE CPP #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE CPP #-} module ArrayFire.Internal.Types where #include "af/seq.h" @@ -17,6 +19,7 @@ import Data.Word import Foreign.C.String import Foreign.C.Types import Foreign.ForeignPtr +import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import Foreign.Storable import GHC.Int @@ -55,8 +58,8 @@ instance Storable AFIndex where afIsBatch <- #{peek af_index_t, isBatch} ptr afIdx <- if afIsSeq - then Left <$> #{peek af_index_t, idx.arr} ptr - else Right <$> #{peek af_index_t, idx.seq} ptr + then Right <$> #{peek af_index_t, idx.seq} ptr + else Left <$> #{peek af_index_t, idx.arr} ptr pure AFIndex{..} poke ptr AFIndex{..} = do case afIdx of @@ -163,12 +166,81 @@ instance AFType Word64 where instance AFType Word where afType Proxy = u64 +-- | Maps an ArrayFire element type to the scalar type returned by whole-array +-- reductions (e.g. 'meanAll', 'det'). Real and integral element types yield +-- 'Double'; complex element types yield 'Complex Double'. +class AFType a => AFResult a where + type Scalar a + -- | Convert the raw @(real, imag)@ pair returned by the C API to the + -- appropriate Haskell scalar. + toAFResult :: (Double, Double) -> Scalar a + +instance AFResult Double where + type Scalar Double = Double + toAFResult (r, _) = r + +instance AFResult Float where + type Scalar Float = Double + toAFResult (r, _) = r + +instance AFResult (Complex Double) where + type Scalar (Complex Double) = Complex Double + toAFResult (r, i) = r :+ i + +instance AFResult (Complex Float) where + type Scalar (Complex Float) = Complex Double + toAFResult (r, i) = r :+ i + +instance AFResult CBool where + type Scalar CBool = Double + toAFResult (r, _) = r + +instance AFResult Int32 where + type Scalar Int32 = Double + toAFResult (r, _) = r + +instance AFResult Word32 where + type Scalar Word32 = Double + toAFResult (r, _) = r + +instance AFResult Word8 where + type Scalar Word8 = Double + toAFResult (r, _) = r + +instance AFResult Int64 where + type Scalar Int64 = Double + toAFResult (r, _) = r + +instance AFResult Int where + type Scalar Int = Double + toAFResult (r, _) = r + +instance AFResult Int16 where + type Scalar Int16 = Double + toAFResult (r, _) = r + +instance AFResult Word16 where + type Scalar Word16 = Double + toAFResult (r, _) = r + +instance AFResult Word64 where + type Scalar Word64 = Double + toAFResult (r, _) = r + +instance AFResult Word where + type Scalar Word = Double + toAFResult (r, _) = r + -- | ArrayFire backends data Backend = Default + -- ^ Use the default backend (determined by ArrayFire) | CPU + -- ^ CPU backend (always available) | CUDA + -- ^ NVIDIA CUDA GPU backend | OpenCL + -- ^ OpenCL backend (AMD, Intel, NVIDIA) deriving (Show, Eq, Ord) -- | Low-level to high-level Backend conversion @@ -200,17 +272,29 @@ toBackends _ = [] -- | Matrix properties data MatProp = None + -- ^ No property | Trans + -- ^ Data needs to be transposed | CTrans + -- ^ Data needs to be conjugate transposed | Conj + -- ^ Data needs to be conjugated | Upper + -- ^ Matrix is upper triangular | Lower + -- ^ Matrix is lower triangular | DiagUnit + -- ^ Diagonal contains units; used with triangular solvers | Sym + -- ^ Matrix is symmetric | PosDef + -- ^ Matrix is positive definite | Orthog + -- ^ Matrix is orthogonal | TriDiag + -- ^ Matrix is tri-diagonal | BlockDiag + -- ^ Matrix is block diagonal deriving (Show, Eq, Ord) -- | Low-level to High-level 'MatProp' conversion @@ -248,12 +332,16 @@ toMatProp Orthog = (AFMatProp 2048) toMatProp TriDiag = (AFMatProp 4096) toMatProp BlockDiag = (AFMatProp 8192) --- | Binary operation support +-- | Binary operation support (used with scan-by-key and similar operations) data BinaryOp = Add + -- ^ Addition | Mul + -- ^ Multiplication | Min + -- ^ Minimum | Max + -- ^ Maximum deriving (Show, Eq, Ord) -- | High-level to low-level 'MatProp' conversion @@ -274,9 +362,13 @@ fromBinaryOp x = error ("Invalid Binary Op: " <> show x) -- | Storage type used for Sparse arrays data Storage = Dense + -- ^ Dense storage (not sparse) | CSR + -- ^ Compressed Sparse Row format | CSC + -- ^ Compressed Sparse Column format | COO + -- ^ Coordinate list (COO) format deriving (Show, Eq, Ord, Enum) toStorage :: Storage -> AFStorage @@ -309,15 +401,25 @@ fromRandomEngine Mersenne = (AFRandomEngineType 300) -- | Interpolation type data InterpType = Nearest + -- ^ Nearest-neighbor interpolation | Linear + -- ^ Linear interpolation | Bilinear + -- ^ Bilinear interpolation | Cubic + -- ^ Cubic interpolation | LowerInterp + -- ^ Floor interpolation (rounds down to nearest integer) | LinearCosine + -- ^ Cosine-windowed linear interpolation | BilinearCosine + -- ^ Cosine-windowed bilinear interpolation | Bicubic + -- ^ Bicubic interpolation | CubicSpline + -- ^ Cubic spline interpolation | BicubicSpline + -- ^ Bicubic spline interpolation deriving (Show, Eq, Ord, Enum) toInterpType :: AFInterpType -> InterpType @@ -346,7 +448,7 @@ data Connectivity toConnectivity :: AFConnectivity -> Connectivity toConnectivity (AFConnectivity 4) = Conn4 -toConnectivity (AFConnectivity 8) = Conn4 +toConnectivity (AFConnectivity 8) = Conn8 toConnectivity (AFConnectivity x) = error ("Unknown connectivity option: " <> show x) fromConnectivity :: Connectivity -> AFConnectivity @@ -356,9 +458,13 @@ fromConnectivity Conn8 = AFConnectivity 8 -- | Color Space type data CSpace = Gray + -- ^ Grayscale | RGB + -- ^ Red-Green-Blue | HSV + -- ^ Hue-Saturation-Value | YCBCR + -- ^ Luminance + chroma (blue-difference, red-difference) deriving (Show, Eq, Ord, Enum) toCSpace :: AFCSpace -> CSpace @@ -367,11 +473,14 @@ toCSpace (AFCSpace (fromIntegral -> x)) = toEnum x fromCSpace :: CSpace -> AFCSpace fromCSpace = AFCSpace . fromIntegral . fromEnum --- | YccStd type +-- | YCbCr standard data YccStd = Ycc601 + -- ^ ITU-R BT.601 (standard definition) | Ycc709 + -- ^ ITU-R BT.709 (high definition) | Ycc2020 + -- ^ ITU-R BT.2020 (ultra high definition) deriving (Show, Eq, Ord) toAFYccStd :: AFYccStd -> YccStd @@ -385,13 +494,18 @@ fromAFYccStd Ycc601 = afYcc601 fromAFYccStd Ycc709 = afYcc709 fromAFYccStd Ycc2020 = afYcc2020 --- | Moment types +-- | Image moment types data MomentType = M00 + -- ^ Zeroth-order moment (image area / mass) | M01 + -- ^ First-order moment about x-axis | M10 + -- ^ First-order moment about y-axis | M11 + -- ^ Mixed first-order moment | FirstOrder + -- ^ All first-order moments (M00, M01, M10, M11) deriving (Show, Eq, Ord) toMomentType :: AFMomentType -> MomentType @@ -410,10 +524,12 @@ fromMomentType M10 = afMomentM10 fromMomentType M11 = afMomentM11 fromMomentType FirstOrder = afMomentFirstOrder --- | Canny Theshold type +-- | Threshold mode for Canny edge detection data CannyThreshold = Manual + -- ^ User-supplied low and high threshold values | AutoOtsu + -- ^ Thresholds computed automatically via Otsu's method deriving (Show, Eq, Ord, Enum) toCannyThreshold :: AFCannyThreshold -> CannyThreshold @@ -422,11 +538,14 @@ toCannyThreshold (AFCannyThreshold (fromIntegral -> x)) = toEnum x fromCannyThreshold :: CannyThreshold -> AFCannyThreshold fromCannyThreshold = AFCannyThreshold . fromIntegral . fromEnum --- | Flux function type +-- | Flux function for anisotropic diffusion data FluxFunction = FluxDefault + -- ^ Default flux function (same as 'FluxQuadratic') | FluxQuadratic + -- ^ Quadratic flux function (Perona-Malik) | FluxExponential + -- ^ Exponential flux function (Perona-Malik) deriving (Show, Eq, Ord, Enum) toFluxFunction :: AFFluxFunction -> FluxFunction @@ -435,11 +554,14 @@ toFluxFunction (AFFluxFunction (fromIntegral -> x)) = toEnum x fromFluxFunction :: FluxFunction -> AFFluxFunction fromFluxFunction = AFFluxFunction . fromIntegral . fromEnum --- | Diffusion type +-- | Diffusion equation type for anisotropic smoothing data DiffusionEq = DiffusionDefault + -- ^ Default (same as 'DiffusionGrad') | DiffusionGrad + -- ^ Gradient-based diffusion (Perona-Malik) | DiffusionMCDE + -- ^ Mean curvature diffusion equation deriving (Show, Eq, Ord, Enum) toDiffusionEq :: AFDiffusionEq -> DiffusionEq @@ -448,11 +570,14 @@ toDiffusionEq (AFDiffusionEq (fromIntegral -> x)) = toEnum x fromDiffusionEq :: DiffusionEq -> AFDiffusionEq fromDiffusionEq = AFDiffusionEq . fromIntegral . fromEnum --- | Iterative deconvolution algo type +-- | Iterative deconvolution algorithm data IterativeDeconvAlgo = DeconvDefault + -- ^ Default algorithm (same as 'DeconvLandweber') | DeconvLandweber + -- ^ Landweber iteration (gradient descent on least squares) | DeconvRichardsonLucy + -- ^ Richardson-Lucy algorithm (maximum likelihood for Poisson noise) deriving (Show, Eq, Ord, Enum) toIterativeDeconvAlgo :: AFIterativeDeconvAlgo -> IterativeDeconvAlgo @@ -461,10 +586,12 @@ toIterativeDeconvAlgo (AFIterativeDeconvAlgo (fromIntegral -> x)) = toEnum x fromIterativeDeconvAlgo :: IterativeDeconvAlgo -> AFIterativeDeconvAlgo fromIterativeDeconvAlgo = AFIterativeDeconvAlgo . fromIntegral . fromEnum --- | Inverse deconvolution algo type +-- | Inverse (non-iterative) deconvolution algorithm data InverseDeconvAlgo = InverseDeconvDefault + -- ^ Default algorithm (same as 'InverseDeconvTikhonov') | InverseDeconvTikhonov + -- ^ Tikhonov regularized Wiener filter deriving (Show, Eq, Ord, Enum) toInverseDeconvAlgo :: AFInverseDeconvAlgo -> InverseDeconvAlgo @@ -473,13 +600,17 @@ toInverseDeconvAlgo (AFInverseDeconvAlgo (fromIntegral -> x)) = toEnum x fromInverseDeconvAlgo :: InverseDeconvAlgo -> AFInverseDeconvAlgo fromInverseDeconvAlgo = AFInverseDeconvAlgo . fromIntegral . fromEnum --- | Cell type, used in Graphics module +-- | Cell type, used in Graphics module to describe a subplot position data Cell = Cell { cellRow :: Int + -- ^ Row index of the subplot (0-based) , cellCol :: Int + -- ^ Column index of the subplot (0-based) , cellTitle :: String + -- ^ Title string displayed above the plot , cellColorMap :: ColorMap + -- ^ Color map used for rendering } deriving (Show, Eq) cellToAFCell :: Cell -> IO AFCell @@ -491,19 +622,30 @@ cellToAFCell Cell {..} = , afCellColorMap = fromColorMap cellColorMap } --- | ColorMap type +-- | Color map for rendering data ColorMap = ColorMapDefault + -- ^ Default grayscale color map | ColorMapSpectrum + -- ^ Rainbow spectrum (violet to red) | ColorMapColors + -- ^ Distinct colors | ColorMapRed + -- ^ Red gradient | ColorMapMood + -- ^ Mood color map (cool tones) | ColorMapHeat + -- ^ Heat map (black to red to yellow to white) | ColorMapBlue + -- ^ Blue gradient | ColorMapInferno + -- ^ Perceptually uniform: black-purple-orange-yellow | ColorMapMagma + -- ^ Perceptually uniform: black-purple-pink-white | ColorMapPlasma + -- ^ Perceptually uniform: blue-purple-yellow | ColorMapViridis + -- ^ Perceptually uniform: purple-teal-yellow deriving (Show, Eq, Ord, Enum) fromColorMap :: ColorMap -> AFColorMap @@ -512,16 +654,24 @@ fromColorMap = AFColorMap . fromIntegral . fromEnum toColorMap :: AFColorMap -> ColorMap toColorMap (AFColorMap (fromIntegral -> x)) = toEnum x --- | Marker type +-- | Marker shape for scatter plots data MarkerType = MarkerTypeNone + -- ^ No marker | MarkerTypePoint + -- ^ Single pixel point | MarkerTypeCircle + -- ^ Circle | MarkerTypeSquare + -- ^ Square | MarkerTypeTriangle + -- ^ Triangle | MarkerTypeCross + -- ^ X cross | MarkerTypePlus + -- ^ Plus sign | MarkerTypeStar + -- ^ Star deriving (Show, Eq, Ord, Enum) fromMarkerType :: MarkerType -> AFMarkerType @@ -530,17 +680,26 @@ fromMarkerType = AFMarkerType . fromIntegral . fromEnum toMarkerType :: AFMarkerType -> MarkerType toMarkerType (AFMarkerType (fromIntegral -> x)) = toEnum x --- | Match type +-- | Template matching metric type data MatchType = MatchTypeSAD + -- ^ Sum of Absolute Differences | MatchTypeZSAD + -- ^ Zero-mean Sum of Absolute Differences | MatchTypeLSAD + -- ^ Locally scaled Sum of Absolute Differences | MatchTypeSSD + -- ^ Sum of Squared Differences | MatchTypeZSSD + -- ^ Zero-mean Sum of Squared Differences | MatchTypeLSSD + -- ^ Locally scaled Sum of Squared Differences | MatchTypeNCC + -- ^ Normalized Cross Correlation | MatchTypeZNCC + -- ^ Zero-mean Normalized Cross Correlation | MatchTypeSHD + -- ^ Sum of Hamming Distances deriving (Show, Eq, Ord, Enum) fromMatchType :: MatchType -> AFMatchType @@ -549,11 +708,14 @@ fromMatchType = AFMatchType . fromIntegral . fromEnum toMatchType :: AFMatchType -> MatchType toMatchType (AFMatchType (fromIntegral -> x)) = toEnum x --- | TopK type +-- | Order for @topk@ results data TopK = TopKDefault + -- ^ Default order (same as 'TopKMax') | TopKMin + -- ^ Return the k smallest values | TopKMax + -- ^ Return the k largest values deriving (Show, Eq, Ord, Enum) fromTopK :: TopK -> AFTopkFunction @@ -562,10 +724,25 @@ fromTopK = AFTopkFunction . fromIntegral . fromEnum toTopK :: AFTopkFunction -> TopK toTopK (AFTopkFunction (fromIntegral -> x)) = toEnum x --- | Homography Type +-- | Variance bias correction method +data VarBias + = VarianceDefault + -- ^ Default (same as 'VariancePopulation') + | VarianceSample + -- ^ Sample variance (divides by N-1; Bessel's correction) + | VariancePopulation + -- ^ Population variance (divides by N) + deriving (Show, Eq, Ord, Enum) + +fromVarBias :: VarBias -> AFVarBias +fromVarBias = AFVarBias . fromIntegral . fromEnum + +-- | Homography estimation method data HomographyType = RANSAC + -- ^ Random Sample Consensus — robust to outliers | LMEDS + -- ^ Least Median of Squares — robust to up to 50% outliers deriving (Show, Eq, Ord, Enum) fromHomographyType :: HomographyType -> AFHomographyType @@ -586,26 +763,33 @@ toAFSeq :: Seq -> AFSeq toAFSeq (Seq x y z) = (AFSeq x y z) -- | Index Type -data Index a - = Index - { idx :: Either (Array a) Seq - , isSeq :: !Bool - , isBatch :: !Bool - } +data Index + = SeqIndex Bool Seq + | ArrIndex Bool (Array Int) + +seqIdx :: Seq -> Bool -> Index +seqIdx s batch = SeqIndex batch s + +arrIdx :: Array Int -> Bool -> Index +arrIdx a batch = ArrIndex batch a + +-- | Index a contiguous range [begin..end] with step 1. +range :: Int -> Int -> Index +range b e = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) 1) -seqIdx :: Seq -> Bool -> Index a -seqIdx s = Index (Right s) True +-- | Index a range [begin..end] with an explicit step. +rangeStep :: Int -> Int -> Int -> Index +rangeStep b e s = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) (fromIntegral s)) -arrIdx :: Array a -> Bool -> Index a -arrIdx a = Index (Left a) False +-- | Index a single element. +at :: Int -> Index +at n = let d = fromIntegral n in SeqIndex False (Seq d d 1) -toAFIndex :: Index a -> IO AFIndex -toAFIndex (Index a b c) = do - case a of - Right s -> pure $ AFIndex (Right (toAFSeq s)) b c - Left (Array fptr) -> do - withForeignPtr fptr $ \ptr -> - pure $ AFIndex (Left ptr) b c +toAFIndex :: Index -> IO AFIndex +toAFIndex (SeqIndex batch s) = + pure $ AFIndex (Right (toAFSeq s)) True batch +toAFIndex (ArrIndex batch (Array fptr)) = + pure $ AFIndex (Left (unsafeForeignPtrToPtr fptr)) False batch -- | Type alias for ArrayFire API version @@ -669,20 +853,32 @@ fromConvMode (AFConvMode (fromIntegral -> x)) = toEnum x toConvMode :: ConvMode -> AFConvMode toConvMode = AFConvMode . fromIntegral . fromEnum --- | Array Fire types +-- | ArrayFire element types (mirrors @af_dtype@) data AFDType = F32 + -- ^ 32-bit IEEE 754 float | C32 + -- ^ Complex number of two 32-bit floats | F64 + -- ^ 64-bit IEEE 754 double | C64 + -- ^ Complex number of two 64-bit doubles | B8 + -- ^ 8-bit boolean | S32 + -- ^ 32-bit signed integer | U32 + -- ^ 32-bit unsigned integer | U8 + -- ^ 8-bit unsigned integer | S64 + -- ^ 64-bit signed integer | U64 + -- ^ 64-bit unsigned integer | S16 + -- ^ 16-bit signed integer | U16 + -- ^ 16-bit unsigned integer deriving (Show, Eq, Enum) fromAFType :: AFDtype -> AFDType diff --git a/src/ArrayFire/LAPACK.hs b/src/ArrayFire/LAPACK.hs index d30e98f..12a2138 100644 --- a/src/ArrayFire/LAPACK.hs +++ b/src/ArrayFire/LAPACK.hs @@ -1,8 +1,10 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.LAPACK --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -220,6 +222,25 @@ inverse inverse a m = a `op1` (\x y -> af_inverse x y (toMatProp m)) +-- | Compute the pseudo-inverse (Moore-Penrose) of a matrix. +-- +-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__ops__func__pinv.htm) +-- +-- Uses SVD internally. Any singular value below @tol@ is treated as zero. +-- +pinverse + :: AFType a + => Array a + -- ^ input matrix + -> Double + -- ^ tolerance for treating singular values as zero + -> MatProp + -- ^ matrix properties + -> Array a + -- ^ pseudo-inverse of the input +pinverse a tol m = + a `op1` (\x y -> af_pinverse x y tol (toMatProp m)) + -- | Find the rank of the input matrix -- -- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__rank.htm) @@ -244,12 +265,12 @@ rank a b = -- C Interface for finding the determinant of a matrix. -- det - :: AFType a + :: forall a . AFResult a => Array a - -- ^ is input matrix - -> (Double,Double) - -- ^ will contain the real and imaginary part of the determinant of in -det = (`infoFromArray2` af_det) + -- ^ Input matrix + -> Scalar a + -- ^ Determinant ('Double' for real matrices, 'Complex Double' for complex) +det arr = toAFResult @a (arr `infoFromArray2` af_det) -- | Find the norm of the input matrix. -- @@ -272,6 +293,27 @@ norm norm arr (fromNormType -> a) b c = arr `infoFromArray` (\w y -> af_norm w y a b c) +-- | Eigendecomposition of a real symmetric (or complex Hermitian) matrix. +-- +-- On a CUDA backend calls @cusolverDnDsyevd@ (f64) or @cusolverDnSsyevd@ (f32) +-- directly via dlopen — zero CPU\/GPU transfers, correctly ordered with +-- surrounding ArrayFire operations. On CPU or OpenCL backends (or when +-- cuSOLVER is unavailable) falls back to ArrayFire's own SVD with sign +-- recovery, so the function works on all backends. +-- +-- Returns @(eigenvalues, eigenvectors)@: +-- +-- * @eigenvalues@ — length-n vector in /ascending/ order. +-- * @eigenvectors@ — n×n matrix; column @i@ is the eigenvector for @eigenvalues[i]@. +-- +eigSH + :: AFType a + => Array a + -- ^ real symmetric or complex Hermitian n×n matrix (f32 or f64) + -> (Array a, Array a) + -- ^ (eigenvalues vector, eigenvectors matrix) +eigSH mat = mat `op2p` af_eigsh + -- | Is LAPACK available -- -- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__helper__func__available.htm) diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index 0d9383a..b72271c 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -6,7 +6,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Orphans --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -15,7 +15,10 @@ -------------------------------------------------------------------------------- module ArrayFire.Orphans where -import Prelude +import Prelude hiding (pi) +import qualified Prelude + +import Control.DeepSeq (NFData(..)) import qualified ArrayFire.Arith as A import qualified ArrayFire.Array as A @@ -24,18 +27,43 @@ import qualified ArrayFire.Data as A import ArrayFire.Types import ArrayFire.Util +instance NFData (Array a) where + rnf x = x `seq` () + +-- | Structural equality on 'Array': equal shapes and elementwise-equal values. +-- +-- 'A.allTrueAll' reads back a @(real, imaginary)@ pair; for the boolean +-- reduction produced by 'A.eqBatched' the imaginary component is reliably +-- @0@, so comparing the full tuple against @(1.0, 0.0)@ is safe. '/=' is the +-- negation of '==', which keeps the two operators consistent by construction. instance (AFType a, Eq a) => Eq (Array a) where - x == y = A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) - x /= y = A.allTrueAll (A.neqBatched x y False) == (0.0,0.0) + x == y = A.getDims x == A.getDims y + && A.allTrueAll (A.eqBatched (A.eval x) (A.eval y) False) == 1.0 + + x /= y = A.getDims x /= A.getDims y + || A.anyTrueAll (A.neqBatched (A.eval x) (A.eval y) False) /= 0.0 + +-- | Elementwise 'Num' instance for 'Array'. +-- +-- Note that 'signum' implements the real-valued, three-way sign +-- (@x > 0 -> 1@, @x < 0 -> -1@, otherwise @0@). This matches Haskell's +-- 'signum' for integral and real-floating arrays with finite values, but +-- diverges in a few cases: +-- +-- * @NaN@ (for 'Float'\/'Double') yields @0@, whereas Haskell yields @NaN@. +-- * Negative zero @-0.0@ yields @+0.0@, losing the signed zero that +-- Haskell preserves. +-- * For complex arrays (e.g. @'Array' ('Data.Complex.Complex' Double)@) +-- it returns @1@\/@-1@\/@0@ from an order comparison rather than the unit +-- phasor @z / 'abs' z@ that Haskell's 'signum' produces, so the law +-- @'abs' x * 'signum' x == x@ does not hold for complex inputs. instance (Num a, AFType a) => Num (Array a) where x + y = A.add x y x * y = A.mul x y abs = A.abs - signum x = A.sign (-x) - A.sign x - negate arr = do - let (w,x,y,z) = A.getDims arr - A.cast (A.constant @a [w,x,y,z] 0) `A.sub` arr + signum x = A.select (A.gt x 0) 1 (A.select (A.lt x 0) (-1) 0) + negate arr = A.scalar @a (fromInteger (-1)) `A.mul` arr x - y = A.sub x y fromInteger = A.scalar . fromIntegral @@ -47,7 +75,7 @@ instance forall a . (Fractional a, AFType a) => Fractional (Array a) where fromRational n = A.scalar @a (fromRational n) instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where - pi = A.scalar @a 3.14159 + pi = A.scalar @a (realToFrac (Prelude.pi :: Double)) exp = A.exp @a log = A.log @a sqrt = A.sqrt @a diff --git a/src/ArrayFire/Random.hs b/src/ArrayFire/Random.hs index 0f0c31f..b3b59d5 100644 --- a/src/ArrayFire/Random.hs +++ b/src/ArrayFire/Random.hs @@ -11,7 +11,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Random --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -178,10 +178,14 @@ getDefaultRandomEngine = alloca $ \ptrInput -> do throwAFError =<< af_get_default_random_engine ptrInput peek ptrInput - fptr <- newForeignPtr af_release_random_engine_finalizer ptr + retained <- + alloca $ \ptrRetained -> do + throwAFError =<< af_retain_random_engine ptrRetained ptr + peek ptrRetained + fptr <- newForeignPtr af_release_random_engine_finalizer retained pure (RandomEngine fptr) --- | Set defualt 'RandomEngine' type +-- | Set default 'RandomEngine' type -- -- @ -- >>> setDefaultRandomEngineType Philox @@ -222,11 +226,17 @@ setSeed = afCall . af_set_seed . fromIntegral getSeed :: IO Int getSeed = fromIntegral <$> afCall1 af_get_seed +-- | Internal helper that runs a random-generation FFI call which draws from a +-- given 'RandomEngine'. Builds an 'Array' of the requested dimensions, passing +-- the dimensions, element type and engine through to the supplied C function. randEng :: forall a . AFType a => [Int] + -- ^ Dimensions of the 'Array' to generate -> (Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr) + -- ^ Underlying ArrayFire random function to invoke -> RandomEngine + -- ^ Engine to draw random numbers from -> IO (Array a) randEng dims f (RandomEngine fptr) = mask_ $ withForeignPtr fptr $ \rptr -> do @@ -242,15 +252,18 @@ randEng dims f (RandomEngine fptr) = mask_ $ n = fromIntegral (length dims) typ = afType (Proxy @a) +-- | Internal helper that runs a random-generation FFI call using the default +-- random engine. Builds an 'Array' of the requested dimensions, passing the +-- dimensions and element type through to the supplied C function. rand :: forall a . AFType a => [Int] -- ^ Dimensions -> (Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr) + -- ^ Underlying ArrayFire random function to invoke -> IO (Array a) rand dims f = mask_ $ do - ptr <- alloca $ \ptrPtr -> do - zeroOutArray ptrPtr + ptr <- calloca $ \ptrPtr -> do withArray (fromIntegral <$> dims) $ \dimArray -> do throwAFError =<< f ptrPtr n dimArray typ peek ptrPtr diff --git a/src/ArrayFire/Signal.hs b/src/ArrayFire/Signal.hs index 4ddae65..84aa698 100644 --- a/src/ArrayFire/Signal.hs +++ b/src/ArrayFire/Signal.hs @@ -2,7 +2,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Signal --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -93,7 +93,7 @@ approx2 -> Array a -- ^ is the array with interpolated values approx2 arr1 arr2 arr3 (fromInterpType -> i1) f = - op3 arr1 arr2 arr3 (\p x y z -> af_approx2 p x y z i1 f) + op3 arr1 arr3 arr2 (\p x y z -> af_approx2 p x y z i1 f) -- DMJ: Where did these functions go? Were they removed? -- http://arrayfire.org/docs/group__approx__mat.htm diff --git a/src/ArrayFire/Sparse.hs b/src/ArrayFire/Sparse.hs index 1b35026..6d7b922 100644 --- a/src/ArrayFire/Sparse.hs +++ b/src/ArrayFire/Sparse.hs @@ -2,7 +2,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Sparse --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -149,6 +149,10 @@ createSparseArrayFromDense a s = -- 1 -- 1 -- + +-- | Converts a sparse 'Array' from one storage format ('Storage') to another +-- +-- [ArrayFire Docs](http://arrayfire.org/docs/group__sparse__func__convert__to.htm) sparseConvertTo :: (AFType a, Fractional a) => Array a diff --git a/src/ArrayFire/Statistics.hs b/src/ArrayFire/Statistics.hs index 8a3db79..9a1719c 100644 --- a/src/ArrayFire/Statistics.hs +++ b/src/ArrayFire/Statistics.hs @@ -1,9 +1,11 @@ -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-unused-imports #-} -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Statistics --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -33,6 +35,9 @@ -------------------------------------------------------------------------------- module ArrayFire.Statistics where +import Data.Word (Word32) +import Foreign.Ptr (nullPtr) + import ArrayFire.Array import ArrayFire.FFI import ArrayFire.Internal.Statistics @@ -40,7 +45,7 @@ import ArrayFire.Internal.Types -- | Calculates 'mean' of 'Array' along user-specified dimension. -- --- >>> mean ( vector @Int 10 [1..] ) 0 +-- >>> mean (vector @Int 10 [1..]) 0 -- ArrayFire Array -- [1 1 1 1] -- 5.5000 @@ -78,15 +83,15 @@ meanWeighted x y (fromIntegral -> n) = -- | Calculates /variance/ of 'Array' along user-specified dimension. -- --- >>> var (vector @Double 8 [1..8]) False 0 +-- >>> var (vector @Double 8 [1..8]) Population 0 -- ArrayFire Array -- [1 1 1 1] --- 6.0000 +-- 5.2500 var :: AFType a => Array a -- ^ Input 'Array' - -> Bool + -> VarianceType -- ^ boolean denoting Population variance (false) or Sample Variance (true) -> Int -- ^ The dimension along which the variance is extracted @@ -96,12 +101,16 @@ var arr (fromIntegral . fromEnum -> b) d = arr `op1` (\p x -> af_var p x b (fromIntegral d)) +-- | Data type used to express variance type in the 'var' function +data VarianceType = Population | Sample + deriving (Show, Eq, Enum) + -- | Calculates 'varWeighted' of 'Array' along user-specified dimension. -- --- >>> varWeighted ( vector @Double 10 [1..] ) ( vector @Double 10 [1..] ) 0 +-- >>> varWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 -- ArrayFire Array -- [1 1 1 1] --- 6.0000 +-- 1.9091 varWeighted :: AFType a => Array a @@ -156,7 +165,7 @@ cov x y (fromIntegral . fromEnum -> n) = -- | Calculates 'median' of 'Array' along user-specified dimension. -- --- >>> median ( vector @Double 10 [1..] ) 0 +-- >>> median (vector @Double 10 [1..]) 0 -- ArrayFire Array -- [1 1 1 1] -- 5.5000 @@ -175,100 +184,100 @@ median a n = -- | Calculates 'mean' of all elements in an 'Array' -- -- >>> meanAll $ matrix @Double (2,2) [[1,2],[4,5]] --- (3.0,2.232709401e-314) +-- 3.0 meanAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input 'Array' - -> (Double, Double) - -- ^ Mean result (real and imaginary part) -meanAll = (`infoFromArray2` af_mean_all) + -> Scalar a + -- ^ Mean of all elements +meanAll arr = toAFResult @a (arr `infoFromArray2` af_mean_all) -- | Calculates weighted mean of all elements in an 'Array' -- -- >>> meanAllWeighted (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[1,2],[3,4]]) --- (3.0,1.400743288453e-312) +-- 2.8181818181818183 meanAllWeighted - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input 'Array' -> Array a -- ^ 'Array' of weights - -> (Double, Double) - -- ^ Weighted mean (real and imaginary part) + -> Scalar a + -- ^ Weighted mean meanAllWeighted a b = - infoFromArray22 a b af_mean_all_weighted + toAFResult @a (infoFromArray22 a b af_mean_all_weighted) -- | Calculates variance of all elements in an 'Array' -- --- >>> varAll (vector @Double 10 (repeat 10)) False --- (0.0,1.4013073623e-312) +-- >>> varAll (vector @Double 10 (repeat 10)) Population +-- 0.0 varAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input 'Array' - -> Bool - -- ^ Input 'Array' - -> (Double, Double) - -- ^ Variance (real and imaginary part) + -> VarianceType + -- ^ 'Population' variance (÷N) or 'Sample' variance (÷N-1) + -> Scalar a + -- ^ Variance of all elements varAll a (fromIntegral . fromEnum -> b) = - infoFromArray2 a $ \x y z -> - af_var_all x y z b + toAFResult @a (infoFromArray2 a $ \x y z -> + af_var_all x y z b) -- | Calculates weighted variance of all elements in an 'Array' -- -- >>> varAllWeighted ( vector @Double 10 [1..] ) ( vector @Double 10 [1..] ) --- (6.0,2.1941097984e-314) +-- 6.011479591836735 varAllWeighted - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input 'Array' -> Array a -- ^ 'Array' of weights - -> (Double, Double) - -- ^ Variance weighted result, (real and imaginary part) + -> Scalar a + -- ^ Weighted variance of all elements varAllWeighted a b = - infoFromArray22 a b af_var_all_weighted + toAFResult @a (infoFromArray22 a b af_var_all_weighted) -- | Calculates standard deviation of all elements in an 'Array' -- -- >>> stdevAll (vector @Double 10 (repeat 10)) --- (0.0,2.190573324e-314) +-- 0.0 stdevAll - :: AFType a + :: forall a . AFResult a => Array a -- ^ Input 'Array' - -> (Double, Double) - -- ^ Standard deviation result, (real and imaginary part) -stdevAll = (`infoFromArray2` af_stdev_all) + -> Scalar a + -- ^ Standard deviation of all elements +stdevAll arr = toAFResult @a (arr `infoFromArray2` af_stdev_all) -- | Calculates median of all elements in an 'Array' -- -- >>> medianAll (vector @Double 10 (repeat 10)) --- (10.0,2.1961564713e-314) +-- 10.0 medianAll - :: (AFType a, Fractional a) + :: forall a . AFResult a => Array a -- ^ Input 'Array' - -> (Double, Double) - -- ^ Median result, real and imaginary part -medianAll = (`infoFromArray2` af_median_all) + -> Scalar a + -- ^ Median of all elements +medianAll arr = toAFResult @a (arr `infoFromArray2` af_median_all) -- | This algorithm returns Pearson product-moment correlation coefficient. -- -- -- >>> corrCoef ( vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] ) --- (-1.0,2.1904819737e-314) +-- -1.0 corrCoef - :: AFType a + :: forall a . AFResult a => Array a -- ^ First input 'Array' -> Array a -- ^ Second input 'Array' - -> (Double, Double) - -- ^ Correlation coefficient result, real and imaginary part + -> Scalar a + -- ^ Correlation coefficient corrCoef a b = - infoFromArray22 a b af_corrcoef + toAFResult @a (infoFromArray22 a b af_corrcoef) -- | This function returns the top k values along a given dimension of the input array. -- @@ -303,8 +312,58 @@ topk -- ^ The number of elements to be retrieved along the dim dimension -> TopK -- ^ If descending, the highest values are returned. Otherwise, the lowest values are returned - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Returns The values of the top k elements along the dim dimension -- along with the indices of the top k elements along the dim dimension topk a (fromIntegral -> x) (fromTopK -> f) = a `op2p` (\b c d -> af_topk b c d x 0 f) + +-- | Simultaneously compute the mean and variance of an 'Array' along a dimension. +-- +-- More efficient than calling 'mean' and 'var' separately. +-- +-- >>> let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +-- >>> v +-- ArrayFire Array +-- [1 1 1 1] +-- 1.2500 +meanVar + :: AFType a + => Array a + -- ^ Input 'Array' + -> VarBias + -- ^ Variance bias correction: 'VariancePopulation' (÷N) or 'VarianceSample' (÷N-1) + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVar arr bias (fromIntegral -> dim) = + arr `op2p` (\pMean pVar aPtr -> + af_meanvar pMean pVar aPtr nullPtr (fromVarBias bias) dim) + +-- | Simultaneously compute the weighted mean and variance of an 'Array' along a dimension. +-- +-- >>> let (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) (vector @Double 4 [1,1,1,1]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +meanVarWeighted + :: AFType a + => Array a + -- ^ Input 'Array' + -> Array a + -- ^ Weights 'Array' + -> VarBias + -- ^ Variance bias correction + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVarWeighted arr weights bias (fromIntegral -> dim) = + op2p2 arr weights $ \pMean pVar aPtr wPtr -> + af_meanvar pMean pVar aPtr wPtr (fromVarBias bias) dim diff --git a/src/ArrayFire/Types.hs b/src/ArrayFire/Types.hs index e63f6c9..d4e87cb 100644 --- a/src/ArrayFire/Types.hs +++ b/src/ArrayFire/Types.hs @@ -14,7 +14,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Types --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD3 -- Maintainer : David Johnson -- Stability : Experimental @@ -31,7 +31,9 @@ module ArrayFire.Types , RandomEngine , Features , AFType (..) + , AFResult (..) , TopK (..) + , VarBias (..) , Backend (..) , MatchType (..) , BinaryOp (..) @@ -52,6 +54,11 @@ module ArrayFire.Types , InverseDeconvAlgo (..) , Seq (..) , Index (..) + , seqIdx + , arrIdx + , range + , rangeStep + , at , NormType (..) , ConvMode (..) , ConvDomain (..) diff --git a/src/ArrayFire/Util.hs b/src/ArrayFire/Util.hs index d8ba69b..de64818 100644 --- a/src/ArrayFire/Util.hs +++ b/src/ArrayFire/Util.hs @@ -4,7 +4,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Util --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -258,6 +258,7 @@ arrayToString -- ^ If 'True', performs takes the transpose before rendering to 'String' -> String -- ^ 'Array' rendered to 'String' +{-# NOINLINE arrayToString #-} arrayToString expr (Array fptr) (fromIntegral -> prec) (fromIntegral . fromEnum -> trans) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> withCString expr $ \expCstr -> @@ -279,6 +280,7 @@ getSizeOf -- ^ Witness of Haskell type that mirrors ArrayFire type. -> Int -- ^ Size of ArrayFire type +{-# NOINLINE getSizeOf #-} getSizeOf proxy = unsafePerformIO . mask_ . alloca $ \csize -> do throwAFError =<< af_get_size_of csize (afType proxy) diff --git a/src/ArrayFire/Vision.hs b/src/ArrayFire/Vision.hs index 71f3bd7..53b7dc0 100644 --- a/src/ArrayFire/Vision.hs +++ b/src/ArrayFire/Vision.hs @@ -4,7 +4,7 @@ -------------------------------------------------------------------------------- -- | -- Module : ArrayFire.Vision --- Copyright : David Johnson (c) 2019-2020 +-- Copyright : David Johnson (c) 2019-2026 -- License : BSD 3 -- Maintainer : David Johnson -- Stability : Experimental @@ -50,6 +50,7 @@ fast -- ^ Is the length of the edges in the image to be discarded by FAST (minimum is 3, as the radius of the circle) -> Features -- ^ Struct containing arrays for x and y coordinates and score, while array orientation is set to 0 as FAST does not compute orientation, and size is set to 1 as FAST does not compute multiple scales +{-# NOINLINE fast #-} fast (Array fptr) thr (fromIntegral -> arc) (fromIntegral . fromEnum -> non) ratio (fromIntegral -> edge) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -78,6 +79,7 @@ harris -> Float -- ^ struct containing arrays for x and y coordinates and score (Harris response), while arrays orientation and size are set to 0 and 1, respectively, because Harris does not compute that information -> Features +{-# NOINLINE harris #-} harris (Array fptr) (fromIntegral -> maxc) minresp sigma (fromIntegral -> bs) thr = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -107,6 +109,7 @@ orb -- ^ blur image with a Gaussian filter with sigma=2 before computing descriptors to increase robustness against noise if true -> (Features, Array a) -- ^ 'Features' struct composed of arrays for x and y coordinates, score, orientation and size of selected features +{-# NOINLINE orb #-} orb (Array fptr) thr (fromIntegral -> feat) scl (fromIntegral -> levels) (fromIntegral . fromEnum -> blur) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feature, arr) <- @@ -144,6 +147,7 @@ sift -> (Features, Array a) -- ^ Features object composed of arrays for x and y coordinates, score, orientation and size of selected features -- Nx128 array containing extracted descriptors, where N is the number of features found by SIFT +{-# NOINLINE sift #-} sift (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -181,6 +185,7 @@ gloh -> (Features, Array a) -- ^ 'Features' object composed of arrays for x and y coordinates, score, orientation and size of selected features -- ^ Nx272 array containing extracted GLOH descriptors, where N is the number of features found by SIFT +{-# NOINLINE gloh #-} gloh (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -274,6 +279,7 @@ susan -> Int -- ^ indicates how many pixels width area should be skipped for corner detection -> Features +{-# NOINLINE susan #-} susan (Array fptr) (fromIntegral -> a) b c d (fromIntegral -> e) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do feat <- @@ -329,6 +335,7 @@ homography -> (Int, Array a) -- ^ is a 3x3 array containing the estimated homography. -- is the number of inliers that the homography was estimated to comprise, in the case that htype is AF_HOMOGRAPHY_RANSAC, a higher inlier_thr value will increase the estimated inliers. Note that if the number of inliers is too low, it is likely that a bad homography will be returned. +{-# NOINLINE homography #-} homography (Array a) (Array b) diff --git a/test/ArrayFire/AlgorithmSpec.hs b/test/ArrayFire/AlgorithmSpec.hs index 6e5b4d6..2d2c879 100644 --- a/test/ArrayFire/AlgorithmSpec.hs +++ b/test/ArrayFire/AlgorithmSpec.hs @@ -1,9 +1,21 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.AlgorithmSpec where -import qualified ArrayFire as A - +import qualified ArrayFire as A +import qualified Data.List as L import Test.Hspec +import Test.Hspec.ApproxExpect (closeList) +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (NonEmptyList (..), (==>)) + +-- | Reference grouping that mirrors ArrayFire's by-key semantics: each +-- contiguous run of equal keys forms one group. +groupByKeyRef :: Eq k => [k] -> [v] -> [(k, [v])] +groupByKeyRef ks vs = + [ (k, map snd grp) + | grp@((k,_):_) <- L.groupBy (\a b -> fst a == fst b) (zip ks vs) + ] spec :: Spec spec = @@ -79,39 +91,397 @@ spec = A.min (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1 - A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) - A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) - A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 + A.min (A.vector @(A.Complex Double) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (1 A.:+ 0) + A.min (A.vector @(A.Complex Float) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (1 A.:+ 0) A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 + it "Should take the maximum element of a vector" $ do + A.max (A.vector @Int 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int32 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Int16 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @Float 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @Double 10 [1..]) 0 `shouldBe` 10 + A.max (A.vector @(A.Complex Double) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (3 A.:+ 4) + A.max (A.vector @(A.Complex Float) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (3 A.:+ 4) + A.max (A.vector @A.CBool 5 [0,1,1,0,1]) 0 `shouldBe` 1 it "Should find if all elements are true along dimension" $ do - A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1 - A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 + A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` A.scalar @A.CBool 1 + A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1 + A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0 it "Should find if any elements are true along dimension" $ do - A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 - A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1 - A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 + A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1 + A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` A.scalar @A.CBool 1 + A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0 it "Should get count of all elements" $ do A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @Double 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @Float 5 (repeat 1)) 0 `shouldBe` 5 it "Should get sum all elements" $ do - A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` (10,0) - A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0) - A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` (3800,0) - A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0) - it "Should get sum all elements" $ do - A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0) - it "Should product all elements in an Array" $ do - A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0) + A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` 10 + A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` 10.0 + A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` 3800 + A.sumAll (A.vector @(A.Complex Double) 3 [1 A.:+ 2, 3 A.:+ 4, 5 A.:+ 6]) `shouldBe` 9.0 A.:+ 12.0 + it "Should sum all elements ignoring NaN" $ do + A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` 11.0 it "Should product all elements in an Array" $ do - A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0) + A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` 32 + it "Should product all elements ignoring NaN" $ do + A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` 100 it "Should find minimum value of an Array" $ do - A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0) + A.minAll (A.vector @Int 5 [0..]) `shouldBe` 0 it "Should find maximum value of an Array" $ do - A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0) --- it "Should find if all elements are true" $ do --- A.allTrue (A.vector @A.CBool 5 (repeat 0)) `shouldBe` False + A.maxAll (A.vector @Int 5 [0..]) `shouldBe` 4 + it "Should find if all elements are true" $ do + A.allTrueAll (A.vector @A.CBool 5 (repeat 0)) `shouldBe` 0 + it "Should sum values grouped by key" $ do + let keys = A.vector @Int 5 [1,1,2,2,2] + vals = A.vector @Double 5 [10,20,1,2,3] + (ko, vo) = A.sumByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [30,6] + it "Should take the product of values grouped by key" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [2,3,4,5] + (ko, vo) = A.productByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [6,20] + it "Should find the minimum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.minByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [1,2] + it "Should find the maximum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.maxByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [3,5] + it "Should count non-zero values per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [1,0,1,1] + (ko, vo) = A.countByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.Word32 2 [1,2] + -- Regression: countByKey output is u32, not the input value dtype. + -- Marshalling to the host (toList) would read garbage if vo were typed + -- as the input value type (Double = 8 bytes vs u32 = 4 bytes). + A.toList vo `shouldBe` [1,2] + it "Should check allTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [1,1,1,0] + (ko, vo) = A.allTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [1,0] + A.toList vo `shouldBe` [1,0] + it "Should check anyTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [0,0,0,1] + (ko, vo) = A.anyTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [0,1] + A.toList vo `shouldBe` [0,1] + it "Should sum values grouped by key, substituting NaN with 0" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [10, (acos 2), 3, 4] + (ko, vo) = A.sumByKeyNaN keys vals 0 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [10, 7] + it "Should take the product of values grouped by key, substituting NaN with 1" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [2, (acos 2), 4, 5] + (ko, vo) = A.productByKeyNaN keys vals 0 1 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [2, 20] + + describe "accum" $ do + it "computes inclusive cumulative sum along dim 0" $ do + A.accum (A.vector @Double 5 [1,2,3,4,5]) 0 + `shouldBe` A.vector @Double 5 [1,3,6,10,15] + it "computes cumulative sum along dim 1 of a matrix" $ do + A.accum (A.mkArray @Double [2,3] [1,2,3,4,5,6]) 1 + `shouldBe` A.mkArray @Double [2,3] [1,2,4,6,9,12] + + describe "diff1" $ do + it "computes first differences along dim 0" $ do + A.diff1 (A.vector @Double 5 [1,2,4,7,11]) 0 + `shouldBe` A.vector @Double 4 [1,2,3,4] + it "first differences of a constant vector are zero" $ do + A.diff1 (A.vector @Double 4 (repeat 5)) 0 + `shouldBe` A.vector @Double 3 [0,0,0] + + describe "diff2" $ do + it "computes second differences of a quadratic sequence" $ do + A.diff2 (A.vector @Double 5 [0,1,4,9,16]) 0 + `shouldBe` A.vector @Double 3 [2,2,2] + it "second differences of a linear sequence are zero" $ do + A.diff2 (A.vector @Double 5 [1,2,3,4,5]) 0 + `shouldBe` A.vector @Double 3 [0,0,0] + + describe "where'" $ do + it "returns indices of nonzero elements" $ do + A.where' (A.vector @Double 5 [0,1,0,2,0]) + `shouldBe` A.vector @A.Word32 2 [1,3] + it "returns empty array when all elements are zero" $ do + A.getDims (A.where' (A.vector @Double 3 [0,0,0])) + `shouldBe` (0,1,1,1) + + describe "scan" $ do + it "inclusive scan with Add equals accum" $ do + A.scan (A.vector @Double 5 [1..5]) 0 A.Add True + `shouldBe` A.vector @Double 5 [1,3,6,10,15] + it "exclusive scan with Add shifts the prefix sums by one" $ do + A.scan (A.vector @Double 5 [1..5]) 0 A.Add False + `shouldBe` A.vector @Double 5 [0,1,3,6,10] + it "inclusive scan with Mul gives running product" $ do + A.scan (A.vector @Double 4 [1..4]) 0 A.Mul True + `shouldBe` A.vector @Double 4 [1,2,6,24] + + describe "scanByKey" $ do + it "resets prefix sum at each key boundary" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [1,2,3,4] + A.scanByKey keys vals 0 A.Add True + `shouldBe` A.vector @Double 4 [1,3,3,7] + + describe "sort" $ do + it "sorts ascending" $ do + A.sort (A.vector @Double 5 [3,1,4,1,5]) 0 A.Asc + `shouldBe` A.vector @Double 5 [1,1,3,4,5] + it "sorts descending" $ do + A.sort (A.vector @Double 5 [3,1,4,1,5]) 0 A.Desc + `shouldBe` A.vector @Double 5 [5,4,3,1,1] + + describe "sortIndex" $ do + it "returns sorted values and original indices" $ do + let (vals, idxs) = A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 A.Asc + vals `shouldBe` A.vector @Double 4 [1,2,3,4] + idxs `shouldBe` A.vector @A.Word32 4 [2,1,0,3] + + describe "sortByKey" $ do + it "sorts values by key order" $ do + let (ks, vs) = A.sortByKey + (A.vector @Double 4 [2,1,4,3]) + (A.vector @Double 4 [10,9,8,7]) + 0 A.Asc + ks `shouldBe` A.vector @Double 4 [1,2,3,4] + vs `shouldBe` A.vector @Double 4 [9,10,7,8] + + describe "setUnique" $ do + it "removes duplicate elements" $ do + A.setUnique (A.vector @Double 4 [1,1,2,2]) True + `shouldBe` A.vector @Double 2 [1,2] + it "returns a single-element array from an all-same vector" $ do + A.setUnique (A.vector @Double 3 [5,5,5]) True + `shouldBe` A.vector @Double 1 [5] + + describe "setUnion" $ do + it "produces the union of two sorted sets" $ do + A.setUnion (A.vector @Double 3 [3,4,5]) (A.vector @Double 3 [1,2,3]) True + `shouldBe` A.vector @Double 5 [1,2,3,4,5] + + describe "setIntersect" $ do + it "produces the intersection of two sorted sets" $ do + A.setIntersect (A.vector @Double 3 [3,4,5]) (A.vector @Double 3 [1,2,3]) True + `shouldBe` A.vector @Double 1 [3] + it "returns empty array for disjoint sets" $ do + A.getDims (A.setIntersect (A.vector @Double 2 [1,2]) (A.vector @Double 2 [3,4]) True) + `shouldBe` (0,1,1,1) + + -- Regression: infoFromArray3 was missing mask_, risking finalizer interference. + -- iminAll and imaxAll are the primary users. + it "iminAll returns correct value and index" $ do + let arr = A.vector @Double 5 [3, 1, 4, 2, 5] + A.iminAll arr `shouldBe` (1.0, 1) + it "imaxAll returns correct value and index" $ do + let arr = A.vector @Double 5 [3, 1, 4, 1, 5] + A.imaxAll arr `shouldBe` (5.0, 4) + + describe "sort (property)" $ do + -- An ascending sort must return exactly the multiset of inputs in + -- non-decreasing order — i.e. agree element-for-element with Data.List. + prop "ascending sort agrees with Data.List.sort" $ \(xs :: [Double]) -> + not (null xs) ==> + A.toList (A.sort (A.vector (length xs) xs) 0 A.Asc) == L.sort xs + + -- A.Descending sort is the reverse ordering. + prop "descending sort is the reverse ordering" $ \(xs :: [Double]) -> + not (null xs) ==> + A.toList (A.sort (A.vector (length xs) xs) 0 A.Desc) == L.sortBy (flip compare) xs + + describe "by-key reductions (property)" $ do + -- These exercise the op2p2kv marshalling (s32 key cast in, s64 cast out) + -- against a pure contiguous-groupBy reference. Keys are squeezed into a + -- small range so random inputs produce real multi-element runs. + -- Note: ArrayFire's by-key C functions require n >= 2; single-element + -- arrays return ArgError at the C level, so we guard length >= 2. + prop "sumByKey matches a contiguous groupBy reference" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 8) . abs . fst) pairs + vals = map snd pairs + (ko, vo) = A.sumByKey (A.vector @Int n keys) (A.vector @Double n vals) 0 + groups = groupByKeyRef keys vals + in A.toList ko == map fst groups + && closeList (A.toList vo) (map (sum . snd) groups) + + prop "maxByKey matches per-group maxima" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 8) . abs . fst) pairs + vals = map snd pairs + (ko, vo) = A.maxByKey (A.vector @Int n keys) (A.vector @Double n vals) 0 + groups = groupByKeyRef keys vals + in A.toList ko == map fst groups + && closeList (A.toList vo) (map (maximum . snd) groups) + + -- countByKey output is u32, not the input dtype. Comparing host values + -- (toList) guards against the result being mistyped as the value dtype. + prop "countByKey matches per-group nonzero counts" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 8) . abs . fst) pairs + vals = map snd pairs + (ko, vo) = A.countByKey (A.vector @Int n keys) (A.vector @Double n vals) 0 + groups = groupByKeyRef keys vals + in A.toList ko == map fst groups + && A.toList vo + == map (fromIntegral . length . filter (/= 0) . snd) groups + + describe "sort (more properties)" $ do + -- Sort is idempotent: sorting a sorted list gives the same list. + prop "sort is idempotent" $ \(xs :: [Double]) -> + not (null xs) ==> + let sorted = A.sort (A.vector (length xs) xs) 0 A.Asc + in A.toList (A.sort sorted 0 A.Asc) == A.toList sorted + + -- Ascending + descending agree on element multisets (reversed). + prop "desc sort is reverse of asc sort" $ \(xs :: [Double]) -> + not (null xs) ==> + A.toList (A.sort (A.vector (length xs) xs) 0 A.Desc) + == reverse (A.toList (A.sort (A.vector (length xs) xs) 0 A.Asc)) + + describe "accum / scan / diff1 properties" $ do + -- accum along dim 0 = inclusive scan with Add. + prop "accum = scan Add inclusive" $ \(xs :: [Double]) -> + not (null xs) ==> + let arr = A.vector (length xs) xs + in closeList + (A.toList (A.accum arr 0)) + (A.toList (A.scan arr 0 A.Add True)) + + -- diff1 is the left-inverse of accum: diff1 (accum xs) recovers xs[1..]. + -- For a length-n vector, accum produces the prefix sums p[i] = sum xs[0..i]. + -- diff1 gives p[i] - p[i-1] = xs[i] for i>=1, so toList (diff1 (accum xs)) + -- equals tail xs. + prop "diff1 (accum xs) = tail xs" $ \(NonEmpty xs) -> + length xs >= 2 ==> + closeList + (A.toList (A.diff1 (A.accum (A.vector (length xs) xs) 0) 0)) + (tail xs) + + describe "set operation properties" $ do + -- setUnion result contains all elements of each input. + prop "setUnion result contains all elements of A" $ \(xs :: [Double]) -> + not (null xs) ==> + let sorted = L.sort (L.nub xs) + n = length sorted + a = A.vector n sorted + b = A.vector 1 [0] + u = A.toList (A.setUnion a b True) + in all (`elem` u) sorted + + -- setIntersect result contains only elements common to both. + prop "setIntersect result is a subset of each input" $ \(xs :: [Double]) (ys :: [Double]) -> + not (null xs) && not (null ys) ==> + let sortedA = L.sort (L.nub xs) + sortedB = L.sort (L.nub ys) + a = A.vector (length sortedA) sortedA + b = A.vector (length sortedB) sortedB + inter = A.toList (A.setIntersect a b True) + in all (`elem` sortedA) inter && all (`elem` sortedB) inter + + describe "by-key reductions (additional coverage)" $ do + prop "minByKey matches per-group minima" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 8) . abs . fst) pairs + vals = map snd pairs + (ko, vo) = A.minByKey (A.vector @Int n keys) (A.vector @Double n vals) 0 + groups = groupByKeyRef keys vals + in A.toList ko == map fst groups + && closeList (A.toList vo) (map (minimum . snd) groups) + + prop "allTrueByKey matches per-group allTrue" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 4) . abs . fst) pairs + vals = map (\v -> if v > 0 then 1 else 0 :: Double) (map snd pairs) + (ko, vo) = A.allTrueByKey + (A.vector @Int n keys) + (A.vector @Double n vals) + 0 + groups = groupByKeyRef keys vals + expected = map (fromIntegral . fromEnum . all (> 0) . snd) groups :: [A.CBool] + in A.toList ko == map fst groups + && A.toList @A.CBool vo == expected + + prop "anyTrueByKey matches per-group anyTrue" $ \(pairs :: [(Int, Double)]) -> + length pairs >= 2 ==> + let n = length pairs + keys = map ((`mod` 4) . abs . fst) pairs + vals = map (\v -> if v > 0 then 1 else 0 :: Double) (map snd pairs) + (ko, vo) = A.anyTrueByKey + (A.vector @Int n keys) + (A.vector @Double n vals) + 0 + groups = groupByKeyRef keys vals + expected = map (fromIntegral . fromEnum . any (> 0) . snd) groups :: [A.CBool] + in A.toList ko == map fst groups + && A.toList @A.CBool vo == expected + + describe "allTrueAll" $ do + it "returns (1,0) when all elements are non-zero" $ + A.allTrueAll (A.vector @A.CBool 5 (repeat 1)) `shouldBe` 1.0 + it "returns (0,0) when any element is zero" $ + A.allTrueAll (A.vector @A.CBool 5 [1,1,0,1,1]) `shouldBe` 0.0 + it "all-zero vector returns (0,0)" $ + A.allTrueAll (A.vector @Double 4 (repeat 0)) `shouldBe` 0.0 + + describe "anyTrueAll" $ do + it "returns (1,0) when at least one element is non-zero" $ + A.anyTrueAll (A.vector @A.CBool 5 [0,0,1,0,0]) `shouldBe` 1.0 + it "returns (0,0) when all elements are zero" $ + A.anyTrueAll (A.vector @A.CBool 5 (repeat 0)) `shouldBe` 0.0 + + describe "countAll" $ do + it "counts non-zero elements across the whole array" $ + A.countAll (A.vector @Double 5 [1,0,1,0,1]) `shouldBe` 3.0 + it "returns 0 for all-zero array" $ + A.countAll (A.vector @Double 3 (repeat 0)) `shouldBe` 0.0 + it "counts all elements in an all-nonzero array" $ + A.countAll (A.vector @Int 4 [1,2,3,4]) `shouldBe` 4.0 + + describe "imin" $ do + it "returns minimum value and index along dim 0" $ do + let (val, idx) = A.imin (A.vector @Double 5 [3,1,4,2,5]) 0 + val `shouldBe` A.scalar @Double 1.0 + idx `shouldBe` A.scalar @A.Word32 1 + it "minimum of sorted ascending vector is the first element" $ do + let (val, idx) = A.imin (A.vector @Int 4 [10,20,30,40]) 0 + val `shouldBe` A.scalar @Int 10 + idx `shouldBe` A.scalar @A.Word32 0 + + describe "imax" $ do + it "returns maximum value and index along dim 0" $ do + let (val, idx) = A.imax (A.vector @Double 5 [3,1,4,2,5]) 0 + val `shouldBe` A.scalar @Double 5.0 + idx `shouldBe` A.scalar @A.Word32 4 + it "maximum of sorted ascending vector is the last element" $ do + let (val, idx) = A.imax (A.vector @Int 4 [10,20,30,40]) 0 + val `shouldBe` A.scalar @Int 40 + idx `shouldBe` A.scalar @A.Word32 3 diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index 623726f..366f37f 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -4,8 +4,9 @@ module ArrayFire.ArithSpec where -import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector) +import ArrayFire (AFType, Array, cast, clamp, cplx, cplx2, getType, imag, isInf, isZero, matrix, maxOf, minOf, mkArray, real, scalar, vector) import qualified ArrayFire +import Data.Complex (Complex (..)) import Control.Exception (throwIO) import Control.Monad (unless, when) import Foreign.C @@ -14,6 +15,7 @@ import GHC.Stack import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..)) import Test.Hspec import Test.Hspec.QuickCheck +import Test.QuickCheck ((==>)) import Prelude hiding (div) compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation @@ -39,8 +41,10 @@ instance HasEpsilon Double where approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b)) +-- | Relative + absolute tolerance check at machine-epsilon scale. +-- Tolerance = max(4*eps, 2*eps * max(|a|,|b|)). approx :: (Ord a, HasEpsilon a) => a -> a -> Bool -approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b +approx a b = approxWith (2 * eps) (4 * eps) a b shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation shouldBeApprox = compareWith approx @@ -92,7 +96,9 @@ spec = matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]] `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]] prop "Should take cubed root" $ \(x :: Double) -> - evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x + let x3 = x * x * x + in not (isNaN x3 || isInfinite x3) ==> + evalf (ArrayFire.cbrt (scalar x3)) `shouldBeApprox` x it "Should lte Array" $ do 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 @@ -101,13 +107,13 @@ spec = it "Should gt Array" $ do 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0 it "Should lt Array" $ do - 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 + 2 `ArrayFire.lt` (3 :: Array Double) `shouldBe` 1 it "Should eq Array" $ do 3 == (3 :: Array Double) `shouldBe` True it "Should and Array" $ do (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1]) `shouldBe` mkArray [1] [0] - it "Should and Array" $ do + it "Should and Array (vector)" $ do (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0]) `shouldBe` mkArray [2] [0, 0] it "Should or Array" $ do @@ -140,15 +146,15 @@ spec = clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) `shouldBe` 2 it "Should check if an array has positive or negative infinities" $ do - isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1 - isInf (scalar @Double 10) `shouldBe` scalar @Double 0 + isInf (scalar @Double (1 / 0)) `shouldBe` scalar @CBool 1 + isInf (scalar @Double 10) `shouldBe` scalar @CBool 0 it "Should check if an array has any NaN values" $ do - ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 - ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 + ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @CBool 1 + ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @CBool 0 it "Should check if an array has any Zero values" $ do - isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0 - isZero (scalar @Double 0) `shouldBe` scalar @Double 1 - isZero (scalar @Double 1) `shouldBe` scalar @Double 0 + isZero (scalar @Double (acos 2)) `shouldBe` scalar @CBool 0 + isZero (scalar @Double 0) `shouldBe` scalar @CBool 1 + isZero (scalar @Double 1) `shouldBe` scalar @CBool 0 prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x @@ -166,3 +172,361 @@ spec = prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x + + describe "erf" $ do + it "erf 0 = 0" $ + evalf (ArrayFire.erf (scalar @Double 0)) `shouldBeApprox` 0 + it "erf 1 ≈ 0.8427" $ + evalf (ArrayFire.erf (scalar @Double 1)) `shouldBeApprox` 0.8427007929497149 + it "erf is odd: erf(-x) = -erf(x)" $ + evalf (ArrayFire.erf (scalar @Double (-1))) `shouldBeApprox` + negate (evalf (ArrayFire.erf (scalar @Double 1))) + + describe "erfc" $ do + it "erfc 0 = 1" $ + evalf (ArrayFire.erfc (scalar @Double 0)) `shouldBeApprox` 1 + it "erf(x) + erfc(x) = 1" $ do + let x = scalar @Double 1.5 + (evalf (ArrayFire.erf x) + evalf (ArrayFire.erfc x)) `shouldBeApprox` 1 + + describe "sigmoid" $ do + it "sigmoid 0 = 0.5" $ + evalf (ArrayFire.sigmoid (scalar @Double 0)) `shouldBeApprox` 0.5 + it "sigmoid(-x) = 1 - sigmoid(x)" $ do + let x = scalar @Double 2.0 + evalf (ArrayFire.sigmoid (negate x)) + `shouldBeApprox` + (1 - evalf (ArrayFire.sigmoid x)) + + describe "expm1" $ do + it "expm1 0 = 0" $ + evalf (ArrayFire.expm1 (scalar @Double 0)) `shouldBeApprox` 0 + it "expm1 1 = e - 1" $ + evalf (ArrayFire.expm1 (scalar @Double 1)) `shouldBeApprox` (exp 1 - 1) + + describe "clamp (vector)" $ do + it "clamps each element to [lo, hi]" $ + clamp (vector @Int 5 [0,1,5,9,10]) + (scalar @Int 2) + (scalar @Int 8) + `shouldBe` vector @Int 5 [2,2,5,8,8] + + describe "signum" $ do + it "positive Int → 1" $ + signum (scalar @Int 5) `shouldBe` scalar @Int 1 + it "negative Int → -1" $ + signum (scalar @Int (-3)) `shouldBe` scalar @Int (-1) + it "zero Int → 0" $ + signum (scalar @Int 0) `shouldBe` scalar @Int 0 + -- unsigned: old sign(-x) - sign(x) wrapped, making signum always 0 + it "positive Word32 → 1 (unsigned negate wraps)" $ + signum (scalar @ArrayFire.Word32 7) `shouldBe` scalar @ArrayFire.Word32 1 + it "zero Word32 → 0" $ + signum (scalar @ArrayFire.Word32 0) `shouldBe` scalar @ArrayFire.Word32 0 + -- IEEE 754: af_sign checks the sign bit, so sign(-0.0) = 1 → old signum(0.0) = 1 + it "negative zero Double → 0 (IEEE 754 -0.0)" $ + evalf (signum (scalar @Double (-0.0))) `shouldBeApprox` 0 + it "positive Double → 1" $ + evalf (signum (scalar @Double 2.5)) `shouldBeApprox` 1 + it "negative Double → -1" $ + evalf (signum (scalar @Double (-2.5))) `shouldBeApprox` (-1) + it "signum vector" $ + signum (vector @Int 3 [-4, 0, 7]) `shouldBe` vector @Int 3 [-1, 0, 1] + + describe "cplx" $ do + it "lifts a real scalar to complex with zero imaginary part" $ + cplx (scalar @Double 5.0) `shouldBe` scalar @(Complex Double) (5.0 :+ 0.0) + it "real . cplx == id on a vector" $ do + let v = vector @Double 4 [1, 2, 3, 4] + (real (cplx v) :: Array Double) `shouldBe` v + it "imag . cplx == 0 on a vector" $ do + let v = vector @Double 4 [1, 2, 3, 4] + ArrayFire.toList (imag (cplx v) :: Array Double) `shouldBe` [0, 0, 0, 0] + + describe "cplx2" $ do + it "combines real and imaginary parts into a complex scalar" $ + cplx2 (scalar @Double 3.0) (scalar @Double 4.0) + `shouldBe` scalar @(Complex Double) (3.0 :+ 4.0) + it "real . cplx2 r i == r" $ do + let r = vector @Double 3 [1, 2, 3] + i = vector @Double 3 [4, 5, 6] + (real (cplx2 r i) :: Array Double) `shouldBe` r + it "imag . cplx2 r i == i" $ do + let r = vector @Double 3 [1, 2, 3] + i = vector @Double 3 [4, 5, 6] + (imag (cplx2 r i) :: Array Double) `shouldBe` i + + describe "real / imag" $ do + it "real extracts the real part of a complex scalar" $ + (real (scalar @(Complex Double) (7.0 :+ 3.0)) :: Array Double) + `shouldBe` scalar @Double 7.0 + it "imag extracts the imaginary part of a complex scalar" $ + (imag (scalar @(Complex Double) (7.0 :+ 3.0)) :: Array Double) + `shouldBe` scalar @Double 3.0 + it "real and imag round-trip via cplx2" $ do + let c = vector @(Complex Double) 3 [1:+2, 3:+4, 5:+6] + cplx2 (real c :: Array Double) (imag c :: Array Double) `shouldBe` c + + describe "factorial" $ do + it "factorial 0 = 1" $ + evalf (ArrayFire.factorial (scalar @Double 0)) `shouldBeApprox` 1 + it "factorial 5 = 120" $ + evalf (ArrayFire.factorial (scalar @Double 5)) `shouldBeApprox` 120 + it "factorial 10 = 3628800" $ + -- factorial is computed via the platform libm gamma function, which is + -- not bit-exact: on macOS it lands ~2.3e-9 off, exceeding the default + -- relative tolerance (~1.6e-9 at this magnitude). Loosen it here. + approxWith 1e-7 1e-7 (evalf (ArrayFire.factorial (scalar @Double 10))) 3628800 + `shouldBe` True + + describe "floor" $ do + it "floor of 1.7 is 1" $ + evalf (ArrayFire.floor (scalar @Double 1.7)) `shouldBeApprox` 1 + it "floor of -1.2 is -2" $ + evalf (ArrayFire.floor (scalar @Double (-1.2))) `shouldBeApprox` (-2) + it "floor of exact integer is unchanged" $ + evalf (ArrayFire.floor (scalar @Double 3.0)) `shouldBeApprox` 3 + + describe "ceil" $ do + it "ceil of 1.2 is 2" $ + evalf (ArrayFire.ceil (scalar @Double 1.2)) `shouldBeApprox` 2 + it "ceil of -1.7 is -1" $ + evalf (ArrayFire.ceil (scalar @Double (-1.7))) `shouldBeApprox` (-1) + it "ceil of exact integer is unchanged" $ + evalf (ArrayFire.ceil (scalar @Double 4.0)) `shouldBeApprox` 4 + + describe "trunc" $ do + it "trunc of 1.9 is 1" $ + evalf (ArrayFire.trunc (scalar @Double 1.9)) `shouldBeApprox` 1 + it "trunc of -1.9 is -1" $ + evalf (ArrayFire.trunc (scalar @Double (-1.9))) `shouldBeApprox` (-1) + it "trunc of exact integer is unchanged" $ + evalf (ArrayFire.trunc (scalar @Double 5.0)) `shouldBeApprox` 5 + + describe "log10" $ do + it "log10 of 100 is 2" $ + evalf (ArrayFire.log10 (scalar @Double 100)) `shouldBeApprox` 2 + it "log10 of 1 is 0" $ + evalf (ArrayFire.log10 (scalar @Double 1)) `shouldBeApprox` 0 + + describe "log2" $ do + it "log2 of 8 is 3" $ + evalf (ArrayFire.log2 (scalar @Double 8)) `shouldBeApprox` 3 + it "log2 of 1 is 0" $ + evalf (ArrayFire.log2 (scalar @Double 1)) `shouldBeApprox` 0 + + describe "log1p" $ do + it "log1p 0 = 0" $ + evalf (ArrayFire.log1p (scalar @Double 0)) `shouldBeApprox` 0 + it "log1p (e-1) = 1" $ + evalf (ArrayFire.log1p (scalar @Double (exp 1 - 1))) `shouldBeApprox` 1 + + describe "pow" $ do + it "2^10 = 1024" $ + ArrayFire.pow (scalar @Int 2) (scalar @Int 10) `shouldBe` scalar @Int 1024 + it "3^3 = 27" $ + ArrayFire.pow (scalar @Int 3) (scalar @Int 3) `shouldBe` scalar @Int 27 + + describe "pow2" $ do + it "pow2 1 = 2" $ + ArrayFire.pow2 (scalar @Int 1) `shouldBe` scalar @Int 2 + it "pow2 4 = 16" $ + ArrayFire.pow2 (scalar @Int 4) `shouldBe` scalar @Int 16 + it "pow2 0 = 1" $ + ArrayFire.pow2 (scalar @Int 0) `shouldBe` scalar @Int 1 + + describe "root" $ do + it "cube root of 8 is 2" $ + evalf (ArrayFire.root (scalar @Double 8) (scalar @Double 3)) `shouldBeApprox` 2 + it "square root of 9 is 3" $ + evalf (ArrayFire.root (scalar @Double 9) (scalar @Double 2)) `shouldBeApprox` 3 + + describe "arg" $ do + it "arg of a positive real scalar is 0" $ + evalf (ArrayFire.arg (scalar @Double 5)) `shouldBeApprox` 0 + it "arg of 0 is 0" $ + evalf (ArrayFire.arg (scalar @Double 0)) `shouldBeApprox` 0 + + describe "atan2" $ do + it "atan2(1,1) = pi/4" $ + evalf (ArrayFire.atan2 (scalar @Double 1) (scalar @Double 1)) + `shouldBeApprox` (pi / 4) + it "atan2(0,1) = 0" $ + evalf (ArrayFire.atan2 (scalar @Double 0) (scalar @Double 1)) + `shouldBeApprox` 0 + + describe "lgamma" $ do + it "lgamma 1 = 0" $ + evalf (ArrayFire.lgamma (scalar @Double 1)) `shouldBeApprox` 0 + it "lgamma 0.5 = log(sqrt(pi))" $ + evalf (ArrayFire.lgamma (scalar @Double 0.5)) `shouldBeApprox` log (sqrt pi) + + describe "tgamma" $ do + it "tgamma 1 = 1" $ + evalf (ArrayFire.tgamma (scalar @Double 1)) `shouldBeApprox` 1 + it "tgamma 5 = 24 (= 4!)" $ + evalf (ArrayFire.tgamma (scalar @Double 5)) `shouldBeApprox` 24 + it "tgamma 0.5 = sqrt(pi)" $ + evalf (ArrayFire.tgamma (scalar @Double 0.5)) `shouldBeApprox` (sqrt pi) + + describe "addBatched" $ do + it "adds two scalars (batch=True)" $ + (scalar @Int 3 `ArrayFire.addBatched` scalar @Int 4) True `shouldBe` scalar @Int 7 + it "adds two scalars (batch=False)" $ + (scalar @Int 10 `ArrayFire.addBatched` scalar @Int 5) False `shouldBe` scalar @Int 15 + + describe "subBatched" $ do + it "subtracts two scalars (batch=True)" $ + (scalar @Int 9 `ArrayFire.subBatched` scalar @Int 4) True `shouldBe` scalar @Int 5 + it "subtracts two scalars (batch=False)" $ + (scalar @Int 10 `ArrayFire.subBatched` scalar @Int 3) False `shouldBe` scalar @Int 7 + + describe "mulBatched" $ do + it "multiplies two scalars (batch=True)" $ + (scalar @Int 3 `ArrayFire.mulBatched` scalar @Int 5) True `shouldBe` scalar @Int 15 + it "multiplies two scalars (batch=False)" $ + (scalar @Int 6 `ArrayFire.mulBatched` scalar @Int 7) False `shouldBe` scalar @Int 42 + + describe "divBatched" $ do + it "divides two scalars (batch=True)" $ + (scalar @Int 12 `ArrayFire.divBatched` scalar @Int 4) True `shouldBe` scalar @Int 3 + it "divides two scalars (batch=False)" $ + (scalar @Int 20 `ArrayFire.divBatched` scalar @Int 5) False `shouldBe` scalar @Int 4 + + describe "eqBatched" $ do + it "equal scalars return 1 (batch=False)" $ + (scalar @Int 5 `ArrayFire.eqBatched` scalar @Int 5) False `shouldBe` scalar @CBool 1 + it "unequal scalars return 0 (batch=False)" $ + (scalar @Int 5 `ArrayFire.eqBatched` scalar @Int 6) False `shouldBe` scalar @CBool 0 + + describe "neqBatched" $ do + it "unequal scalars return 1 (batch=False)" $ + (scalar @Int 5 `ArrayFire.neqBatched` scalar @Int 6) False `shouldBe` scalar @CBool 1 + it "equal scalars return 0 (batch=False)" $ + (scalar @Int 5 `ArrayFire.neqBatched` scalar @Int 5) False `shouldBe` scalar @CBool 0 + + describe "ltBatched" $ do + it "1 < 2 returns 1 (batch=False)" $ + (scalar @Int 1 `ArrayFire.ltBatched` scalar @Int 2) False `shouldBe` scalar @CBool 1 + it "2 < 1 returns 0 (batch=False)" $ + (scalar @Int 2 `ArrayFire.ltBatched` scalar @Int 1) False `shouldBe` scalar @CBool 0 + + describe "leBatched" $ do + it "1 <= 1 returns 1 (batch=False)" $ + (scalar @Int 1 `ArrayFire.leBatched` scalar @Int 1) False `shouldBe` scalar @CBool 1 + it "2 <= 1 returns 0 (batch=False)" $ + (scalar @Int 2 `ArrayFire.leBatched` scalar @Int 1) False `shouldBe` scalar @CBool 0 + + describe "gtBatched" $ do + it "2 > 1 returns 1 (batch=False)" $ + (scalar @Int 2 `ArrayFire.gtBatched` scalar @Int 1) False `shouldBe` scalar @CBool 1 + it "1 > 2 returns 0 (batch=False)" $ + (scalar @Int 1 `ArrayFire.gtBatched` scalar @Int 2) False `shouldBe` scalar @CBool 0 + + describe "geBatched" $ do + it "1 >= 1 returns 1 (batch=False)" $ + (scalar @Int 1 `ArrayFire.geBatched` scalar @Int 1) False `shouldBe` scalar @CBool 1 + it "1 >= 2 returns 0 (batch=False)" $ + (scalar @Int 1 `ArrayFire.geBatched` scalar @Int 2) False `shouldBe` scalar @CBool 0 + + describe "bitAndBatched" $ do + it "bitAndBatched 1 1 = 1 (batch=False)" $ + ArrayFire.bitAndBatched (scalar @Int 1) (scalar @Int 1) False `shouldBe` scalar @Int 1 + it "bitAndBatched 1 0 = 0 (batch=False)" $ + ArrayFire.bitAndBatched (scalar @Int 1) (scalar @Int 0) False `shouldBe` scalar @Int 0 + + describe "bitOrBatched" $ do + it "bitOrBatched 1 0 = 1 (batch=False)" $ + ArrayFire.bitOrBatched (scalar @Int 1) (scalar @Int 0) False `shouldBe` scalar @Int 1 + it "bitOrBatched 0 0 = 0 (batch=False)" $ + ArrayFire.bitOrBatched (scalar @Int 0) (scalar @Int 0) False `shouldBe` scalar @Int 0 + + describe "bitXorBatched" $ do + it "bitXorBatched 1 1 = 0 (batch=False)" $ + ArrayFire.bitXorBatched (scalar @Int 1) (scalar @Int 1) False `shouldBe` scalar @Int 0 + it "bitXorBatched 1 0 = 1 (batch=False)" $ + ArrayFire.bitXorBatched (scalar @Int 1) (scalar @Int 0) False `shouldBe` scalar @Int 1 + + describe "bitShiftL" $ do + it "1 << 3 = 8" $ + ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3) `shouldBe` scalar @Int 8 + it "1 << 0 = 1" $ + ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 0) `shouldBe` scalar @Int 1 + it "3 << 2 = 12" $ + ArrayFire.bitShiftL (scalar @Int 3) (scalar @Int 2) `shouldBe` scalar @Int 12 + + describe "bitShiftR" $ do + it "8 >> 3 = 1" $ + ArrayFire.bitShiftR (scalar @Int 8) (scalar @Int 3) `shouldBe` scalar @Int 1 + it "12 >> 2 = 3" $ + ArrayFire.bitShiftR (scalar @Int 12) (scalar @Int 2) `shouldBe` scalar @Int 3 + it "1 >> 0 = 1" $ + ArrayFire.bitShiftR (scalar @Int 1) (scalar @Int 0) `shouldBe` scalar @Int 1 + + describe "andBatched" $ do + it "1 AND 1 = 1 (batch=False)" $ + ArrayFire.andBatched (scalar @Int 1) (scalar @Int 1) False `shouldBe` scalar @CBool 1 + it "1 AND 0 = 0 (batch=False)" $ + ArrayFire.andBatched (scalar @Int 1) (scalar @Int 0) False `shouldBe` scalar @CBool 0 + + describe "orBatched" $ do + it "1 OR 0 = 1 (batch=False)" $ + ArrayFire.orBatched (scalar @Int 1) (scalar @Int 0) False `shouldBe` scalar @CBool 1 + it "0 OR 0 = 0 (batch=False)" $ + ArrayFire.orBatched (scalar @Int 0) (scalar @Int 0) False `shouldBe` scalar @CBool 0 + + describe "bitShiftLBatched" $ do + it "1 << 3 = 8 (batch=False)" $ + ArrayFire.bitShiftLBatched (scalar @Int 1) (scalar @Int 3) False `shouldBe` scalar @Int 8 + it "3 << 2 = 12 (batch=False)" $ + ArrayFire.bitShiftLBatched (scalar @Int 3) (scalar @Int 2) False `shouldBe` scalar @Int 12 + + describe "bitShiftRBatched" $ do + it "8 >> 3 = 1 (batch=False)" $ + ArrayFire.bitShiftRBatched (scalar @Int 8) (scalar @Int 3) False `shouldBe` scalar @Int 1 + it "12 >> 2 = 3 (batch=False)" $ + ArrayFire.bitShiftRBatched (scalar @Int 12) (scalar @Int 2) False `shouldBe` scalar @Int 3 + + describe "clampBatched" $ do + it "clamp 2 to [1,3] = 2 (batch=False)" $ + ArrayFire.clampBatched (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) False `shouldBe` scalar @Int 2 + it "clamp 0 to [1,3] = 1 (batch=False)" $ + ArrayFire.clampBatched (scalar @Int 0) (scalar @Int 1) (scalar @Int 3) False `shouldBe` scalar @Int 1 + it "clamp 5 to [1,3] = 3 (batch=False)" $ + ArrayFire.clampBatched (scalar @Int 5) (scalar @Int 1) (scalar @Int 3) False `shouldBe` scalar @Int 3 + + describe "remBatched" $ do + it "7 rem 3 = 1 (batch=False)" $ + ArrayFire.remBatched (scalar @Int 7) (scalar @Int 3) False `shouldBe` scalar @Int 1 + it "10 rem 5 = 0 (batch=False)" $ + ArrayFire.remBatched (scalar @Int 10) (scalar @Int 5) False `shouldBe` scalar @Int 0 + + describe "modBatched" $ do + it "7 mod 3 = 1 (batch=False)" $ + ArrayFire.modBatched (scalar @Int 7) (scalar @Int 3) False `shouldBe` scalar @Int 1 + it "9 mod 3 = 0 (batch=False)" $ + ArrayFire.modBatched (scalar @Int 9) (scalar @Int 3) False `shouldBe` scalar @Int 0 + + describe "minOfBatched" $ do + it "min 2 3 = 2 (batch=False)" $ + ArrayFire.minOfBatched (scalar @Int 2) (scalar @Int 3) False `shouldBe` scalar @Int 2 + it "min 5 1 = 1 (batch=False)" $ + ArrayFire.minOfBatched (scalar @Int 5) (scalar @Int 1) False `shouldBe` scalar @Int 1 + + describe "maxOfBatched" $ do + it "max 2 3 = 3 (batch=False)" $ + ArrayFire.maxOfBatched (scalar @Int 2) (scalar @Int 3) False `shouldBe` scalar @Int 3 + it "max 5 1 = 5 (batch=False)" $ + ArrayFire.maxOfBatched (scalar @Int 5) (scalar @Int 1) False `shouldBe` scalar @Int 5 + + describe "rootBatched" $ do + it "cube root of 8 = 2 (batch=False)" $ + evalf (ArrayFire.rootBatched (scalar @Double 8) (scalar @Double 3) False) `shouldBeApprox` 2 + it "square root of 9 = 3 (batch=False)" $ + evalf (ArrayFire.rootBatched (scalar @Double 9) (scalar @Double 2) False) `shouldBeApprox` 3 + + describe "powBatched" $ do + it "2^3 = 8 (batch=False)" $ + ArrayFire.powBatched (scalar @Int 2) (scalar @Int 3) False `shouldBe` scalar @Int 8 + it "5^2 = 25 (batch=False)" $ + ArrayFire.powBatched (scalar @Int 5) (scalar @Int 2) False `shouldBe` scalar @Int 25 diff --git a/test/ArrayFire/ArraySpec.hs b/test/ArrayFire/ArraySpec.hs index 1452a00..ca90429 100644 --- a/test/ArrayFire/ArraySpec.hs +++ b/test/ArrayFire/ArraySpec.hs @@ -4,37 +4,37 @@ module ArrayFire.ArraySpec where import Control.Exception import Data.Complex +import qualified Data.Vector.Storable as V import Data.Word import Foreign.C.Types import GHC.Int import Test.Hspec +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck ((==>)) -import ArrayFire +import ArrayFire hiding (not) spec :: Spec spec = describe "Array tests" $ do - it "Should perform Array tests" $ do - (1 + 1) `shouldBe` 2 - it "Should fail to create 0 dimension arrays" $ do - let arr = mkArray @Int [0,0,0,0] [1..] - evaluate arr `shouldThrow` anyException - it "Should fail to create 0 length arrays" $ do - let arr = mkArray @Int [0,0,0,1] [] - evaluate arr `shouldThrow` anyException - it "Should fail to create 0 length arrays w/ 0 dimensions" $ do - let arr = mkArray @Int [0,0,0,0] [] - evaluate arr `shouldThrow` anyException + it "Should add two scalar arrays" $ do + (scalar @Int 1 + scalar @Int 1) `shouldBe` scalar @Int 2 + it "Should create a 0 dimension array" $ do + getElements (mkArray @Int [3,0,1,1] []) `shouldBe` 0 + it "Should create a 0 length array" $ do + getElements (mkArray @Int [0,0,0,1] []) `shouldBe` 0 + it "Should create a 0 length array w/ 0 dimensions" $ do + getElements (mkArray @Int [0,0,0,0] []) `shouldBe` 0 it "Should create a column vector" $ do let arr = mkArray @Int [9,1,1,1] (repeat 9) isColumn arr `shouldBe` True it "Should create a row vector" $ do let arr = mkArray @Int [1,9,1,1] (repeat 9) isRow arr `shouldBe` True - it "Should create a vector" $ do + it "Should recognize a column array as a vector" $ do let arr = mkArray @Int [9,1,1,1] (repeat 9) isVector arr `shouldBe` True - it "Should create a vector" $ do + it "Should recognize a row array as a vector" $ do let arr = mkArray @Int [1,9,1,1] (repeat 9) isVector arr `shouldBe` True it "Should copy an array" $ do @@ -47,10 +47,10 @@ spec = it "Should return the number of elements" $ do let arr = mkArray @Int [9,9,1,1] [1..] getElements arr `shouldBe` 81 --- it "Should give an empty array" $ do --- let arr = mkArray @Int [-1,1,1,1] [] --- getElements arr `shouldBe` 0 --- isEmpty arr `shouldBe` True + it "Should give an empty array" $ do + let arr = mkArray @Int [0,1,1,1] [] + getElements arr `shouldBe` 0 + isEmpty arr `shouldBe` True it "Should create a scalar array" $ do let arr = mkArray @Int [1] [1] isScalar arr `shouldBe` True @@ -154,3 +154,85 @@ spec = let arr = mkArray @Word [10] [1..10] toList arr `shouldBe` [1..10] + + -- Regression: toVector previously allocated len*size bytes instead of size, + -- causing quadratic memory use. These round-trips verify correct element count + -- and values at sizes where the bug was most wasteful. + describe "toVector round-trip" $ do + it "preserves all elements for a 1000-element Double array" $ do + let xs = [1..1000] :: [Double] + arr = mkArray @Double [1000] xs + V.toList (toVector arr) `shouldBe` xs + it "preserves all elements for a 500-element Int array" $ do + let xs = [1..500] :: [Int] + arr = mkArray @Int [500] xs + V.toList (toVector arr) `shouldBe` xs + it "length of toVector matches getElements" $ do + let arr = mkArray @Double [7, 13] (repeat 0) + V.length (toVector arr) `shouldBe` getElements arr + + describe "fromVector" $ do + it "round-trips a Double vector" $ do + let xs = V.fromList [1..10 :: Double] + arr = fromVector @Double [10] xs + toVector arr `shouldBe` xs + it "round-trips an Int vector" $ do + let xs = V.fromList [1..100 :: Int] + arr = fromVector @Int [100] xs + toVector arr `shouldBe` xs + it "round-trips a Complex Double vector" $ do + let xs = V.fromList [1 :+ 2, 3 :+ 4 :: Complex Double] + arr = fromVector @(Complex Double) [2] xs + toVector arr `shouldBe` xs + it "produces the same result as mkArray" $ do + let xs = [1..25 :: Double] + arr1 = mkArray @Double [5,5] xs + arr2 = fromVector @Double [5,5] (V.fromList xs) + arr2 `shouldBe` arr1 + it "throws on dimension mismatch" $ do + let xs = V.fromList [1,2,3 :: Double] + evaluate (fromVector @Double [4] xs) `shouldThrow` anyException + -- Round-trip is data-preserving (no arithmetic), so equality is exact. + -- This also guards the toVector allocation fix against host over-reads. + prop "toVector . fromVector == id (Double)" $ \(xs :: [Double]) -> + not (null xs) ==> + let v = V.fromList xs + in V.toList (toVector (fromVector @Double [length xs] v)) == xs + prop "toVector . fromVector == id (Int)" $ \(xs :: [Int]) -> + not (null xs) ==> + let v = V.fromList xs + in V.toList (toVector (fromVector @Int [length xs] v)) == xs + + describe "cube" $ do + it "creates a 2x2x2 cube with correct dims" $ do + let c = cube @Double (2,2,2) + [ [[1,2],[3,4]], [[5,6],[7,8]] ] + getDims c `shouldBe` (2,2,2,1) + it "creates a 2x2x2 cube with correct element count" $ do + let c = cube @Double (2,2,2) + [ [[1,2],[3,4]], [[5,6],[7,8]] ] + getElements c `shouldBe` 8 + it "all-constant cube equals constant array" $ do + let c = cube @Double (2,2,2) + [ [[3,3],[3,3]], [[3,3],[3,3]] ] + c `shouldBe` mkArray @Double [2,2,2] (replicate 8 3) + + describe "tensor" $ do + it "creates a 2x2x2x2 tensor with correct dims" $ do + let t = tensor @Double (2,2,2,2) + [ [ [[1,2],[3,4]], [[5,6],[7,8]] ] + , [ [[1,2],[3,4]], [[5,6],[7,8]] ] + ] + getDims t `shouldBe` (2,2,2,2) + it "creates a 2x2x2x2 tensor with correct element count" $ do + let t = tensor @Double (2,2,2,2) + [ [ [[1,2],[3,4]], [[5,6],[7,8]] ] + , [ [[1,2],[3,4]], [[5,6],[7,8]] ] + ] + getElements t `shouldBe` 16 + it "all-constant tensor equals constant array" $ do + let t = tensor @Double (2,2,2,2) + [ [ [[5,5],[5,5]], [[5,5],[5,5]] ] + , [ [[5,5],[5,5]], [[5,5],[5,5]] ] + ] + t `shouldBe` mkArray @Double [2,2,2,2] (replicate 16 5) diff --git a/test/ArrayFire/BLASSpec.hs b/test/ArrayFire/BLASSpec.hs index 40cbbec..5cd267b 100644 --- a/test/ArrayFire/BLASSpec.hs +++ b/test/ArrayFire/BLASSpec.hs @@ -1,10 +1,35 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.BLASSpec where -import ArrayFire hiding (not) +import ArrayFire hiding (not, and, abs, max, mm, tr) import Data.Complex import Test.Hspec +import Test.Hspec.ApproxExpect (closeList) +import Test.Hspec.QuickCheck (prop) + +-- | Build a 4x4 'Double' matrix from an arbitrary (possibly short) list, +-- padding with zeros so the shape is always well-defined. +mat4 :: [Double] -> Array Double +mat4 xs = mkArray [4,4] (take 16 (xs ++ repeat 0)) + +-- | Build a length-4 'Double' vector, padding with zeros. +vec4 :: [Double] -> Array Double +vec4 xs = vector 4 (take 4 (xs ++ repeat 0)) + +-- | Plain matrix product with default (None) operands. +mm :: Array Double -> Array Double -> Array Double +mm a b = (a `matmul` b) None None + +-- | Transpose (no conjugation). +tr :: Array Double -> Array Double +tr a = transpose a False + +-- | Scale every element of a 4x4 matrix by a constant. +scaleMat :: Double -> Array Double -> Array Double +scaleMat c a = mkArray [4,4] (map (c *) (toList a)) + spec :: Spec spec = @@ -14,22 +39,109 @@ spec = `shouldBe` matrix @Double (2,2) [[8,8],[8,8]] it "Should dot product two vectors" $ do dot (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - scalar @Double 8 + `shouldBe` scalar @Double 8 it "Should produce scalar dot product between two vectors as a Complex number" $ do dotAll (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - 8.0 :+ 0.0 + `shouldBe` 8.0 :+ 0.0 it "Should take the transpose of a matrix" $ do transpose (matrix @Double (2,2) [[1,1],[2,2]]) False - `shouldBe` - matrix @Double (2,2) [[1,2],[1,2]] + `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] it "Should take the transpose of a matrix in place" $ do + -- transposeInPlace is an IO () that mutates the underlying C buffer. + -- All Haskell references sharing the same ForeignPtr see the result. + -- Do not use the original binding after calling this. let m = matrix @Double (2,2) [[1,1],[2,2]] transposeInPlace m False m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] + it "Should perform gemm: alpha=1, A*I = A" $ do + let a = matrix @Double (2,2) [[1,2],[3,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm None None 1.0 a b `shouldBe` a + it "Should perform gemm: alpha=2 scales the result" $ do + -- b col-major: col0=[3,4], col1=[5,6] + -- 2 * I * b = 2b → col0=[6,8], col1=[10,12] + let a = matrix @Double (2,2) [[1,0],[0,1]] + b = matrix @Double (2,2) [[3,4],[5,6]] + gemm None None 2.0 a b `shouldBe` matrix @Double (2,2) [[6,8],[10,12]] + it "Should perform gemm with transposed A" $ do + let a = matrix @Double (2,2) [[1,3],[2,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm Trans None 1.0 a b `shouldBe` matrix @Double (2,2) [[1,2],[3,4]] + it "Should perform gemm: non-trivial A*B" $ do + -- matrix (2,2) [[c0r0,c0r1],[c1r0,c1r1]] is column-major. + -- A = [[1,3],[2,4]], B = [[5,7],[6,8]] (rows displayed by ArrayFire) + -- A*B col0 = [1*5+3*6, 2*5+4*6] = [23,34] + -- A*B col1 = [1*7+3*8, 2*7+4*8] = [31,46] + let a = matrix @Double (2,2) [[1,2],[3,4]] + b = matrix @Double (2,2) [[5,6],[7,8]] + gemm None None 1.0 a b `shouldBe` matrix @Double (2,2) [[23,34],[31,46]] + + describe "algebraic properties" $ do + -- Transposition only moves data, so double-transpose is exactly the + -- identity (no floating-point rounding involved). + prop "transpose is an involution" $ \(xs :: [Double]) -> + let m = mat4 xs + in toList (transpose (transpose m False) False) == toList m + + -- Multiplying by the identity matrix recovers the original. + prop "A * I = A" $ \(xs :: [Double]) -> + let a = mat4 xs + in closeList (toList ((a `matmul` identity [4,4]) None None)) (toList a) + + -- (A^T B^T)^T = B A : transpose distributes over a product (reversed). + prop "(A^T B^T)^T = B A" $ \(xs :: [Double]) (ys :: [Double]) -> + let a = mat4 xs + b = mat4 ys + lhs = transpose ((transpose a False `matmul` transpose b False) None None) False + rhs = (b `matmul` a) None None + in closeList (toList lhs) (toList rhs) + + -- Matrix multiplication is associative. + prop "(A*B)*C = A*(B*C)" $ \(xs :: [Double]) (ys :: [Double]) (zs :: [Double]) -> + let a = mat4 xs; b = mat4 ys; c = mat4 zs + in closeList (toList (mm (mm a b) c)) (toList (mm a (mm b c))) + + -- Multiplication distributes over addition on the left. + prop "A*(B+C) = A*B + A*C" $ \(xs :: [Double]) (ys :: [Double]) (zs :: [Double]) -> + let a = mat4 xs; b = mat4 ys; c = mat4 zs + in closeList (toList (mm a (b + c))) (toList (mm a b + mm a c)) + + -- Multiplication distributes over addition on the right. + prop "(A+B)*C = A*C + B*C" $ \(xs :: [Double]) (ys :: [Double]) (zs :: [Double]) -> + let a = mat4 xs; b = mat4 ys; c = mat4 zs + in closeList (toList (mm (a + b) c)) (toList (mm a c + mm b c)) + + -- The identity is a left identity too (the existing case is right-sided). + prop "I*A = A" $ \(xs :: [Double]) -> + let a = mat4 xs + in closeList (toList (mm (identity [4,4]) a)) (toList a) + + -- Transpose of a product reverses the order of the factors. + prop "(A*B)^T = B^T * A^T" $ \(xs :: [Double]) (ys :: [Double]) -> + let a = mat4 xs; b = mat4 ys + in closeList (toList (tr (mm a b))) (toList (mm (tr b) (tr a))) + -- Transpose is additive. + prop "(A+B)^T = A^T + B^T" $ \(xs :: [Double]) (ys :: [Double]) -> + let a = mat4 xs; b = mat4 ys + in closeList (toList (tr (a + b))) (toList (tr a + tr b)) + -- Scalar factors pull through a product: (cA)*B = c(A*B). + prop "(cA)*B = c(A*B)" $ \(c :: Double) (xs :: [Double]) (ys :: [Double]) -> + let a = mat4 xs; b = mat4 ys + in closeList (toList (mm (scaleMat c a) b)) (toList (scaleMat c (mm a b))) + -- The zero matrix annihilates under multiplication. + prop "A*0 = 0" $ \(xs :: [Double]) -> + let a = mat4 xs + in all (== 0) (toList (mm a (mat4 []))) + -- gemm with alpha=1 and no transposition agrees with matmul. + prop "gemm None None 1 A B = A*B" $ \(xs :: [Double]) (ys :: [Double]) -> + let a = mat4 xs; b = mat4 ys + in closeList (toList (gemm None None 1.0 a b)) (toList (mm a b)) + -- The dot product of real vectors is symmetric. + prop "dot x y = dot y x" $ \(xs :: [Double]) (ys :: [Double]) -> + let x = vec4 xs; y = vec4 ys + in closeList (toList (dot x y None None)) (toList (dot y x None None)) diff --git a/test/ArrayFire/DataSpec.hs b/test/ArrayFire/DataSpec.hs index fcbd53f..7d185a7 100644 --- a/test/ArrayFire/DataSpec.hs +++ b/test/ArrayFire/DataSpec.hs @@ -2,14 +2,18 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.DataSpec where -import Control.Exception -import Data.Complex -import Data.Word -import Foreign.C.Types -import GHC.Int -import Test.Hspec +import Control.Exception +import Data.Bits (complement) +import Data.Complex +import Data.Word +import Foreign.C.Types +import GHC.Int +import Prelude hiding (flip) +import Test.Hspec +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck ((==>)) -import ArrayFire +import ArrayFire hiding (not) spec :: Spec spec = @@ -32,8 +36,146 @@ spec = constant @(Complex Float) [1] (1.0 :+ 1.0) `shouldBe` constant @(Complex Float) [1] (1.0 :+ 1.0) + + describe "arange" $ do + it "generates a sequence along dim 0 for a 1D array" $ do + arange @Double [5] (-1) `shouldBe` vector @Double 5 [0,1,2,3,4] + it "generates a sequence along dim 1 for a 2D array" $ do + arange @Double [3,2] 1 `shouldBe` mkArray @Double [3,2] [0,0,0,1,1,1] + + describe "iota" $ do + it "generates a flat sequence without tiling" $ do + iota @Double [5] [] `shouldBe` vector @Double 5 [0,1,2,3,4] + it "tiles the sequence along dim 0" $ do + iota @Double [3] [2] `shouldBe` vector @Double 6 [0,1,2,0,1,2] + + describe "identity" $ do + it "creates a 2x2 identity matrix" $ do + identity @Double [2,2] + `shouldBe` mkArray @Double [2,2] [1,0,0,1] + it "creates a 3x3 identity matrix" $ do + identity @Double [3,3] + `shouldBe` mkArray @Double [3,3] [1,0,0,0,1,0,0,0,1] + + describe "diagCreate" $ do + it "creates a diagonal matrix from a vector (diag 0)" $ do + diagCreate (vector @Double 3 [1,2,3]) 0 + `shouldBe` mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3] + it "creates a superdiagonal matrix (diag 1)" $ do + diagCreate (vector @Double 2 [5,6]) 1 + `shouldBe` mkArray @Double [3,3] [0,0,0,5,0,0,0,6,0] + + describe "diagExtract" $ do + it "extracts the main diagonal of a square matrix" $ do + diagExtract (mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3]) 0 + `shouldBe` vector @Double 3 [1,2,3] + it "is the inverse of diagCreate on the main diagonal" $ do + let v = vector @Double 4 [1,2,3,4] + diagExtract (diagCreate v 0) 0 `shouldBe` v + + describe "lower" $ do + it "extracts the lower triangular part (unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + lower m True + `shouldBe` mkArray @Double [3,3] [1,2,3,0,1,6,0,0,1] + it "extracts the lower triangular part (non-unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + lower m False + `shouldBe` mkArray @Double [3,3] [1,2,3,0,5,6,0,0,9] + + describe "upper" $ do + it "extracts the upper triangular part (unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + upper m True + `shouldBe` mkArray @Double [3,3] [1,0,0,4,1,0,7,8,1] + it "extracts the upper triangular part (non-unit diagonal)" $ do + let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9] + upper m False + `shouldBe` mkArray @Double [3,3] [1,0,0,4,5,0,7,8,9] + + describe "tile" $ do + it "tiles a scalar into a 3x3 array" $ do + tile (scalar @Int 7) [3,3] + `shouldBe` constant @Int [3,3] 7 + it "tiles a row vector along dim 0" $ do + tile (mkArray @Int [1,3] [1,2,3]) [2,1] + `shouldBe` mkArray @Int [2,3] [1,1,2,2,3,3] + + describe "moddims" $ do + it "reshapes a vector into a matrix" $ do + moddims (vector @Int 6 [1..6]) [2,3] + `shouldBe` mkArray @Int [2,3] [1,2,3,4,5,6] + it "reshapes a matrix back to a vector" $ do + let v = vector @Int 6 [1..6] + moddims (moddims v [2,3]) [6] `shouldBe` v + + describe "flat" $ do + it "flattens a 2x3 matrix to a 6-element vector" $ do + flat (mkArray @Int [2,3] [1,2,3,4,5,6]) + `shouldBe` vector @Int 6 [1,2,3,4,5,6] + + describe "flip" $ do + it "reverses a vector (dim 0)" $ do + flip (vector @Int 4 [1,2,3,4]) 0 + `shouldBe` vector @Int 4 [4,3,2,1] + it "reverses columns of a matrix (dim 1)" $ do + flip (mkArray @Int [2,2] [1,2,3,4]) 1 + `shouldBe` mkArray @Int [2,2] [3,4,1,2] + + describe "shift" $ do + it "shifts a vector by 2 elements (wrapping)" $ do + shift (vector @Double 4 [1,2,3,4]) 2 0 0 0 + `shouldBe` vector @Double 4 [3,4,1,2] + + describe "select" $ do + it "selects elements from two arrays based on a boolean mask" $ do + let cond = vector @CBool 4 [1,0,1,0] + a = vector @Double 4 [10,20,30,40] + b = vector @Double 4 [1,2,3,4] + select cond a b `shouldBe` vector @Double 4 [10,2,30,4] + + describe "selectScalarR" $ do + it "uses scalar for false positions" $ do + let cond = vector @CBool 4 [1,0,1,0] + a = vector @Double 4 [10,20,30,40] + selectScalarR cond a 99 `shouldBe` vector @Double 4 [10,99,30,99] + + describe "selectScalarL" $ do + it "uses scalar for true positions" $ do + let cond = vector @CBool 4 [1,0,1,0] + b = vector @Double 4 [1,2,3,4] + selectScalarL cond 99 b `shouldBe` vector @Double 4 [99,2,99,4] + it "Should join Arrays along the specified dimension" $ do join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2] joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3] + + describe "bitNot" $ do + it "complements 0 to all-ones (-1 in two's complement) for Int32" $ do + bitNot (scalar @Int32 0) `shouldBe` scalar @Int32 (-1) + it "complements -1 to 0 for Int32" $ do + bitNot (scalar @Int32 (-1)) `shouldBe` scalar @Int32 0 + it "complements 0 to maxBound for Word32" $ do + bitNot (scalar @Word32 0) `shouldBe` scalar @Word32 maxBound + it "bitNot . bitNot == id" $ do + let v = vector @Int32 4 [0, 1, -1, 42] + bitNot (bitNot v) `shouldBe` v + prop "bitNot is an involution (Int32)" $ \(xs :: [Int32]) -> + not (null xs) ==> + toList (bitNot (bitNot (vector @Int32 (length xs) xs))) == xs + prop "bitNot agrees with Data.Bits.complement (Int32)" $ \(xs :: [Int32]) -> + not (null xs) ==> + toList (bitNot (vector @Int32 (length xs) xs)) == map complement xs + + describe "reorder" $ do + it "reorder [0,1] is identity for a 2D matrix" $ do + let m = matrix @Double (3,4) [[1..3],[3..6],[6..9],[9..12]] + reorder m [0,1] `shouldBe` m + it "reorder [1,0] transposes a matrix" $ do + let m = matrix @Double (2,3) [[1,2],[3,4],[5,6]] + getDims (reorder m [1,0]) `shouldBe` (3,2,1,1) + it "reorder [1,0] then [1,0] round-trips" $ do + let m = matrix @Double (3,4) [[1..3],[3..6],[6..9],[9..12]] + reorder (reorder m [1,0]) [1,0] `shouldBe` m diff --git a/test/ArrayFire/DeviceSpec.hs b/test/ArrayFire/DeviceSpec.hs index 3f2eceb..16cb8a6 100644 --- a/test/ArrayFire/DeviceSpec.hs +++ b/test/ArrayFire/DeviceSpec.hs @@ -7,7 +7,7 @@ import Test.Hspec spec :: Spec spec = - describe "Algorithm tests" $ do + describe "Device tests" $ do it "Should show device info" $ do A.info `shouldReturn` () it "Should show device init" $ do @@ -18,4 +18,6 @@ spec = A.getDevice >>= (`shouldSatisfy` (>= 0)) it "Should get and set device" $ do (A.getDevice >>= A.setDevice) `shouldReturn` () + it "Should get device count >= 1" $ do + A.getDeviceCount >>= (`shouldSatisfy` (>= 1)) diff --git a/test/ArrayFire/ExceptionSpec.hs b/test/ArrayFire/ExceptionSpec.hs new file mode 100644 index 0000000..6fb5b17 --- /dev/null +++ b/test/ArrayFire/ExceptionSpec.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module ArrayFire.ExceptionSpec where + +import Control.Exception (evaluate, try) +import qualified ArrayFire as A +import ArrayFire.Exception +import ArrayFire.Internal.Defines (AFErr (..)) +import Test.Hspec + +spec :: Spec +spec = describe "Exception spec" $ do + + -- The error-code → constructor table is the heart of the FFI error path; + -- a wrong entry silently mislabels every failure of that kind. + describe "toAFExceptionType" $ do + + it "maps every documented AFErr code to its constructor" $ + map (toAFExceptionType . AFErr) + [101,102,103,201,202,203,204,205,207,208,301,302,303,401,402,501,502,503,998,999] + `shouldBe` + [ NoMemoryError, DriverError, RuntimeError, InvalidArrayError, ArgError + , SizeError, TypeError, DiffTypeError, BatchError, DeviceError + , NotSupportedError, NotConfiguredError, NonFreeError, NoDblError + , NoGfxError, LoadLibError, LoadSymError, BackendMismatchError + , InternalError, UnknownError + ] + + it "maps unrecognized codes to UnhandledError" $ do + toAFExceptionType (AFErr 0) `shouldBe` UnhandledError + toAFExceptionType (AFErr 12345) `shouldBe` UnhandledError + + -- End-to-end: a genuine ArrayFire failure must cross the FFI boundary as a + -- typed 'AFException', not a crash or an opaque error. + describe "library errors surface as AFException" $ + + it "a matmul dimension mismatch throws a typed AFException" $ do + let a = A.mkArray @Double [2,3] [1..6] -- 2x3 + b = A.mkArray @Double [2,2] [1..4] -- 2x2 (inner dims 3 /= 2) + r <- try (evaluate (A.getElements (A.matmul a b A.None A.None))) + :: IO (Either AFException Int) + case r of + Right n -> + expectationFailure ("expected an AFException, but got " ++ show n) + Left (AFException ty code _msg) -> do + ty `shouldSatisfy` (`elem` [SizeError, ArgError]) + code `shouldSatisfy` (> 0) diff --git a/test/ArrayFire/FeaturesSpec.hs b/test/ArrayFire/FeaturesSpec.hs index 0d2405e..277be8a 100644 --- a/test/ArrayFire/FeaturesSpec.hs +++ b/test/ArrayFire/FeaturesSpec.hs @@ -1,13 +1,52 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.FeaturesSpec where -import ArrayFire hiding (acos) -import Prelude -import Test.Hspec +import qualified ArrayFire as A +import Test.Hspec + +-- | All five per-feature accessor arrays for a 'Features' handle. +accessors :: A.Features -> [A.Array Float] +accessors f = + [ A.getFeaturesXPos f + , A.getFeaturesYPos f + , A.getFeaturesScore f + , A.getFeaturesOrientation f + , A.getFeaturesSize f + ] spec :: Spec -spec = - describe "Feautures tests" $ do - it "Should get features number an array" $ do - let feats = createFeatures 10 - getFeaturesNum feats `shouldBe` 10 +spec = describe "Features spec" $ do + + describe "createFeatures / getFeaturesNum" $ do + it "reports the requested number of features" $ + A.getFeaturesNum (A.createFeatures 10) `shouldBe` 10 + + it "supports an empty feature set" $ + A.getFeaturesNum (A.createFeatures 0) `shouldBe` 0 + + it "supports a large feature set" $ + A.getFeaturesNum (A.createFeatures 1024) `shouldBe` 1024 + + describe "accessor arrays" $ do + it "every accessor array has getFeaturesNum elements" $ do + let feats = A.createFeatures 10 + map A.getElements (accessors feats) `shouldBe` replicate 5 10 + + it "every accessor array is a column vector of length n" $ do + let feats = A.createFeatures 7 + map A.getDims (accessors feats) `shouldBe` replicate 5 (7,1,1,1) + + -- NB: 'createFeatures 0' is a degenerate case — ArrayFire does not + -- allocate the per-feature accessor arrays for an empty set, so reading + -- them back yields uninitialized handles (garbage element counts / dims). + -- We therefore do not assert anything about accessors of an empty set. + + describe "retainFeatures" $ do + it "preserves the feature count" $ do + let feats = A.createFeatures 10 + A.getFeaturesNum (A.retainFeatures feats) `shouldBe` A.getFeaturesNum feats + + it "preserves accessor-array dimensions" $ do + let feats = A.retainFeatures (A.createFeatures 5) + map A.getDims (accessors feats) `shouldBe` replicate 5 (5,1,1,1) diff --git a/test/ArrayFire/GraphicsSpec.hs b/test/ArrayFire/GraphicsSpec.hs index 3e98667..aa26dd8 100644 --- a/test/ArrayFire/GraphicsSpec.hs +++ b/test/ArrayFire/GraphicsSpec.hs @@ -2,17 +2,34 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.GraphicsSpec where -import Control.Exception -import Data.Complex -import Data.Word -import Foreign.C.Types -import GHC.Int -import Test.Hspec - -import ArrayFire +import ArrayFire (Cell(..), ColorMap(..)) +import Test.Hspec spec :: Spec -spec = - describe "Graphics tests" $ do - it "Should create window" $ do - (1 + 1) `shouldBe` 2 +spec = describe "Graphics spec" $ do + + -- The 'Cell' render-descriptor is a pure record and is always testable, + -- with or without a display. + -- + -- The window operations (createWindow, setTitle, ...) are intentionally + -- not exercised here: they require a live OpenGL/forge context and abort + -- the process with a SIGSEGV on headless machines. A segfault is not a + -- catchable Haskell exception, so there is no safe way to probe them in an + -- automated suite. + describe "Cell" $ do + let cell = Cell 1 2 "chart" ColorMapSpectrum + + it "exposes its fields" $ do + cellRow cell `shouldBe` 1 + cellCol cell `shouldBe` 2 + cellTitle cell `shouldBe` "chart" + cellColorMap cell `shouldBe` ColorMapSpectrum + + it "has a lawful Eq instance" $ do + cell `shouldBe` Cell 1 2 "chart" ColorMapSpectrum + cell `shouldNotBe` Cell 1 2 "chart" ColorMapHeat + + it "carries each ColorMap through a record update" $ + -- ColorMap derives Enum (not Bounded); enumFrom runs to the last ctor + map (cellColorMap . \c -> cell { cellColorMap = c }) [ColorMapDefault ..] + `shouldBe` ([ColorMapDefault ..] :: [ColorMap]) diff --git a/test/ArrayFire/ImageSpec.hs b/test/ArrayFire/ImageSpec.hs index 1824429..00e02ec 100644 --- a/test/ArrayFire/ImageSpec.hs +++ b/test/ArrayFire/ImageSpec.hs @@ -2,17 +2,153 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.ImageSpec where -import Control.Exception -import Data.Complex -import Data.Word -import Foreign.C.Types -import GHC.Int -import Test.Hspec +import qualified ArrayFire as A +import ArrayFire.Exception (AFException (..), AFExceptionType (..)) +import Control.Exception (bracket, finally, try, throwIO) +import System.Directory (getTemporaryDirectory, removeFile) +import System.IO (openTempFile, hClose) +import Test.Hspec +import Test.Hspec.ApproxExpect -import ArrayFire +-- | A 4×4 single-channel constant image. +gray :: A.Array Float +gray = A.constant @Float [4,4] 1.0 + +-- | A 4×4×3 three-channel (RGB) constant image. +rgb :: A.Array Float +rgb = A.constant @Float [4,4,3] 1.0 spec :: Spec -spec = - describe "Image tests" $ do - it "Should test if Image I/O is available" $ do - isImageIOAvailable `shouldReturn` True +spec = describe "Image spec" $ do + describe "isImageIOAvailable" $ + it "reports whether FreeImage support was compiled in" $ + -- value is build-dependent; we only assert the call succeeds & is Bool + (A.isImageIOAvailable >>= (`shouldSatisfy` (`elem` [True, False]))) + + describe "gaussianKernel" $ do + it "produces a kernel of the requested dimensions" $ + A.getDims (A.gaussianKernel @Float 3 5 0 0) `shouldBe` (3,5,1,1) + + it "is normalized to sum ~1" $ + sum (A.toList (A.gaussianKernel @Float 5 5 0 0)) `shouldBeApprox` (1.0 :: Float) + + it "has only non-negative weights" $ + A.toList (A.gaussianKernel @Float 5 5 0 0) `shouldSatisfy` all (>= 0) + + describe "resize" $ do + it "upsamples to the requested dimensions" $ + A.getDims (A.resize gray 8 8 A.Nearest) `shouldBe` (8,8,1,1) + + it "downsamples to the requested dimensions" $ + A.getDims (A.resize gray 2 2 A.Bilinear) `shouldBe` (2,2,1,1) + + it "preserves a constant image under bilinear resize" $ + A.toList (A.resize gray 8 8 A.Bilinear) `shouldSatisfy` all (`approx` 1.0) + + describe "colorspace conversion" $ do + it "rgb2gray collapses the channel dimension" $ + A.getDims (A.rgb2gray rgb 0.3 0.59 0.11) `shouldBe` (4,4,1,1) + + it "rgb2gray of a constant image yields the weighted intensity" $ + A.toList (A.rgb2gray rgb 0.3 0.59 0.11) `shouldSatisfy` all (`approx` 1.0) + + it "gray2rgb expands to three channels" $ + A.getDims (A.gray2rgb gray 1 1 1) `shouldBe` (4,4,3,1) + + it "rgb2ycbcr / ycbcr2rgb preserve image dimensions" $ do + let ycbcr = A.rgb2ycbcr rgb A.Ycc601 + A.getDims ycbcr `shouldBe` (4,4,3,1) + A.getDims (A.ycbcr2rgb ycbcr A.Ycc601) `shouldBe` (4,4,3,1) + + describe "morphology" $ do + it "dilation with an all-ones mask leaves a constant image unchanged" $ do + let mask = A.constant @Float [3,3] 1.0 + A.toList (A.dilate gray mask) `shouldSatisfy` all (`approx` 1.0) + + it "erosion with an all-ones mask leaves a constant image unchanged" $ do + let mask = A.constant @Float [3,3] 1.0 + A.toList (A.erode gray mask) `shouldSatisfy` all (`approx` 1.0) + + describe "histogram" $ do + it "has one element per requested bin" $ + A.getElements (A.histogram gray 16 0 1) `shouldBe` 16 + + it "produces a u32 array" $ + A.getType (A.histogram gray 16 0 1) `shouldBe` A.U32 + + it "accumulates every pixel across all bins" $ + sum (map fromIntegral (A.toList (A.histogram gray 16 0 1))) + `shouldBe` (16 :: Int) -- 4×4 pixels + + describe "gradient" $ + it "of a constant image is zero in both directions" $ do + let (gx, gy) = A.gradient gray + A.toList gx `shouldSatisfy` all (`approx` 0.0) + A.toList gy `shouldSatisfy` all (`approx` 0.0) + + describe "summed area table (sat)" $ do + it "preserves the image dimensions" $ + A.getDims (A.sat gray) `shouldBe` (4,4,1,1) + + it "bottom-right cell holds the total sum" $ + -- column-major: last element is the integral over the whole image + last (A.toList (A.sat gray)) `shouldBeApprox` (16.0 :: Float) + + describe "moments" $ + it "M00 of a constant image equals its total intensity (area)" $ + A.momentsAll gray A.M00 `shouldBeApprox` (16.0 :: Double) + + describe "Image I/O" $ do + it "saveImage/loadImage round-trips a grayscale image" $ do + avail <- A.isImageIOAvailable + if not avail then pending else do + res <- try $ withTempPng $ \path -> do + A.saveImage gray path + img <- A.loadImage @Float path False + A.getDims img `shouldBe` (4,4,1,1) + A.toList img `shouldSatisfy` all (`approx` 1.0) + case res of + Left (AFException LoadLibError _ _) -> pending + Left e -> throwIO e + Right () -> return () + + it "saveImage/loadImage round-trips a colour image" $ do + avail <- A.isImageIOAvailable + if not avail then pending else do + res <- try $ withTempPng $ \path -> do + A.saveImage rgb path + img <- A.loadImage @Float path True + A.getDims img `shouldBe` (4,4,3,1) + A.toList img `shouldSatisfy` all (`approx` 1.0) + case res of + Left (AFException LoadLibError _ _) -> pending + Left e -> throwIO e + Right () -> return () + + it "saveImageNative/loadImageNative round-trips dims" $ do + avail <- A.isImageIOAvailable + if not avail then pending else do + res <- try $ withTempPng $ \path -> do + A.saveImageNative gray path + img <- A.loadImageNative @Float path + let (r, c, _, _) = A.getDims img + (r, c) `shouldBe` (4, 4) + case res of + Left (AFException LoadLibError _ _) -> pending + Left e -> throwIO e + Right () -> return () + + where + -- relative+absolute tolerance check, returning Bool for use with `all` + approx :: Float -> Float -> Bool + approx x e = abs (x - e) <= 1e-8 + 1e-5 * max (abs x) (abs e) + + withTempPng :: (FilePath -> IO a) -> IO a + withTempPng action = + bracket + (do tmp <- getTemporaryDirectory + (path, h) <- openTempFile tmp "af_test.png" + hClose h + pure path) + removeFile + action diff --git a/test/ArrayFire/IndexSpec.hs b/test/ArrayFire/IndexSpec.hs index d709317..e0ea264 100644 --- a/test/ArrayFire/IndexSpec.hs +++ b/test/ArrayFire/IndexSpec.hs @@ -1,21 +1,149 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.IndexSpec where -import qualified ArrayFire as A -import Control.Exception -import Data.Complex -import Data.Int -import Data.Proxy -import Data.Word -import Foreign.C.Types +import qualified ArrayFire as A +import Data.Function ((&)) import Test.Hspec +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (NonEmptyList (..), choose, forAll) spec :: Spec spec = - describe "Index spec" $ do - it "Should index into an array" $ do - let arr = A.vector @Int 10 [1..] - A.index arr [A.Seq 0 4 1] - `shouldBe` - A.vector @Int 5 [1..] + describe "Index" $ do + + describe "index" $ do + it "indexes a sub-range of a vector" $ do + A.index (A.vector @Int 10 [1..]) [A.Seq 0 4 1] + `shouldBe` A.vector @Int 5 [1..] + it "indexes every other element with step=2" $ do + A.index (A.vector @Int 6 [0,1,2,3,4,5]) [A.Seq 0 4 2] + `shouldBe` A.vector @Int 3 [0,2,4] + it "selects the full vector with afSpan" $ do + let arr = A.vector @Int 5 [1..] + A.index arr [A.afSpan] `shouldBe` arr + + describe "afSpan" $ do + it "equals Seq 1 1 0 (the ArrayFire span sentinel)" $ do + A.afSpan `shouldBe` A.Seq 1 1 0 + + describe "lookup" $ do + it "gathers elements by an index array" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + ixArr = A.vector @Int 3 [0, 2, 4] + A.lookup arr ixArr 0 + `shouldBe` A.vector @Double 3 [10, 30, 50] + it "allows repeated indices" $ do + let arr = A.vector @Int 5 [10, 20, 30, 40, 50] + ixArr = A.vector @Int 4 [0, 0, 4, 4] + A.lookup arr ixArr 0 + `shouldBe` A.vector @Int 4 [10, 10, 50, 50] + + describe "assignSeq" $ do + it "assigns into a middle slice of a vector" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + A.assignSeq arr [A.Seq 1 3 1] src + `shouldBe` A.vector @Double 5 [1, 0, 0, 0, 5] + it "assigns a single element" $ do + let arr = A.vector @Double 5 [1..] + src = A.scalar @Double 99 + A.assignSeq arr [A.Seq 2 2 1] src + `shouldBe` A.vector @Double 5 [1, 2, 99, 4, 5] + it "overwrites the full vector via afSpan" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 5 (repeat 0) + A.assignSeq arr [A.afSpan] src `shouldBe` src + + describe "indexGen" $ do + it "indexes a sub-range of a vector with seqIdx" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + A.indexGen arr [A.seqIdx (A.Seq 0 2 1) False] + `shouldBe` A.vector @Double 3 [10, 20, 30] + it "indexes a 2D sub-matrix with two seqIdx" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + A.indexGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]] + + describe "assignGen" $ do + it "assigns into a vector slice with seqIdx" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + result = A.assignGen arr [A.seqIdx (A.Seq 1 3 1) False] src + A.indexGen result [A.seqIdx (A.Seq 1 3 1) False] `shouldBe` src + it "assigns into a 2D sub-matrix with two seqIdx" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + src = A.matrix @Double (2,2) [[0,0],[0,0]] + result = A.assignGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] src + A.indexGen result [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` src + + describe "(!) operator" $ do + it "indexes a 1D sub-range with range" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + (arr A.! A.range 0 2) + `shouldBe` A.vector @Double 3 [10, 20, 30] + it "indexes a single element with at" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + (arr A.! A.at 2) + `shouldBe` A.scalar @Double 30 + it "indexes a 2D sub-matrix with a tuple" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + (arr A.! (A.range 0 1, A.range 0 1)) + `shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]] + + describe "(.~) operator" $ do + it "assigns into a 1D slice" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + result = arr & A.range 1 3 A..~ src + (result A.! A.range 1 3) `shouldBe` src + it "assigns into a 2D sub-matrix" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + src = A.matrix @Double (2,2) [[0,0],[0,0]] + result = arr & (A.range 0 1, A.range 0 1) A..~ src + (result A.! (A.range 0 1, A.range 0 1)) `shouldBe` src + + describe "rangeStep" $ do + it "selects every other element" $ do + let arr = A.vector @Double 6 [0,1,2,3,4,5] + (arr A.! A.rangeStep 0 4 2) + `shouldBe` A.vector @Double 3 [0,2,4] + + describe "indexing properties" $ do + -- afSpan selects all elements, recovering the original array exactly. + prop "index with afSpan is identity" $ \(NonEmpty xs) -> + let arr = A.vector @Double (length xs) xs + in A.index arr [A.afSpan] == arr + + -- Read-after-write: reading back the slice just written returns the source. + prop "index (assignSeq arr seqs src) seqs = src" $ + forAll (choose (1, 20)) $ \n -> + forAll (choose (0, n-1)) $ \lo -> + forAll (choose (lo, n-1)) $ \hi -> + \(xs :: [Double]) (ys :: [Double]) -> + let arr = A.vector @Double n (take n (xs ++ repeat 0)) + src = A.vector @Double (hi - lo + 1) (take (hi - lo + 1) (ys ++ repeat 0)) + seqs = [A.Seq (fromIntegral lo) (fromIntegral hi) 1] + in A.index (A.assignSeq arr seqs src) seqs == src + + -- lookup with identity permutation [0..n-1] returns the original array. + prop "lookup with identity permutation is identity" $ \(NonEmpty xs) -> + let n = length xs + arr = A.vector @Double n xs + ixArr = A.vector @Int n [0..n-1] + in A.lookup arr ixArr 0 == arr + + -- (.~) write-then-read consistency via the (!) operator. + prop "(.~) then (!) recovers the written slice" $ + forAll (choose (2, 20)) $ \n -> + forAll (choose (0, n-1)) $ \lo -> + forAll (choose (lo, n-1)) $ \hi -> + \(xs :: [Double]) (ys :: [Double]) -> + let arr = A.vector @Double n (take n (xs ++ repeat 0)) + src = A.vector @Double (hi - lo + 1) (take (hi - lo + 1) (ys ++ repeat 0)) + result = arr & A.rangeStep lo hi 1 A..~ src + in (result A.! A.rangeStep lo hi 1) == src diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 5c225c7..9d302dd 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -1,45 +1,279 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.LAPACKSpec where -import qualified ArrayFire as A +import qualified ArrayFire as A +import Data.Complex (realPart, imagPart) import Prelude import Test.Hspec -import Test.Hspec.ApproxExpect +import Test.Hspec.ApproxExpect +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (Gen, choose, forAll, vectorOf) + +-- | A 3x3 matrix product with default (None) operands. +mm :: A.Array Double -> A.Array Double -> A.Array Double +mm a b = (a `A.matmul` b) A.None A.None + +-- | Transpose (real, no conjugation). +tr :: A.Array Double -> A.Array Double +tr a = A.transpose a False + +-- | Generate the entries of an @n@x@n@ matrix with modestly sized values so +-- the decompositions stay numerically well-behaved. +genMat :: Int -> Gen [Double] +genMat n = vectorOf (n * n) (choose (-5, 5)) + spec :: Spec spec = describe "LAPACK spec" $ do it "Should have LAPACK available" $ do A.isLAPACKAvailable `shouldBe` True + it "Should perform svd" $ do - let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] - A.getDims s `shouldBe` (4,4,1,1) - A.getDims v `shouldBe` (2,1,1,1) - A.getDims d `shouldBe` (2,2,1,1) + let (u,sigma,vt) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] + A.getDims u `shouldBe` (4,4,1,1) + A.getDims sigma `shouldBe` (2,1,1,1) + A.getDims vt `shouldBe` (2,2,1,1) + it "Should perform svd in place" $ do - let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] - A.getDims s `shouldBe` (4,4,1,1) - A.getDims v `shouldBe` (2,1,1,1) - A.getDims d `shouldBe` (2,2,1,1) + let (u,sigma,vt) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] + A.getDims u `shouldBe` (4,4,1,1) + A.getDims sigma `shouldBe` (2,1,1,1) + A.getDims vt `shouldBe` (2,2,1,1) + it "Should perform lu" $ do - let (s,v,d) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] - A.getDims s `shouldBe` (2,2,1,1) - A.getDims v `shouldBe` (2,2,1,1) - A.getDims d `shouldBe` (2,1,1,1) + let (l,u,piv) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] + A.getDims l `shouldBe` (2,2,1,1) + A.getDims u `shouldBe` (2,2,1,1) + A.getDims piv `shouldBe` (2,1,1,1) + it "Should perform qr" $ do - let (s,v,d) = A.lu $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] - A.getDims s `shouldBe` (3,3,1,1) - A.getDims v `shouldBe` (3,3,1,1) - A.getDims d `shouldBe` (3,1,1,1) - it "Should get determinant of Double" $ do - let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] - (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) - x `shouldBeApprox` (-14) - let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] - x `shouldBeApprox` (-14) --- it "Should calculate inverse" $ do --- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] --- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] --- it "Should calculate psuedo inverse" $ do --- let x = A.pinverse (A.matrix @Double (2,2) [[4,7],[2,6]]) 1.0 A.None --- x `shouldBe` A.matrix @Double (2,2) [[0.6,-0.2],[-0.7,0.4]] + let (q,r,tau) = A.qr $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] + A.getDims q `shouldBe` (3,3,1,1) + A.getDims r `shouldBe` (3,3,1,1) + A.getDims tau `shouldBe` (3,1,1,1) + + it "Should get determinant of a real matrix" $ do + A.det (A.matrix @Double (2,2) [[3,8],[4,6]]) + `shouldBeApprox` (-14) + + it "Should get determinant of a complex matrix" $ do + -- M = | 3+i 4+i | (column-major: col0=[3+i,8+i], col1=[4+i,6+i]) + -- | 8+i 6+i | + -- det = (3+i)(6+i) - (4+i)(8+i) = -14 - 3i + let d = A.det $ A.matrix @(A.Complex Double) (2,2) + [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] + realPart d `shouldBeApprox` (-14) + imagPart d `shouldBeApprox` (-3) + + it "Should calculate inverse" $ do + -- M = | 4 2 | (column-major: col0=[4,7], col1=[2,6]) + -- | 7 6 | + -- M^-1 = (1/10) * | 6 -2 | = col0=[0.6,-0.7], col1=[-0.2,0.4] + -- | -7 4 | + let result = A.toList $ A.inverse (A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]]) A.None + expected = [0.6, -0.7, -0.2, 0.4] + mapM_ (uncurry shouldBeApprox) (zip result expected) + + it "Should find the rank of a matrix" $ do + A.rank (A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]) 1e-5 `shouldBe` 2 + A.rank (A.identity @Double [3,3]) 1e-5 `shouldBe` 3 + + it "Should compute the norm of a vector" $ do + -- || [3, 4] ||_2 = 5 + A.norm (A.vector @Double 2 [3,4]) A.NormVector2 1 1 `shouldBeApprox` 5 + -- || [3, 4] ||_1 = 7 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorOne 1 1 `shouldBeApprox` 7 + -- || [3, 4] ||_inf = 4 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorInf 1 1 `shouldBeApprox` 4 + + it "Should perform cholesky decomposition" $ do + -- A = | 4 2 | (column-major: [4,2,2,3]) + -- | 2 3 | + -- L = | 2 0 | where L*L^T = A + -- | 1 √2 | + let a = A.mkArray @Double [2,2] [4,2,2,3] + (status, l) = A.cholesky a False + status `shouldBe` 0 + let ls = A.toList @Double l + mapM_ (uncurry shouldBeApprox) (zip ls [2, 1, 0, sqrt 2]) + + it "choleskyInplace returns 0 for a symmetric positive definite matrix" $ do + let a = A.mkArray @Double [2,2] [4,2,2,3] + A.choleskyInplace a False `shouldBe` 0 + + it "Should solve Ax=b using solveLU" $ do + -- A = | 2 1 | b = | 5 | => x = | 1 | + -- | 1 3 | | 10| | 3 | + -- Column-major A: [2,1,1,3], b: [5,10] + let a = A.mkArray @Double [2,2] [2,1,1,3] + b = A.vector @Double 2 [5,10] + piv = A.luInPlace a True + x = A.solveLU a piv b A.None + mapM_ (uncurry shouldBeApprox) (zip (A.toList @Double x) [1,3]) + + describe "decomposition reconstruction properties" $ do + -- QR factors multiply back to the original matrix. + prop "QR: Q*R = A" $ forAll (genMat 3) $ \xs -> + let a = A.mkArray @Double [3,3] xs + (q,r,_) = A.qr a + in closeList (A.toList (mm q r)) (A.toList a) + + -- The Q factor is orthogonal: Q^T Q = I. + prop "QR: Q^T Q = I" $ forAll (genMat 3) $ \xs -> + let a = A.mkArray @Double [3,3] xs + (q,_,_) = A.qr a + in closeList (A.toList (mm (tr q) q)) (A.toList (A.identity @Double [3,3])) + + -- SVD factors multiply back to the original: U * diag(S) * V^T = A. + prop "SVD: U diag(S) V^T = A" $ forAll (genMat 3) $ \xs -> + let a = A.mkArray @Double [3,3] xs + (u,s,vt) = A.svd a + sigma = A.diagCreate s 0 + in closeList (A.toList (mm (mm u sigma) vt)) (A.toList a) + + -- Cholesky factor reproduces a symmetric positive-definite matrix: + -- A = B^T B + 3I is SPD, and L*L^T = A. + prop "Cholesky: L*L^T = A (SPD)" $ forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + (status, l) = A.cholesky a False + in status == 0 && closeList (A.toList (mm l (tr l))) (A.toList a) + + describe "more decomposition properties" $ do + -- Singular values are all non-negative. + prop "SVD: singular values are non-negative" $ forAll (genMat 3) $ \xs -> + let a = A.mkArray @Double [3,3] xs + (_,s,_) = A.svd a + in all (>= -1e-12) (A.toList s) + + -- LU reconstruction: L * U = P * A where P is the pivot permutation. + -- ArrayFire's lu returns (L, U, piv) where piv is a pivot index vector. + -- We verify the simpler invariant that L is unit lower-triangular (diag=1). + prop "LU: L has unit diagonal" $ forAll (genMat 3) $ \xs -> + let a = A.mkArray @Double [3,3] xs + (l,_,_) = A.lu a + diag = [A.toList l !! (i * 3 + i) | i <- [0..2]] + in all (\d -> abs (d - 1.0) < 1e-9) diag + + -- det(A * B) ≈ det(A) * det(B) (multiplicativity of determinant) + prop "det(A*B) = det(A)*det(B)" $ forAll (genMat 3) $ \xs -> + forAll (genMat 3) $ \ys -> + let a = A.mkArray @Double [3,3] xs + b = A.mkArray @Double [3,3] ys + da = A.det a + db = A.det b + dab = A.det (mm a b) + expected = da * db + in abs (dab - expected) < 1e-6 + 1e-4 * abs expected + + -- inverse(inverse(A)) ≈ A for a well-conditioned matrix (B^T B + 3I is SPD). + prop "inverse is its own inverse (SPD input)" $ forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + ainv = A.inverse a A.None + ainv2 = A.inverse ainv A.None + in closeList (A.toList ainv2) (A.toList a) + + describe "pinverse" $ do + it "pinverse of a full-rank square matrix matches inverse" $ do + -- For an invertible matrix, pinverse should equal inverse. + let a = A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] + pinv = A.toList $ A.pinverse a 1e-6 A.None + inv = A.toList $ A.inverse a A.None + mapM_ (uncurry shouldBeApprox) (zip pinv inv) + + it "pinverse of a tall matrix satisfies pinv(A) * A ≈ I" $ do + -- For a full-column-rank matrix A (m x n, m >= n), pinv(A) * A = I_n. + let a = A.matrix @Double (3,2) [[1,2,3],[4,5,6]] + pinvA = A.pinverse a 1e-9 A.None + prod = mm pinvA a -- (2x3) * (3x2) = 2x2 identity + eye = A.identity @Double [2,2] + closeList (A.toList prod) (A.toList eye) `shouldBe` True + + prop "pinverse: A * pinv(A) * A ≈ A (full-rank square)" $ + forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + pinvA = A.pinverse a 1e-9 A.None + in closeList (A.toList (mm (mm a pinvA) a)) (A.toList a) + + prop "pinverse: pinv(A) * A * pinv(A) ≈ pinv(A) (full-rank square)" $ + forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + pinvA = A.pinverse a 1e-9 A.None + in closeList (A.toList (mm (mm pinvA a) pinvA)) (A.toList pinvA) + + describe "eigSH" $ do + -- Works on all backends: CUDA uses cuSOLVER, others use SVD fallback. + + it "returns correct eigenvalues for 2x2 symmetric matrix" $ do + -- A = [[3,1],[1,3]], eigenvalues 2 and 4 (ascending) + let a = A.matrix @Double (2,2) [[3,1],[1,3]] + (evals, _) = A.eigSH a + evList = A.toList evals + length evList `shouldBe` 2 + evList !! 0 `shouldBeApprox` 2.0 + evList !! 1 `shouldBeApprox` 4.0 + + it "returns orthonormal eigenvectors for 2x2 matrix" $ do + let a = A.matrix @Double (2,2) [[3,1],[1,3]] + (_, evecs) = A.eigSH a + vtv = A.toList $ mm (tr evecs) evecs + eye2 = A.toList (A.identity @Double [2,2]) + mapM_ (uncurry shouldBeApprox) (zip vtv eye2) + + it "reconstructs the original 2x2 matrix: V * diag(λ) * V^T = A" $ do + let a = A.matrix @Double (2,2) [[3,1],[1,3]] + (evals, evecs) = A.eigSH a + recon = mm (mm evecs (A.diagCreate evals 0)) (tr evecs) + mapM_ (uncurry shouldBeApprox) (zip (A.toList recon) (A.toList a)) + + it "returns eigenvalues in ascending order for 3x3 matrix" $ do + -- A = [[2,1,0],[1,2,1],[0,1,2]], eigenvalues 2-sqrt(2), 2, 2+sqrt(2) + let a = A.matrix @Double (3,3) [[2,1,0],[1,2,1],[0,1,2]] + (evals, _) = A.eigSH a + evList = A.toList evals + evList !! 0 `shouldBeApprox` (2 - sqrt 2) + evList !! 1 `shouldBeApprox` 2.0 + evList !! 2 `shouldBeApprox` (2 + sqrt 2) + + it "handles matrix with negative eigenvalues" $ do + -- A = [[0,1],[1,0]], eigenvalues -1 and +1 + let a = A.matrix @Double (2,2) [[0,1],[1,0]] + (evals, _) = A.eigSH a + evList = A.toList evals + evList !! 0 `shouldBeApprox` (-1.0) + evList !! 1 `shouldBeApprox` 1.0 + + prop "eigSH: V * diag(λ) * V^T = A (SPD input)" $ + forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + (evals, evecs) = A.eigSH a + recon = mm (mm evecs (A.diagCreate evals 0)) (tr evecs) + in closeList (A.toList recon) (A.toList a) + + prop "eigSH: V^T * V = I (eigenvectors are orthonormal)" $ + forAll (genMat 3) $ \xs -> + let b = A.mkArray @Double [3,3] xs + a = mm (tr b) b + A.mkArray @Double [3,3] [3,0,0, 0,3,0, 0,0,3] + (_, evecs) = A.eigSH a + in closeList (A.toList (mm (tr evecs) evecs)) + (A.toList (A.identity @Double [3,3])) + + describe "qrInPlace" $ do + it "qrInPlace on a 3x3 matrix returns a tau vector of length 3" $ do + let a = A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] + tau = A.qrInPlace a + A.getDims tau `shouldBe` (3,1,1,1) + it "qrInPlace on a 4x3 matrix returns a tau vector with min(rows,cols) elements" $ do + let a = A.mkArray @Double [4,3] [1..12] + tau = A.qrInPlace a + A.getDims tau `shouldBe` (3,1,1,1) + it "qrInPlace on a square matrix produces a non-empty tau array" $ do + let a = A.mkArray @Double [2,2] [1,2,3,4] + tau = A.qrInPlace a + A.getElements tau `shouldBe` 2 diff --git a/test/ArrayFire/NumericalSpec.hs b/test/ArrayFire/NumericalSpec.hs new file mode 100644 index 0000000..cbe63c0 --- /dev/null +++ b/test/ArrayFire/NumericalSpec.hs @@ -0,0 +1,127 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +-- | Numerical algorithm tests that exercise broad API surface area. +-- Each test has a known exact answer derived from mathematics, so failures +-- indicate either a bug in the library or a precision regression. +module ArrayFire.NumericalSpec where + +import qualified ArrayFire as A +import Data.Function ((&)) +import Test.Hspec +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (NonEmptyList (..)) + +tol :: Double +tol = 1e-4 + +shouldBeApprox :: Double -> Double -> Expectation +shouldBeApprox x y = abs (x - y) < tol `shouldBe` True + +spec :: Spec +spec = describe "Numerical algorithms" $ do + + -- ∫₀^π sin(x) dx = 2 (midpoint rectangle rule) + -- Exercises: arange, sin, sumAll, scalar, *, + + describe "Rectangle-rule integration" $ do + it "approximates integral of sin over [0,pi] = 2" $ do + let n = 10000 :: Int + h = pi / fromIntegral n + is = A.arange @Double [n] (-1) -- [0,1,...,n-1] + xs = (is + A.scalar 0.5) * A.scalar h -- midpoints + result = h * A.sumAll (sin xs) + result `shouldBeApprox` 2.0 + + -- Power iteration on A = [[2,1],[1,2]] + -- Exact dominant eigenvalue = 3, eigenvector = [1,1]/√2 + -- Exercises: matrix, matmul, sumAll, *, /, scalar, sqrt, Haskell iterate + describe "Power iteration" $ do + it "converges to dominant eigenvalue 3 of [[2,1],[1,2]]" $ do + let a = A.matrix @Double (2,2) [[2,1],[1,2]] + v0 = A.matrix @Double (2,1) [[1,1]] + norm2 v = sqrt @Double (A.sumAll (v * v)) + norm v = v / A.scalar (norm2 v) + step v = norm (A.matmul a v A.None A.None) + vFinal = iterate step (norm v0) !! 30 + av = A.matmul a vFinal A.None A.None + -- Rayleigh quotient: v^T A v + lambda = A.sumAll (vFinal * av) + lambda `shouldBeApprox` 3.0 + + -- Geometric series: Σ(k=0..19) 0.5^k = (1 - 0.5^20)/(1 - 0.5) + -- Exercises: arange, (**), sumAll, scalar + describe "Geometric series" $ do + it "sum of 0.5^k for k=0..19 matches closed form" $ do + let n = 20 :: Int + ks = A.arange @Double [n] (-1) + terms = A.scalar 0.5 ** ks + result = A.sumAll terms + expected = (1.0 - 0.5 ^ n) / (1.0 - 0.5) + result `shouldBeApprox` expected + + -- Centered-difference moving average on u = [1..10]: + -- avg_i = (u[i-1] + u[i+1]) / 2 for i = 1..8 + -- For an arithmetic sequence, this equals u[i] exactly. + -- Exercises: vector, (!), range, +, /, scalar + describe "Slice-based centered differences" $ do + it "moving average of arithmetic sequence equals interior values" $ do + let u = A.vector @Double 10 [1..10] + avg = (u A.! A.range 0 7 + u A.! A.range 2 9) / A.scalar 2.0 + avg `shouldBe` u A.! A.range 1 8 + + -- Slice assignment: overwrite interior of a zero vector. + -- Exercises: vector, &, (.~), !, range, toList + describe "Slice assignment" $ do + it "(.~) writes src into interior slice, leaves boundaries unchanged" $ do + let u = A.vector @Double 6 (repeat 0.0) + src = A.vector @Double 4 [1,2,3,4] + result = u & A.range 1 4 A..~ src + A.toList result `shouldBe` [0,1,2,3,4,0] + + -- Sample statistics of [1..100]. + -- mean([1..100]) = 50.5 (exact by Gauss's formula) + -- sum = n * mean must hold exactly. + -- Exercises: vector, meanAll, sumAll + describe "Statistical identities" $ do + it "mean of [1..100] = 50.5" $ do + A.meanAll (A.vector @Double 100 [1..100]) `shouldBeApprox` 50.5 + it "sumAll = n * meanAll" $ do + let arr = A.vector @Double 100 [1..100] + m = A.meanAll arr + s = A.sumAll arr + s `shouldBeApprox` (100 * m) + it "variance of a constant array is 0" $ do + A.varAll (A.vector @Double 50 (repeat 7.0)) A.Population `shouldBeApprox` 0.0 + + -- Sum of first n squares: Σ(k=1..n) k² = n(n+1)(2n+1)/6 + -- Exercises: iota, *, +, scalar, sumAll + describe "Sum of squares" $ do + it "Sigma k^2 for k=1..100 matches closed form n(n+1)(2n+1)/6" $ do + let n = 100 :: Int + ks = A.iota @Double [n] [] + A.scalar 1.0 -- [1,2,...,n] + result = A.sumAll (ks * ks) + expected = fromIntegral (n * (n+1) * (2*n+1)) / 6.0 + result `shouldBeApprox` expected + + -- Parseval's theorem: ||x||² = (1/N)||X||² where X = FFT(x) + -- Uses a complex Dirac delta: |x|² = 1, FFT is a flat spectrum |X[k]|² = 1 each. + -- Exercises: mkArray, fft, conjg, real, sumAll, * + describe "Parseval's theorem" $ do + it "time-domain and frequency-domain energies agree" $ do + let n = 64 :: Int + -- Dirac delta: all energy in first sample + xs = A.mkArray @(A.Complex Double) [n] (1 : repeat 0) + -- time-domain energy: Σ |x[k]|² = 1 + tEnergy = A.sumAll (A.real (xs * A.conjg xs) :: A.Array Double) + -- frequency-domain energy: (1/N) Σ |X[k]|² = (1/N)*N = 1 + xf = A.fft xs 1.0 n + fEnergy = (1.0 / fromIntegral n) * (A.sumAll (A.real (xf * A.conjg xf) :: A.Array Double)) + tEnergy `shouldBeApprox` 1.0 + tEnergy `shouldBeApprox` fEnergy + + describe "sumAll = n * meanAll (property)" $ do + prop "sumAll = n * meanAll for any non-empty list of Double" $ \(NonEmpty xs) -> + let n = length xs + arr = A.vector @Double n xs + s = A.sumAll arr + m = A.meanAll arr + in abs (s - fromIntegral n * m) < 1e-9 + 1e-6 * abs s diff --git a/test/ArrayFire/RandomSpec.hs b/test/ArrayFire/RandomSpec.hs index 926a9cf..eb3bf48 100644 --- a/test/ArrayFire/RandomSpec.hs +++ b/test/ArrayFire/RandomSpec.hs @@ -2,13 +2,13 @@ module ArrayFire.RandomSpec where import ArrayFire -import Control.Monad import Test.Hspec spec :: Spec -spec = - describe "Random engine spec" $ do +spec = describe "Random spec" $ do + + describe "random engine" $ do it "Should create random engine" $ do (`shouldBe` Philox) =<< getRandomEngineType @@ -27,4 +27,124 @@ spec = setSeed 100 (`shouldBe` 100) =<< getSeed + -- Reproducibility is the contract that makes randomness usable in tests and + -- science: a fixed seed must yield a fixed stream. + describe "seed reproducibility" $ do + + it "global setSeed makes randu reproducible" $ do + setSeed 1234 + a1 <- toList <$> randu @Float [256] + setSeed 1234 + a2 <- toList <$> randu @Float [256] + a2 `shouldBe` a1 + + it "global setSeed makes randn reproducible" $ do + setSeed 9876 + a1 <- toList <$> randn @Double [256] + setSeed 9876 + a2 <- toList <$> randn @Double [256] + a2 `shouldBe` a1 + + it "two engines with the same seed + type draw the same stream" $ do + e1 <- createRandomEngine 42 Philox + e2 <- createRandomEngine 42 Philox + a1 <- toList <$> randomUniform @Float [256] e1 + a2 <- toList <$> randomUniform @Float [256] e2 + a2 `shouldBe` a1 + + it "engines with different seeds draw different streams" $ do + e1 <- createRandomEngine 1 Philox + e2 <- createRandomEngine 2 Philox + a1 <- toList <$> randomUniform @Float [256] e1 + a2 <- toList <$> randomUniform @Float [256] e2 + a2 `shouldNotBe` a1 + + describe "distribution shape & range" $ do + + it "randu produces the requested dimensions" $ do + a <- randu @Float [3,4] + getDims a `shouldBe` (3,4,1,1) + + it "randn produces the requested dimensions" $ do + a <- randn @Double [5,2,3] + getDims a `shouldBe` (5,2,3,1) + + it "uniform draws lie in [0,1)" $ do + setSeed 7 + xs <- toList <$> randu @Float [4096] + xs `shouldSatisfy` all (\x -> x >= 0 && x < 1) + + describe "randomNormal" $ do + it "produces the requested dimensions" $ do + e <- getDefaultRandomEngine + a <- randomNormal @Double [3,4] e + getDims a `shouldBe` (3,4,1,1) + it "produces the right number of elements" $ do + e <- getDefaultRandomEngine + a <- randomNormal @Float [5,2] e + getElements a `shouldBe` 10 + + it "two engines with the same seed produce the same normal stream" $ do + e1 <- createRandomEngine 42 Philox + e2 <- createRandomEngine 42 Philox + a1 <- toList <$> randomNormal @Double [256] e1 + a2 <- toList <$> randomNormal @Double [256] e2 + a2 `shouldBe` a1 + it "engines with different seeds produce different normal streams" $ do + e1 <- createRandomEngine 1 Philox + e2 <- createRandomEngine 2 Philox + a1 <- toList <$> randomNormal @Double [256] e1 + a2 <- toList <$> randomNormal @Double [256] e2 + a2 `shouldNotBe` a1 + + describe "randomEngineSetSeed / randomEngineGetSeed" $ do + it "getSeed returns the seed supplied to createRandomEngine" $ do + e <- createRandomEngine 9999 Philox + s <- randomEngineGetSeed e + s `shouldBe` 9999 + it "setAndGet round-trip: seed is updated after randomEngineSetSeed" $ do + e <- createRandomEngine 1 Philox + randomEngineSetSeed e 12345 + s <- randomEngineGetSeed e + s `shouldBe` 12345 + it "different seeds produce different streams after randomEngineSetSeed" $ do + e <- createRandomEngine 1 Philox + randomEngineSetSeed e 100 + a1 <- toList <$> randomUniform @Float [64] e + randomEngineSetSeed e 200 + a2 <- toList <$> randomUniform @Float [64] e + a2 `shouldNotBe` a1 + it "same seed after reset produces the same stream" $ do + e <- createRandomEngine 1 Philox + randomEngineSetSeed e 777 + a1 <- toList <$> randomUniform @Float [64] e + randomEngineSetSeed e 777 + a2 <- toList <$> randomUniform @Float [64] e + a2 `shouldBe` a1 + + describe "retainRandomEngine" $ do + it "retained engine has the same type as the original" $ do + e <- createRandomEngine 42 Philox + e' <- retainRandomEngine e + getRandomEngineType e' `shouldReturn` Philox + it "retained handle shares state with original (both advance the same stream)" $ do + e <- createRandomEngine 42 Philox + e' <- retainRandomEngine e + a1 <- toList <$> randomUniform @Double [4] e + a2 <- toList <$> randomUniform @Double [4] e' + a2 `shouldNotBe` a1 + describe "setDefaultRandomEngineType" $ do + it "default engine type reflects the type that was set" $ do + setDefaultRandomEngineType ThreeFry + e <- getDefaultRandomEngine + getRandomEngineType e `shouldReturn` ThreeFry + it "switching type changes what getDefaultRandomEngine reports" $ do + setDefaultRandomEngineType Philox + e1 <- getDefaultRandomEngine + t1 <- getRandomEngineType e1 + setDefaultRandomEngineType Mersenne + e2 <- getDefaultRandomEngine + t2 <- getRandomEngineType e2 + t1 `shouldBe` Philox + t2 `shouldBe` Mersenne diff --git a/test/ArrayFire/SignalSpec.hs b/test/ArrayFire/SignalSpec.hs index 06b890e..16340b9 100644 --- a/test/ArrayFire/SignalSpec.hs +++ b/test/ArrayFire/SignalSpec.hs @@ -1,20 +1,333 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.SignalSpec where import qualified ArrayFire as A -import Data.Int -import Data.Word import Data.Complex -import Data.Proxy -import Foreign.C.Types import Test.Hspec +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (NonEmptyList (..), choose, forAll, vectorOf) + +-- | Check all elements of two Double arrays are within tolerance. +shouldBeApproxD + :: A.Array Double + -> A.Array Double + -> Expectation +shouldBeApproxD actual expected = + zipWith (\a e -> abs (a - e)) + (A.toList @Double actual) + (A.toList @Double expected) + `shouldSatisfy` all (< 1e-6) + +-- | Check all elements of two Complex Double arrays are within tolerance. +shouldBeApproxC + :: A.Array (Complex Double) + -> A.Array (Complex Double) + -> Expectation +shouldBeApproxC actual expected = + zipWith (\a e -> magnitude (a - e)) + (A.toList @(Complex Double) actual) + (A.toList @(Complex Double) expected) + `shouldSatisfy` all (< 1e-10) spec :: Spec spec = - describe "Signal spec" $ do - it "Should do FFT in place" $ do - A.fftInPlace (A.matrix @(Complex Double) (1,1) [[1 :+ 1]]) 10.2 - `shouldReturn` () - it "Should do FFT" $ do - A.fft (A.matrix @(Complex Float) (1,1) [[1 :+ 1]]) 1 1 - `shouldBe` A.matrix @(Complex Float) (1,1) [[1 :+ 1]] + describe "Signal" $ do + + describe "fft" $ do + it "fftInPlace runs without error" $ do + A.fftInPlace (A.scalar @(Complex Double) (1 :+ 0)) 1.0 + `shouldReturn` () + + it "transform of a Dirac delta is a flat spectrum" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,0,0,0]) 1.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [1,1,1,1] + + it "transform of all-ones concentrates all energy at DC" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,1,1,1]) 1.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [4,0,0,0] + + it "normalization factor scales the output" $ do + A.fft (A.mkArray @(Complex Double) [4] [1,0,0,0]) 2.0 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4] [2,2,2,2] + + it "ifft . fft is the identity" $ do + let n = 8 + input = A.mkArray @(Complex Double) [n] (map (:+ 0) [1..8]) + A.ifft (A.fft input 1.0 n) (1.0 / fromIntegral n) n + `shouldBeApproxC` input + + it "fft output_size pads with zeros when larger than input" $ do + -- 4-point FFT of a 2-point signal padded to 4: input [1,1,0,0] + A.fft (A.mkArray @(Complex Double) [2] [1,1]) 1.0 4 + `shouldBeApproxC` + A.fft (A.mkArray @(Complex Double) [4] [1,1,0,0]) 1.0 4 + + describe "fft2" $ do + it "2D transform of a Dirac delta is a flat spectrum" $ do + A.fft2 (A.mkArray @(Complex Double) [4,4] (1 : replicate 15 0)) 1.0 4 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4,4] (replicate 16 1) + + it "ifft2 . fft2 is the identity" $ do + let input = A.mkArray @(Complex Double) [4,4] (map (:+ 0) [1..16]) + A.ifft2 (A.fft2 input 1.0 4 4) (1.0 / 16) 4 4 + `shouldBeApproxC` input + + it "2D transform of all-ones concentrates all energy at DC" $ do + A.fft2 (A.mkArray @(Complex Double) [4,4] (replicate 16 1)) 1.0 4 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4,4] (16 : replicate 15 0) + + describe "fft2_inplace" $ do + it "runs without error" $ do + A.fft2_inplace (A.mkArray @(Complex Double) [4,4] (map (:+ 0) [1..16])) 1.0 + `shouldReturn` () + + describe "fft3" $ do + it "3D transform of a Dirac delta is a flat spectrum" $ do + A.fft3 (A.mkArray @(Complex Double) [4,4,4] (1 : replicate 63 0)) 1.0 4 4 4 + `shouldBeApproxC` + A.mkArray @(Complex Double) [4,4,4] (replicate 64 1) + + it "ifft3 . fft3 is the identity" $ do + let input = A.mkArray @(Complex Double) [4,4,4] (map (:+ 0) [1..64]) + A.ifft3 (A.fft3 input 1.0 4 4 4) (1.0 / 64) 4 4 4 + `shouldBeApproxC` input + + describe "fft3_inplace" $ do + it "runs without error" $ do + A.fft3_inplace (A.mkArray @(Complex Double) [4,4,4] (map (:+ 0) [1..64])) 1.0 + `shouldReturn` () + + describe "ifft_inplace" $ do + it "runs without error" $ do + A.ifft_inplace (A.mkArray @(Complex Double) [4] (map (:+ 0) [1..4])) 1.0 + `shouldReturn` () + + describe "ifft2_inplace" $ do + it "runs without error" $ do + A.ifft2_inplace (A.mkArray @(Complex Double) [4,4] (map (:+ 0) [1..16])) 1.0 + `shouldReturn` () + + describe "ifft3_inplace" $ do + it "runs without error" $ do + A.ifft3_inplace (A.mkArray @(Complex Double) [4,4,4] (map (:+ 0) [1..64])) 1.0 + `shouldReturn` () + + describe "fftr2c / fftc2r" $ do + it "fftr2c output has (n/2+1) complex elements" $ do + let n = 8 + out = A.fftr2c (A.mkArray @Double [n] [1..8]) 1.0 n + A.getElements out `shouldBe` (n `div` 2 + 1) + + it "fftc2r recovers even-length real signal" $ do + let n = 8 + inp = A.mkArray @Double [n] [1..8] + spec' = A.fftr2c inp 1.0 n + -- norm = 1/n so that r2c * c2r = identity + out = A.fftc2r spec' (1.0 / fromIntegral n) False + out `shouldBeApproxD` inp + + it "fft2r2c output first dim is (n/2+1)" $ do + let n = 8 + out = A.fft2r2c (A.mkArray @Double [n,n] (replicate (n*n) 1.0)) 1.0 n n + (d0, _, _, _) = A.getDims out + d0 `shouldBe` (n `div` 2 + 1) + + it "fft3r2c runs without error" $ do + let n = 4 + out = A.fft3r2c (A.mkArray @Double [n,n,n] (replicate (n*n*n) 1.0)) 1.0 n n n + A.getElements out `shouldSatisfy` (> 0) + + describe "approx1" $ do + it "matches docstring example with Cubic interpolation" $ do + let input = A.vector @Float 3 [10,20,30] + positions = A.vector @Float 5 [0.0, 0.5, 1.0, 1.5, 2.0] + result = A.approx1 input positions A.Cubic 0.0 + zipWith (\a e -> abs (a - e)) + (A.toList @Float result) + (A.toList @Float (A.mkArray @Float [5] [10.0, 13.75, 20.0, 26.25, 30.0])) + `shouldSatisfy` all (< 1e-4) + + it "Nearest interpolation returns nearest sample value" $ do + let input = A.vector @Float 3 [10,20,30] + positions = A.vector @Float 3 [0.0, 1.0, 2.0] + zipWith (\a e -> abs (a - e)) + (A.toList @Float (A.approx1 input positions A.Nearest 0.0)) + (A.toList @Float (A.mkArray @Float [3] [10.0, 20.0, 30.0])) + `shouldSatisfy` all (< 1e-4) + + it "out-of-bounds positions use the fill value" $ do + let input = A.vector @Double 3 [10,20,30] + positions = A.vector @Double 1 [-1.0] + A.approx1 input positions A.Linear 0.0 + `shouldBeApproxD` A.mkArray @Double [1] [0.0] + + describe "approx2" $ do + it "matches docstring example with Cubic interpolation" $ do + let input = A.matrix @Float (3,3) [[1,1,1],[2,2,2],[3,3,3]] + pos1 = A.matrix @Float (2,2) [[0.5,1.5],[0.5,1.5]] + pos2 = A.matrix @Float (2,2) [[0.5,0.5],[1.5,1.5]] + result = A.approx2 input pos1 pos2 A.Cubic 0.0 + zipWith (\a e -> abs (a - e)) + (A.toList @Float result) + (A.toList @Float (A.mkArray @Float [2,2] [1.375, 2.625, 1.375, 2.625])) + `shouldSatisfy` all (< 1e-4) + + describe "convolve1" $ do + it "convolving with unit delta is identity" $ do + let sig = A.mkArray @Double [5] [1,2,3,4,5] + delta = A.mkArray @Double [1] [1] + A.convolve1 sig delta A.ConvDefault A.ConvDomainSpatial + `shouldBeApproxD` sig + + it "ConvExpand output length is signal_len + filter_len - 1" $ do + let sig = A.mkArray @Double [5] [1,2,3,4,5] + flt = A.mkArray @Double [3] [1,0,0] + out = A.convolve1 sig flt A.ConvExpand A.ConvDomainSpatial + A.getElements out `shouldBe` 7 + + it "ConvDomainAuto matches ConvDomainSpatial result" $ do + let sig = A.mkArray @Double [8] [1,2,3,4,5,6,7,8] + flt = A.mkArray @Double [3] [1,2,1] + A.convolve1 sig flt A.ConvDefault A.ConvDomainAuto + `shouldBeApproxD` + A.convolve1 sig flt A.ConvDefault A.ConvDomainSpatial + + describe "convolve2" $ do + it "convolving with unit 2D delta is identity" $ do + let img = A.mkArray @Double [4,4] [1..16] + delta = A.mkArray @Double [1,1] [1] + A.convolve2 img delta A.ConvDefault A.ConvDomainSpatial + `shouldBeApproxD` img + + describe "convolve2Sep" $ do + it "separable convolution matches full 2D convolution with outer-product kernel" $ do + let img = A.mkArray @Double [4,4] [1..16] + colF = A.mkArray @Double [1] [1] + rowF = A.mkArray @Double [1] [1] + A.convolve2Sep colF rowF img A.ConvDefault + `shouldBeApproxD` img + + describe "fftConvolve2" $ do + it "result matches spatial convolve2 for a simple kernel" $ do + let img = A.mkArray @Double [8,8] [1..64] + flt = A.mkArray @Double [3,3] [0,0,0, 0,1,0, 0,0,0] + A.fftConvolve2 img flt A.ConvDefault + `shouldBeApproxD` + A.convolve2 img flt A.ConvDefault A.ConvDomainSpatial + + describe "fir" $ do + it "passthrough filter (b=[1]) returns input unchanged" $ do + let sig = A.mkArray @Double [5] [1,2,3,4,5] + b = A.mkArray @Double [1] [1] + A.fir b sig `shouldBeApproxD` sig + + describe "iir" $ do + it "all-feedforward / no-feedback is equivalent to FIR" $ do + let sig = A.mkArray @Double [5] [1,2,3,4,5] + b = A.mkArray @Double [1] [1] + a = A.mkArray @Double [1] [1] + A.iir b a sig `shouldBeApproxD` sig + + describe "medFilt1" $ do + it "constant signal is unchanged by any kernel" $ do + let sig = A.mkArray @Double [7] (replicate 7 3.0) + A.medFilt1 sig 3 A.PadZero `shouldBeApproxD` sig + + describe "medFilt2" $ do + it "constant image is unchanged by any kernel" $ do + let img = A.mkArray @Double [5,5] (replicate 25 7.0) + A.medFilt2 img 3 3 A.PadSym `shouldBeApproxD` img + + describe "convolve3" $ do + it "convolving with unit 3D delta is identity" $ do + let vol = A.mkArray @Double [4,4,4] [1..64] + delta = A.mkArray @Double [1,1,1] [1] + A.convolve3 vol delta A.ConvDefault A.ConvDomainSpatial + `shouldBeApproxD` vol + + describe "fft2C2r" $ do + it "fft2r2c . fft2C2r is the identity for an even-size 2D signal" $ do + let n = 8 + inp = A.mkArray @Double [n,n] [1..fromIntegral (n*n)] + c2r = A.fft2C2r (A.fft2r2c inp 1.0 n n) (1.0 / fromIntegral (n*n)) False + c2r `shouldBeApproxD` inp + + describe "fft3C2r" $ do + it "fft3r2c . fft3C2r is the identity for an even-size 3D signal" $ do + let n = 4 + inp = A.mkArray @Double [n,n,n] [1..fromIntegral (n*n*n)] + c2r = A.fft3C2r (A.fft3r2c inp 1.0 n n n) (1.0 / fromIntegral (n*n*n)) False + c2r `shouldBeApproxD` inp + + describe "setFFTPlanCacheSize" $ do + it "runs without error" $ do + A.setFFTPlanCacheSize 4 `shouldReturn` () + + describe "FFT properties" $ do + -- ifft . fft = id for arbitrary complex signals of power-of-2 length + prop "ifft . fft = id (arbitrary complex signal)" $ + forAll (choose (1 :: Int, 6)) $ \k -> + forAll (vectorOf (2^k) (choose (-10, 10 :: Double))) $ \xs -> + let n = 2^k + input = A.mkArray @(A.Complex Double) [n] (map (:+ 0) xs) + out = A.ifft (A.fft input 1.0 n) (1.0 / fromIntegral n) n + in zipWith (\a e -> magnitude (a - e)) + (A.toList @(A.Complex Double) out) + (A.toList @(A.Complex Double) input) + `shouldSatisfy` all (< 1e-9) + + -- FFT linearity: fft(a + b) = fft(a) + fft(b) + prop "fft is linear: fft(a+b) = fft(a) + fft(b)" $ + forAll (choose (1 :: Int, 5)) $ \k -> + forAll (vectorOf (2^k) (choose (-5, 5 :: Double))) $ \as_ -> + forAll (vectorOf (2^k) (choose (-5, 5 :: Double))) $ \bs_ -> + let n = 2^k + a = A.mkArray @(A.Complex Double) [n] (map (:+ 0) as_) + b = A.mkArray @(A.Complex Double) [n] (map (:+ 0) bs_) + lhs = A.toList @(A.Complex Double) (A.fft (a + b) 1.0 n) + rhs = zipWith (+) + (A.toList @(A.Complex Double) (A.fft a 1.0 n)) + (A.toList @(A.Complex Double) (A.fft b 1.0 n)) + in zipWith (\l r -> magnitude (l - r)) lhs rhs + `shouldSatisfy` all (< 1e-9) + + -- Parseval's theorem: ||x||^2 = (1/N) * ||X||^2 + prop "Parseval's theorem holds for arbitrary signals" $ + forAll (choose (1 :: Int, 6)) $ \k -> + forAll (vectorOf (2^k) (choose (-10, 10 :: Double))) $ \xs -> + let n = 2^k + input = A.mkArray @(A.Complex Double) [n] (map (:+ 0) xs) + tEnergy = sum (map (\x -> x*x) xs) + xf = A.fft input 1.0 n + fEnergy = (1.0 / fromIntegral n) * + sum (map (\c -> realPart c * realPart c + imagPart c * imagPart c) + (A.toList @(A.Complex Double) xf)) + in abs (tEnergy - fEnergy) < 1e-6 + 1e-6 * abs tEnergy + + -- convolve1 with unit delta is identity for arbitrary signals + prop "convolve1 with unit delta is identity" $ \(NonEmpty xs) -> + let sig = A.mkArray @Double [length xs] xs + delta = A.mkArray @Double [1] [1] + out = A.convolve1 sig delta A.ConvDefault A.ConvDomainSpatial + in zipWith (\a e -> abs (a - e)) + (A.toList @Double out) + (A.toList @Double sig) + `shouldSatisfy` all (< 1e-9) + + -- fftr2c . fftc2r round-trip for arbitrary even-length real signals + prop "fftc2r . fftr2c = id for even-length real signals" $ + forAll (choose (1 :: Int, 5)) $ \k -> + forAll (vectorOf (2^k) (choose (-10, 10 :: Double))) $ \xs -> + let n = 2^k + inp = A.mkArray @Double [n] xs + out = A.fftc2r (A.fftr2c inp 1.0 n) (1.0 / fromIntegral n) False + in zipWith (\a e -> abs (a - e)) + (A.toList @Double out) + xs + `shouldSatisfy` all (< 1e-9) diff --git a/test/ArrayFire/SparseSpec.hs b/test/ArrayFire/SparseSpec.hs index b90c931..f636b85 100644 --- a/test/ArrayFire/SparseSpec.hs +++ b/test/ArrayFire/SparseSpec.hs @@ -1,19 +1,92 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.SparseSpec where -import qualified ArrayFire as A +import qualified ArrayFire as A import Data.Int -import Data.Word -import Data.Complex -import Data.Proxy -import Foreign.C.Types import Test.Hspec +-- 3×3 diagonal matrix diag(1,2,3), stored column-major: +-- col0=[1,0,0], col1=[0,2,0], col2=[0,0,3] +diag3 :: A.Array Double +diag3 = A.mkArray @Double [3,3] [1,0,0, 0,2,0, 0,0,3] + spec :: Spec spec = - describe "Sparse spec" $ do - it "Should create a sparse array" $ do - (1+1) `shouldBe` 2 - -- A.createSparseArrayFromDense (A.matrix @Double (10,10) [1..]) A.CSR - -- `shouldBe` - -- A.vector @Double 10 [0..] + describe "Sparse" $ do + + describe "createSparseArrayFromDense" $ do + it "NNZ equals number of non-zero elements" $ do + A.sparseGetNNZ (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` 3 + it "all-zero matrix has NNZ 0" $ do + let zeros = A.mkArray @Double [3,3] (repeat 0) + A.sparseGetNNZ (A.createSparseArrayFromDense zeros A.CSR) `shouldBe` 0 + it "fully-dense matrix has NNZ equal to element count" $ do + let full = A.mkArray @Double [2,2] [1,2,3,4] + A.sparseGetNNZ (A.createSparseArrayFromDense full A.CSR) `shouldBe` 4 + it "storage format is preserved" $ do + A.sparseGetStorage (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` A.CSR + it "COO storage format is preserved" $ do + A.sparseGetStorage (A.createSparseArrayFromDense diag3 A.COO) `shouldBe` A.COO + + describe "sparseToDense" $ do + it "CSR round-trip preserves all values" $ do + A.sparseToDense (A.createSparseArrayFromDense diag3 A.CSR) `shouldBe` diag3 + it "COO round-trip preserves all values" $ do + A.sparseToDense (A.createSparseArrayFromDense diag3 A.COO) `shouldBe` diag3 + + describe "sparseConvertTo" $ do + it "CSR → COO preserves NNZ" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseGetNNZ coo `shouldBe` 3 + it "CSR → COO storage tag changes" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseGetStorage coo `shouldBe` A.COO + it "CSR → COO → Dense recovers original matrix" $ do + let coo = A.sparseConvertTo (A.createSparseArrayFromDense diag3 A.CSR) A.COO + A.sparseToDense coo `shouldBe` diag3 + + describe "sparseGetValues" $ do + it "diagonal matrix CSR values are the diagonal entries in row order" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.sparseGetValues sp `shouldBe` A.vector @Double 3 [1,2,3] + + describe "sparseGetRowIdx / sparseGetColIdx" $ do + -- The underlying arrays are s32; we check length, not raw values. + it "CSR row pointer array has nrows+1 elements" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.getElements (A.sparseGetRowIdx sp) `shouldBe` 4 + it "CSR column index array has NNZ elements" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + A.getElements (A.sparseGetColIdx sp) `shouldBe` 3 + + describe "sparseGetInfo" $ do + it "values component matches sparseGetValues" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + (vals, _, _, _) = A.sparseGetInfo sp + vals `shouldBe` A.sparseGetValues sp + it "storage tag matches sparseGetStorage" $ do + let sp = A.createSparseArrayFromDense diag3 A.CSR + (_, _, _, storage) = A.sparseGetInfo sp + storage `shouldBe` A.sparseGetStorage sp + + describe "createSparseArray" $ do + -- Build a 3x3 diagonal sparse matrix directly from COO components: + -- values = [1,2,3], rowIdx = [0,1,2], colIdx = [0,1,2] + it "NNZ equals length of supplied values array" $ do + let vals = A.vector @Double 3 [1,2,3] + rowIdx = A.vector @Int32 3 [0,1,2] + colIdx = A.vector @Int32 3 [0,1,2] + sp = A.createSparseArray 3 3 vals rowIdx colIdx A.COO + A.sparseGetNNZ sp `shouldBe` 3 + it "storage format matches the requested format" $ do + let vals = A.vector @Double 3 [1,2,3] + rowIdx = A.vector @Int32 3 [0,1,2] + colIdx = A.vector @Int32 3 [0,1,2] + sp = A.createSparseArray 3 3 vals rowIdx colIdx A.COO + A.sparseGetStorage sp `shouldBe` A.COO + it "converting to dense recovers the diagonal matrix" $ do + let vals = A.vector @Double 3 [1,2,3] + rowIdx = A.vector @Int32 3 [0,1,2] + colIdx = A.vector @Int32 3 [0,1,2] + sp = A.createSparseArray 3 3 vals rowIdx colIdx A.COO + A.sparseToDense sp `shouldBe` diag3 diff --git a/test/ArrayFire/StatisticsSpec.hs b/test/ArrayFire/StatisticsSpec.hs index c8c6314..f6987bf 100644 --- a/test/ArrayFire/StatisticsSpec.hs +++ b/test/ArrayFire/StatisticsSpec.hs @@ -1,11 +1,16 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.StatisticsSpec where -import ArrayFire hiding (not) +import Data.Word (Word32) +import ArrayFire hiding (not, abs, isNaN) +import Data.Maybe import Data.Complex import Test.Hspec import Test.Hspec.ApproxExpect +import Test.Hspec.QuickCheck (prop) +import Test.QuickCheck (NonEmptyList (..), (==>)) spec :: Spec spec = @@ -15,17 +20,20 @@ spec = `shouldBe` 5.5 it "Should find the weighted-mean" $ do - meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 - `shouldBeApprox` - 7.0 + listToMaybe (toList (meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0)) + `shouldBe` + (Just 7.0) it "Should find the variance" $ do - var (vector @Double 8 [1..8]) False 0 + var (vector @Double 8 [1..8]) Population 0 `shouldBe` 5.25 - it "Should find the weighted variance" $ do + it "Should find the weighted variance (equal weights)" $ do varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0 `shouldBe` 5.25 + it "Should find the weighted variance (increasing weights)" $ do + head (toList (varWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0)) + `shouldBeApprox` (21/11 :: Double) it "Should find the standard deviation" $ do stdev (vector @Double 10 (cycle [1,-1])) 0 `shouldBe` @@ -39,34 +47,92 @@ spec = `shouldBe` 5.5 it "Should find the mean of all elements across all dimensions" $ do - fst (meanAll (matrix @Double (2,2) [[10,10],[10,10]])) - `shouldBe` - 10 + meanAll (matrix @Double (2,2) [[10,10],[10,10]]) + `shouldBe` 10 it "Should find the weighted mean of all elements across all dimensions" $ do - fst (meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]])) - `shouldBe` - 10 + meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]]) + `shouldBe` 10 it "Should find the variance of all elements across all dimensions" $ do - fst (varAll (vector @Double 10 (repeat 10)) False) - `shouldBe` - 0 + varAll (vector @Double 10 (repeat 10)) Population + `shouldBe` 0 it "Should find the weighted variance of all elements across all dimensions" $ do - fst (varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10))) - `shouldBe` - 0 + varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10)) + `shouldBe` 0 it "Should find the stdev of all elements across all dimensions" $ do - fst (stdevAll (vector @Double 10 (repeat 10))) - `shouldBe` - 0 + stdevAll (vector @Double 10 (repeat 10)) + `shouldBe` 0 it "Should find the median of all elements across all dimensions" $ do - fst (medianAll (vector @Double 10 [1..])) - `shouldBe` - 5.5 + medianAll (vector @Double 10 [1..]) + `shouldBe` 5.5 it "Should find the correlation coefficient" $ do - fst (corrCoef (vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] )) - `shouldBe` - (-1.0) + corrCoef (vector @Int 10 [1..]) (vector @Int 10 [10,9..]) + `shouldBe` (-1.0) it "Should find the top k elements" $ do let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault vals `shouldBe` vector @Double 3 [10,9,8] - indexes `shouldBe` vector @Double 3 [9,8,7] + indexes `shouldBe` vector @Word32 3 [9,8,7] + it "Should compute mean and variance together (population)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 + it "Should compute mean and variance together (sample)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VarianceSample 0 + m `shouldBe` scalar @Double 2.5 + -- sample variance of [1,2,3,4] = 5/3 ≈ 1.6667 + case listToMaybe (toList v) of + Just k -> k `shouldBeApprox` (5.0/3.0) + _ -> error "failure" + it "Should compute weighted mean and variance together" $ do + let uniform = vector @Double 4 (repeat 1.0) + (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) uniform VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 + + describe "statistical properties" $ do + -- mean(x + c) = mean(x) + c (translation equivariance) + prop "mean is translation-equivariant" $ \(NonEmpty xs) (c :: Double) -> + let n = length xs + arr = vector @Double n xs + lhs = meanAll (arr + scalar c) + rhs = meanAll arr + c + in abs (lhs - rhs) < 1e-9 + + -- var(x + c) = var(x) (translation invariance) + prop "variance is translation-invariant" $ \(NonEmpty xs) (c :: Double) -> + let n = length xs + arr = vector @Double n xs + lhs = varAll arr Population + rhs = varAll (arr + scalar c) Population + in abs (lhs - rhs) < 1e-6 * (1 + abs lhs) + + -- stdev(x)^2 = var(x, Population) (consistency) + prop "stdev^2 equals population variance" $ \(NonEmpty xs) -> + let n = length xs + arr = vector @Double n xs + sd = stdevAll arr + v = varAll arr Population + in abs (sd * sd - v) < 1e-9 + 1e-6 * abs v + + -- mean(c * x) = c * mean(x) (scale equivariance) + prop "mean scales linearly" $ \(NonEmpty xs) (c :: Double) -> + let n = length xs + arr = vector @Double n xs + lhs = meanAll (scalar c * arr) + rhs = c * meanAll arr + in abs (lhs - rhs) < 1e-9 + 1e-9 * abs rhs + + -- corrCoef(x, y) is in [-1, 1] (Cauchy-Schwarz) + prop "corrCoef is in [-1, 1]" $ \(NonEmpty xs) (ys :: [Double]) -> + let n = length xs + arr1 = vector @Double n xs + arr2 = vector @Double n (take n (ys ++ repeat 0)) + r = corrCoef arr1 arr2 + in not (isNaN r) ==> r >= -1.0 - 1e-9 && r <= 1.0 + 1e-9 + + -- sumAll = n * meanAll (for any non-empty list) + prop "sumAll = n * meanAll" $ \(NonEmpty xs) -> + let n = length xs + arr = vector @Double n xs + s = sumAll arr + m = meanAll arr + in abs (s - fromIntegral n * m) < 1e-9 + 1e-6 * abs s diff --git a/test/ArrayFire/VisionSpec.hs b/test/ArrayFire/VisionSpec.hs index 82bddc1..71978c5 100644 --- a/test/ArrayFire/VisionSpec.hs +++ b/test/ArrayFire/VisionSpec.hs @@ -1,14 +1,267 @@ -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module ArrayFire.VisionSpec where -import qualified ArrayFire as A +import qualified ArrayFire as A +import Control.Exception (SomeException, evaluate, try) +import Control.Monad (when) import Test.Hspec +-- | 100×100 constant-intensity Float image. No edges or corners. +-- FAST / Harris / SUSAN must produce 0 features on this image. +flatImg :: A.Array Float +flatImg = A.constant @Float [100, 100] 0.5 + +-- | 100×100 image composed of four 50×50 quadrants with alternating +-- intensities (0.0 / 1.0), creating a strong corner at the centre. +quadrantImg :: A.Array Float +quadrantImg = + let tl = A.constant @Float [50, 50] 0.0 + tr = A.constant @Float [50, 50] 1.0 + bl = A.constant @Float [50, 50] 1.0 + br = A.constant @Float [50, 50] 0.0 + in A.join 0 (A.join 1 tl tr) (A.join 1 bl br) + +xpos, ypos, score, orient, size_ :: A.Features -> A.Array Float +xpos = A.getFeaturesXPos +ypos = A.getFeaturesYPos +score = A.getFeaturesScore +orient = A.getFeaturesOrientation +size_ = A.getFeaturesSize + spec :: Spec -spec = - describe "Vision spec" $ do - it "Should construct Features for fast feature detection" $ do - let arr = A.vector @Int 30000 [1..] - let feats = A.fast arr 1.0 9 False 1.0 3 - (1 + 1) `shouldBe` 2 +spec = describe "Vision spec" $ do + + -- ------------------------------------------------------------------ -- + -- FAST + -- ------------------------------------------------------------------ -- + describe "fast" $ do + it "detects 0 features on a flat image" $ + A.getFeaturesNum (A.fast flatImg 0.05 9 False 1.0 3) `shouldBe` 0 + + it "all accessor arrays are consistent with getFeaturesNum" $ do + let feats = A.fast quadrantImg 0.1 9 False 1.0 3 + n = A.getFeaturesNum feats + A.getElements (xpos feats) `shouldBe` n + A.getElements (ypos feats) `shouldBe` n + A.getElements (score feats) `shouldBe` n + A.getElements (orient feats) `shouldBe` n + A.getElements (size_ feats) `shouldBe` n + + it "detected x-coordinates lie in [0, 100)" $ do + let feats = A.fast quadrantImg 0.1 9 False 1.0 3 + A.toList (xpos feats) `shouldSatisfy` all (\x -> x >= (0 :: Float) && x < 100) + + it "detected y-coordinates lie in [0, 100)" $ do + let feats = A.fast quadrantImg 0.1 9 False 1.0 3 + A.toList (ypos feats) `shouldSatisfy` all (\y -> y >= (0 :: Float) && y < 100) + + it "all feature scores are non-negative" $ do + let feats = A.fast quadrantImg 0.1 9 False 1.0 3 + A.toList (score feats) `shouldSatisfy` all (>= (0 :: Float)) + + -- ------------------------------------------------------------------ -- + -- Harris + -- ------------------------------------------------------------------ -- + describe "harris" $ do + it "detects 0 corners on a flat image" $ + A.getFeaturesNum (A.harris flatImg 500 1e-3 1.0 0 0.04) `shouldBe` 0 + + it "all accessor arrays are consistent with getFeaturesNum" $ do + let feats = A.harris quadrantImg 500 1e-3 1.0 0 0.04 + n = A.getFeaturesNum feats + A.getElements (xpos feats) `shouldBe` n + A.getElements (ypos feats) `shouldBe` n + A.getElements (score feats) `shouldBe` n + + it "detected x-coordinates lie in [0, 100)" $ do + let feats = A.harris quadrantImg 500 1e-3 1.0 0 0.04 + A.toList (xpos feats) `shouldSatisfy` all (\x -> x >= (0 :: Float) && x < 100) + + it "detected y-coordinates lie in [0, 100)" $ do + let feats = A.harris quadrantImg 500 1e-3 1.0 0 0.04 + A.toList (ypos feats) `shouldSatisfy` all (\y -> y >= (0 :: Float) && y < 100) + + -- ------------------------------------------------------------------ -- + -- ORB + -- ------------------------------------------------------------------ -- + describe "orb" $ do + it "descriptor row count equals getFeaturesNum" $ do + let (feats, descs) = A.orb quadrantImg 0.1 500 1.5 4 False + n = A.getFeaturesNum feats + (d0, _, _, _) = A.getDims (descs :: A.Array Float) + d0 `shouldBe` n + + it "all coordinate arrays are consistent with getFeaturesNum" $ do + let (feats, _) = A.orb quadrantImg 0.1 500 1.5 4 False + n = A.getFeaturesNum feats + A.getElements (xpos feats) `shouldBe` n + A.getElements (ypos feats) `shouldBe` n + A.getElements (score feats) `shouldBe` n + A.getElements (orient feats) `shouldBe` n + A.getElements (size_ feats) `shouldBe` n + + -- ------------------------------------------------------------------ -- + -- SUSAN + -- ------------------------------------------------------------------ -- + describe "susan" $ do + it "detects 0 corners on a flat image" $ + A.getFeaturesNum (A.susan flatImg 3 0.1 0.5 0.05 3) `shouldBe` 0 + + it "all accessor arrays are consistent with getFeaturesNum" $ do + let feats = A.susan quadrantImg 3 0.1 0.5 0.05 3 + n = A.getFeaturesNum feats + A.getElements (xpos feats) `shouldBe` n + A.getElements (ypos feats) `shouldBe` n + A.getElements (score feats) `shouldBe` n + + it "detected x-coordinates lie in [0, 100)" $ do + let feats = A.susan quadrantImg 3 0.1 0.5 0.05 3 + A.toList (xpos feats) `shouldSatisfy` all (\x -> x >= (0 :: Float) && x < 100) + + -- ------------------------------------------------------------------ -- + -- Difference of Gaussians + -- ------------------------------------------------------------------ -- + describe "dog" $ do + it "output has the same dimensions as the input image" $ + A.getDims (A.dog flatImg 1 2) `shouldBe` (100, 100, 1, 1) + + it "DoG of a constant image has zero interior values" $ do + -- Border pixels are non-zero due to Gaussian zero-padding; the interior + -- (at least 2 pixels from each edge for kernel radius=2) must be zero. + let result = A.dog (A.constant @Float [20, 20] 0.5) 1 2 + interior = result A.! (A.range 2 17, A.range 2 17) + A.toList @Float interior `shouldSatisfy` all (\v -> abs v < 1e-5) + + it "different radii produce different results on a non-constant image" $ do + let dog12 = A.dog quadrantImg 1 2 + dog13 = A.dog quadrantImg 1 3 + (dog12 == dog13) `shouldBe` False + + -- ------------------------------------------------------------------ -- + -- matchTemplate + -- ------------------------------------------------------------------ -- + describe "matchTemplate" $ do + it "output has the same dimensions as the search image" $ do + let img = A.constant @Float [20, 20] 1.0 + tmpl = A.constant @Float [5, 5] 1.0 + A.getDims (A.matchTemplate img tmpl A.MatchTypeSAD) `shouldBe` (20, 20, 1, 1) + + it "SAD of a zero image against a zero template is zero everywhere" $ do + let img = A.constant @Float [10, 10] 0.0 + tmpl = A.constant @Float [3, 3] 0.0 + result = A.matchTemplate img tmpl A.MatchTypeSAD + A.toList @Float result `shouldSatisfy` all (< 1e-5) + + it "SSD of a zero image against a zero template is zero everywhere" $ do + let img = A.constant @Float [10, 10] 0.0 + tmpl = A.constant @Float [3, 3] 0.0 + result = A.matchTemplate img tmpl A.MatchTypeSSD + A.toList @Float result `shouldSatisfy` all (< 1e-5) + + -- ------------------------------------------------------------------ -- + -- hammingMatcher + -- ------------------------------------------------------------------ -- + describe "hammingMatcher" $ do + it "identical descriptors produce 0 Hamming distances" $ do + -- 4 features, each 4 uint32 components; dim 0 = feature length + let desc = A.mkArray @A.Word32 [4, 4] (replicate 16 0xDEADBEEF) + (_idxs, dists) = A.hammingMatcher desc desc 0 1 + A.toList @A.Word32 dists `shouldBe` replicate 4 0 + + it "result arrays have one entry per query feature (n_dist = 1)" $ do + let query = A.mkArray @A.Word32 [4, 3] (replicate 12 0x00000000) + train = A.mkArray @A.Word32 [4, 5] (replicate 20 0xFFFFFFFF) + (idxs, dists) = A.hammingMatcher query train 0 1 + A.getElements @A.Word32 idxs `shouldBe` 3 + A.getElements @A.Word32 dists `shouldBe` 3 + + it "returned indices are within training-set bounds" $ do + let query = A.mkArray @A.Word32 [4, 3] (replicate 12 0x00000000) + train = A.mkArray @A.Word32 [4, 5] (replicate 20 0x00000000) + (idxs, _dists) = A.hammingMatcher query train 0 1 + A.toList @A.Word32 idxs `shouldSatisfy` all (< 5) + + -- ------------------------------------------------------------------ -- + -- nearestNeighbor + -- ------------------------------------------------------------------ -- + describe "nearestNeighbor" $ do + it "identical descriptors produce 0 SAD distances" $ do + let desc = A.mkArray @Float [4, 4] (replicate 16 1.0) + (_idxs, dists) = A.nearestNeighbor desc desc 0 1 A.MatchTypeSAD + A.toList @Float dists `shouldBe` replicate 4 0.0 + + it "identical descriptors produce 0 SSD distances" $ do + let desc = A.mkArray @Float [4, 4] (replicate 16 1.0) + (_idxs, dists) = A.nearestNeighbor desc desc 0 1 A.MatchTypeSSD + A.toList @Float dists `shouldBe` replicate 4 0.0 + + it "result count matches number of query features" $ do + let query = A.mkArray @Float [4, 3] (replicate 12 0.0) + train = A.mkArray @Float [4, 5] (replicate 20 1.0) + (idxs, dists) = A.nearestNeighbor query train 0 1 A.MatchTypeSAD + A.getElements @Float idxs `shouldBe` 3 + A.getElements @Float dists `shouldBe` 3 + + it "returned indices are within training-set bounds" $ do + let query = A.mkArray @Float [4, 3] (replicate 12 0.0) + train = A.mkArray @Float [4, 5] (replicate 20 1.0) + (idxs, _) = A.nearestNeighbor query train 0 1 A.MatchTypeSAD + A.toList @Float idxs `shouldSatisfy` all (< 5) + + -- ------------------------------------------------------------------ -- + -- homography + -- ------------------------------------------------------------------ -- + describe "homography" $ do + it "returns a 3×3 homography matrix" $ do + -- 4 exact correspondences: unit square → 2× scaled square + let sx = A.vector @Float 4 [0, 1, 0, 1] + sy = A.vector @Float 4 [0, 0, 1, 1] + dx = A.vector @Float 4 [0, 2, 0, 2] + dy = A.vector @Float 4 [0, 0, 2, 2] + (_, h) = A.homography sx sy dx dy A.RANSAC 1.0 1000 + A.getDims h `shouldBe` (3, 3, 1, 1) + + it "inlier count is non-negative" $ do + let sx = A.vector @Float 4 [0, 1, 0, 1] + sy = A.vector @Float 4 [0, 0, 1, 1] + (inliers, _) = A.homography sx sy sx sy A.RANSAC 1.0 1000 + inliers `shouldSatisfy` (>= 0) + + it "identity correspondences yield at least 4 inliers" $ do + let sx = A.vector @Float 4 [0, 1, 0, 1] + sy = A.vector @Float 4 [0, 0, 1, 1] + (inliers, _) = A.homography sx sy sx sy A.RANSAC 10.0 1000 + inliers `shouldSatisfy` (>= 4) + + -- ------------------------------------------------------------------ -- + -- SIFT (may not be compiled into every ArrayFire build) + -- ------------------------------------------------------------------ -- + describe "sift" $ do + it "descriptor row count equals getFeaturesNum; width is 128 when features found" $ do + result <- try $ evaluate $ + A.sift quadrantImg 3 0.04 10.0 1.6 False (1.0 / 256.0) 0.05 + case (result :: Either SomeException (A.Features, A.Array Float)) of + Left _ -> pendingWith "SIFT not available in this ArrayFire build" + Right (feats, descs) -> do + let n = A.getFeaturesNum feats + (d0, d1, _, _) = A.getDims descs + d0 `shouldBe` n + -- AF returns (0,0) when no features are found rather than (0,128), + -- so only assert the column width when at least one feature exists. + when (n > 0) $ d1 `shouldBe` 128 + -- ------------------------------------------------------------------ -- + -- GLOH (may not be compiled into every ArrayFire build) + -- ------------------------------------------------------------------ -- + describe "gloh" $ do + it "descriptor row count equals getFeaturesNum; width is 272 when features found" $ do + result <- try $ evaluate $ + A.gloh quadrantImg 3 0.04 10.0 1.6 False (1.0 / 256.0) 0.05 + case (result :: Either SomeException (A.Features, A.Array Float)) of + Left _ -> pendingWith "GLOH not available in this ArrayFire build" + Right (feats, descs) -> do + let n = A.getFeaturesNum feats + (d0, d1, _, _) = A.getDims descs + d0 `shouldBe` n + when (n > 0) $ d1 `shouldBe` 272 diff --git a/test/Main.hs b/test/Main.hs index c949527..979a97d 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,11 +1,15 @@ -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE GeneralisedNewtypeDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Main where -import Control.Monad - +import Prelude hiding (negate) +import Control.Monad (forM_, unless) +import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.Proxy +import Data.Semiring (Semiring (..), Ring (..)) import Spec (spec) +import System.Exit (exitFailure) import Test.Hspec (hspec) import Test.QuickCheck import Test.QuickCheck.Classes @@ -13,32 +17,102 @@ import Test.QuickCheck.Classes import qualified ArrayFire as A import ArrayFire (Array) -import System.IO.Unsafe +import Foreign.C.Types (CBool (..)) +-- Multi-dimensional arrays: used for eqLaws, so the Eq instance is exercised +-- on matrices and tensors, not just scalars. instance (A.AFType a, Arbitrary a) => Arbitrary (Array a) where - arbitrary = pure $ unsafePerformIO (A.randu [2,2]) + arbitrary = do + ndim <- choose (1, 4) + dims <- vectorOf ndim (choose (1, 4)) + elems <- vectorOf (product dims) arbitrary + pure (A.mkArray dims elems) + shrink arr = + [ A.mkArray dims' (take (product dims') (A.toList arr)) + | dims' <- shrunkDims + , product dims' > 0 + ] + where + (d0, d1, d2, d3) = A.getDims arr + ndim = A.getNumDims arr + currentDims = take ndim [d0, d1, d2, d3] + shrunkDims = + [ [if i == j then d - 1 else d | (j, d) <- zip [0..] currentDims] + | i <- [0 .. ndim - 1] + , currentDims !! i > 1 + ] + ++ [take (ndim - 1) currentDims | ndim > 1] + +-- Scalar wrapper for numLaws. +-- Num laws require: (a) binary ops succeed for any two generated values, and +-- (b) `fromInteger 0` compares equal to `0 * x`. Both hold only when all +-- arrays are the same shape. Scalars ([1 1 1 1]) are the minimal fixed shape +-- that makes every Num law well-typed and exact for integer element types. +newtype Scalar a = Scalar (Array a) + deriving (Show, Eq, Num) + +-- Semiring/Ring instances so we can exercise semiringLaws/ringLaws, which +-- check associativity, distributivity and annihilation explicitly (stronger +-- than numLaws). Defined in terms of the derived Num instance; exact for the +-- integral element types these are instantiated at. +instance (A.AFType a, Num a) => Semiring (Scalar a) where + zero = 0 + one = 1 + plus = (+) + times = (*) + fromNatural n = fromInteger (toInteger n) + +instance (A.AFType a, Num a) => Ring (Scalar a) where + negate x = 0 - x + +instance Arbitrary CBool where + arbitrary = CBool <$> arbitrary + +instance (A.AFType a, Arbitrary a) => Arbitrary (Scalar a) where + arbitrary = Scalar . A.scalar <$> arbitrary + shrink (Scalar arr) = Scalar . A.scalar <$> case A.toList arr of + x : _ -> shrink x + [] -> [] + +-- Run a Laws check, print results in the same format as lawsCheck, and mark +-- the IORef False on any failure so we can call exitFailure at the end. +checkLaws :: IORef Bool -> Laws -> IO () +checkLaws ref laws = do + let cls = lawsTypeclass laws + forM_ (lawsProperties laws) $ \(name, prop) -> do + putStr $ cls ++ ": " ++ name ++ " " + r <- quickCheckWithResult stdArgs { chatty = False } prop + putStr (output r) + unless (isSuccess r) (writeIORef ref False) main :: IO () -main = do - A.setBackend A.CPU --- checks (Proxy :: Proxy (A.Array (A.Complex Float))) --- checks (Proxy :: Proxy (A.Array (A.Complex Double))) --- checks (Proxy :: Proxy (A.Array Double)) --- checks (Proxy :: Proxy (A.Array Float)) --- checks (Proxy :: Proxy (A.Array Double)) --- checks (Proxy :: Proxy (A.Array A.Int16)) --- checks (Proxy :: Proxy (A.Array A.Int32)) - -- checks (Proxy :: Proxy (A.Array A.CBool)) - -- checks (Proxy :: Proxy (A.Array Word)) - -- checks (Proxy :: Proxy (A.Array A.Word8)) - -- checks (Proxy :: Proxy (A.Array A.Word16)) - -- checks (Proxy :: Proxy (A.Array A.Word32)) --- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Double)) --- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Float)) +main = A.withArrayFire $ do + ref <- newIORef True + let check = checkLaws ref + -- IEEE 754 is not an exact ring; only Eq laws for floating-point arrays. + check (eqLaws (Proxy :: Proxy (Array Double))) + check (eqLaws (Proxy :: Proxy (Array Float))) + -- Complex: Eq only (IEEE 754 + gt/lt undefined for complex numbers). + check (eqLaws (Proxy :: Proxy (Array (A.Complex Double)))) + check (eqLaws (Proxy :: Proxy (Array (A.Complex Float)))) + -- Integral types: exact ring laws via Scalar, Eq laws via multi-dim Array. + intChecks ref (Proxy :: Proxy Int) + intChecks ref (Proxy :: Proxy A.Int16) + intChecks ref (Proxy :: Proxy A.Int32) + intChecks ref (Proxy :: Proxy A.Int64) + intChecks ref (Proxy :: Proxy A.Word8) + intChecks ref (Proxy :: Proxy A.Word16) + intChecks ref (Proxy :: Proxy A.Word32) + intChecks ref (Proxy :: Proxy A.Word64) + intChecks ref (Proxy :: Proxy Word) + intChecks ref (Proxy :: Proxy A.CBool) hspec spec + ok <- readIORef ref + unless ok exitFailure -checks proxy = do - lawsCheck (numLaws proxy) - lawsCheck (eqLaws proxy) - lawsCheck (ordLaws proxy) --- lawsCheck (semigroupLaws proxy) +intChecks :: forall a. (A.AFType a, Arbitrary a, Num a, Eq a) => IORef Bool -> Proxy a -> IO () +intChecks ref _ = do + checkLaws ref (numLaws (Proxy :: Proxy (Scalar a))) + checkLaws ref (semiringLaws (Proxy :: Proxy (Scalar a))) + checkLaws ref (ringLaws (Proxy :: Proxy (Scalar a))) + checkLaws ref (eqLaws (Proxy :: Proxy (Array a))) diff --git a/test/Test/Hspec/ApproxExpect.hs b/test/Test/Hspec/ApproxExpect.hs index 3e9d66b..8ff6c05 100644 --- a/test/Test/Hspec/ApproxExpect.hs +++ b/test/Test/Hspec/ApproxExpect.hs @@ -1,19 +1,30 @@ -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} module Test.Hspec.ApproxExpect where import Data.CallStack (HasCallStack) - import Test.Hspec (shouldSatisfy, Expectation) infix 1 `shouldBeApprox` -shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) - => a -> a -> Expectation -shouldBeApprox actual tgt - -- This is a hackish way of checking, without requiring a specific - -- type or an 'Ord' instance, whether two floating-point values - -- are only some epsilons apart: when the difference is small enough - -- so scaling it down some more makes it a no-op for addition. - = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt +-- | Element-wise relative + absolute closeness for lists of 'Double'. +-- +-- Tolerances: atol = 1e-9, rtol = 1e-6 (suitable for BLAS/FFT results). +closeList :: [Double] -> [Double] -> Bool +closeList as bs = + length as == length bs && + and (zipWith (\a b -> abs (a - b) <= 1e-9 + 1e-6 * max (abs a) (abs b)) as bs) +-- | Assert two floating-point values are within relative + absolute tolerance. +-- +-- Uses the same formula as numpy.testing.assert_allclose: +-- |a - b| <= atol + rtol * max(|a|, |b|) +-- with rtol = 1e-5 and atol = 1e-8, matching numpy defaults. +shouldBeApprox + :: (HasCallStack, Show a, Ord a, Fractional a) + => a -> a -> Expectation +shouldBeApprox actual expected = + actual `shouldSatisfy` \x -> + abs (x - expected) <= atol + rtol * max (abs x) (abs expected) + where + rtol = 1e-5 + atol = 1e-8