KMeans: Reuse Precomputed Norms for Inertia Computation#2258
KMeans: Reuse Precomputed Norms for Inertia Computation#2258tarang-jain wants to merge 6 commits into
Conversation
📝 WalkthroughWalkthrough
Changescluster_cost X_norm and device cost handling
neighbors filtering documentation
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/src/cluster/detail/kmeans.cuh`:
- Around line 963-975: The host data path (in the else block after checking
data_on_device) unconditionally copies from h_norm_cache without verifying it
was populated. When max_iter equals zero, the training loop never initializes
h_norm_cache, causing the raft::copy operation to read uninitialized memory and
compute incorrect final inertia. Add a guard condition in the else block to
check if max_iter is zero, and handle this case separately (either by skipping
the norm computation or computing norms on-the-fly) before attempting to copy
from the unpopulated h_norm_cache.
In `@cpp/src/cluster/kmeans.cuh`:
- Around line 404-414: Before using the X_norm parameter's data handle in the
conditional block where X_norm.has_value() is true, add a validation check to
ensure that X_norm->extent(0) equals n_samples. If the sizes do not match, raise
an appropriate error (such as raft::log::warn or throw an exception) to prevent
out-of-bounds device reads. This validation should occur immediately after
confirming X_norm has a value and before calling const_cast to obtain the data
pointer.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 0950519f-ff3a-4b92-8771-aa7c3a2cd615
📒 Files selected for processing (2)
cpp/src/cluster/detail/kmeans.cuhcpp/src/cluster/kmeans.cuh
| std::optional<raft::device_vector_view<const DataT, IndexT>> batch_xnorm = std::nullopt; | ||
| if (need_compute_norms) { | ||
| if constexpr (data_on_device) { | ||
| batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>( | ||
| L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size); | ||
| } else { | ||
| raft::copy(L2NormBatch.data_handle(), | ||
| h_norm_cache.data_handle() + data_batch.offset(), | ||
| cur_batch_size, | ||
| stream); | ||
| batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>( | ||
| L2NormBatch.data_handle(), cur_batch_size); | ||
| } |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
HIGH: Guard the uncached host-norm path for zero-iteration fits.
Issue: For host data, this final inertia path always copies from h_norm_cache; if max_iter == 0, the training loop never populated it, so inertia is computed from uninitialized norms.
Why: This returns incorrect final inertia for a valid-looking parameter combination because max_iter is not rejected earlier.
Suggested fix
if (need_compute_norms) {
if constexpr (data_on_device) {
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
} else {
- raft::copy(L2NormBatch.data_handle(),
- h_norm_cache.data_handle() + data_batch.offset(),
- cur_batch_size,
- stream);
+ if (norms_cached) {
+ raft::copy(L2NormBatch.data_handle(),
+ h_norm_cache.data_handle() + data_batch.offset(),
+ cur_batch_size,
+ stream);
+ } else {
+ compute_batch_norms(data_batch.data(), cur_batch_size);
+ }
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle(), cur_batch_size);
}
}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/cluster/detail/kmeans.cuh` around lines 963 - 975, The host data path
(in the else block after checking data_on_device) unconditionally copies from
h_norm_cache without verifying it was populated. When max_iter equals zero, the
training loop never initializes h_norm_cache, causing the raft::copy operation
to read uninitialized memory and compute incorrect final inertia. Add a guard
condition in the else block to check if max_iter is zero, and handle this case
separately (either by skipping the norm computation or computing norms
on-the-fly) before attempting to copy from the unpopulated h_norm_cache.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/include/cuvs/cluster/kmeans.hpp`:
- Around line 1535-1536: Update the public Doxygen for the kmeans API so
`X_norm` is documented as squared row norms (`||x||^2`), not plain L2 norms. Fix
the parameter descriptions in the affected overloads of the kmeans public header
(the same `X_norm` docs repeated across the API surface) to clearly state that
callers must pass precomputed squared norms and that the internal norm
computation is skipped when provided.
- Around line 1538-1544: `cluster_cost` in the public kmeans API is
source-breaking because `cost` was changed from a host scalar to a device
scalar, so restore compatibility by keeping the existing host-scalar overloads
as deprecated shims alongside the new `raft::device_scalar_view` overload in
`cluster_cost` and the related overloads at the referenced locations. Update the
overload set in `cuvs::cluster::kmeans` so downstream callers can still compile,
emit deprecation warnings from the old signatures, and add the required
migration/deprecation note in the public API docs for the affected functions.
In `@cpp/src/cluster/detail/kmeans_balanced.cuh`:
- Around line 1142-1147: The final inertia computation in build_hierarchical is
still recomputing row norms instead of reusing the cached dataset_norm. Update
the cluster_cost call in this hierarchical path to forward dataset_norm for the
L2 cases that already precompute it, and keep the existing fallback behavior for
other metrics or paths. Use the build_hierarchical and cluster_cost symbols to
locate the call site and thread the cached norm through the final inertia
calculation.
In `@fern/pages/cpp_api/cpp-api-neighbors-common.md`:
- Around line 328-330: The UDF predicate documentation is missing the exact
device-function signature and leaves the argument semantics ambiguous. Update
the text around the `function_name` reference to explicitly show the full
`__device__` predicate signature, including the argument order and types
(`query_id`, `source_id`, and `filter_data`), and clarify in the same section
that `filter_data` is passed through unchanged and must point to
device-accessible memory when dereferenced. Keep the existing references to
`source_index_t` and CAGRA’s `uint32_t` mapping so users can implement the UDF
correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 0ec9614c-2978-4f81-a575-deba2396b285
📒 Files selected for processing (8)
c/src/cluster/kmeans.cppcpp/include/cuvs/cluster/kmeans.hppcpp/src/cluster/detail/kmeans.cuhcpp/src/cluster/detail/kmeans_balanced.cuhcpp/src/cluster/kmeans.cuhcpp/src/cluster/kmeans_cluster_cost.cufern/pages/cpp_api/cpp-api-cluster-kmeans.mdfern/pages/cpp_api/cpp-api-neighbors-common.md
🚧 Files skipped from review as they are similar to previous changes (1)
- cpp/src/cluster/detail/kmeans.cuh
| * @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples]. | ||
| * When provided, the internal norm computation is skipped. |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
HIGH: X_norm is documented as an L2 norm, but the implementation consumes squared norms.
This parameter is forwarded into the L2NormX path in cpp/src/cluster/kmeans.cuh, whose contract is ||x||^2. Documenting it as a plain “L2 norm” invites callers to pass sqrt(sum(x^2)), which will silently skew the reported cluster cost.
Please make the public Doxygen explicit that X_norm must contain squared row norms. As per coding guidelines, “All public API functions must include complete Doxygen documentation describing parameters, return values, and any side effects.” As per path instructions for cpp/include/cuvs/**/*, “For public C++ API headers, additionally check: Doxygen documentation for all public functions/classes.”
Also applies to: 1559-1560, 1583-1584, 1607-1608
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/include/cuvs/cluster/kmeans.hpp` around lines 1535 - 1536, Update the
public Doxygen for the kmeans API so `X_norm` is documented as squared row norms
(`||x||^2`), not plain L2 norms. Fix the parameter descriptions in the affected
overloads of the kmeans public header (the same `X_norm` docs repeated across
the API surface) to clearly state that callers must pass precomputed squared
norms and that the internal norm computation is skipped when provided.
Sources: Coding guidelines, Path instructions
| void cluster_cost( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const float, int> X, | ||
| raft::device_matrix_view<const float, int> centroids, | ||
| raft::host_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt); | ||
| raft::device_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt, | ||
| std::optional<raft::device_vector_view<const float, int>> X_norm = std::nullopt); |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🔴 Critical | 🏗️ Heavy lift
CRITICAL: cluster_cost is now source-breaking in the public C++ API.
Switching cost from raft::host_scalar_view to raft::device_scalar_view in cpp/include/cuvs/ removes the old call surface entirely, so existing downstream callers stop compiling with no deprecation path or migration note.
Consider keeping the old host-scalar overloads as deprecated shims for at least one release and documenting the migration. As per coding guidelines, “API changes require deprecation warnings.” As per path instructions for cpp/include/cuvs/**/*, “Breaking changes require deprecation warnings and migration guide updates.”
Also applies to: 1562-1568, 1586-1592, 1610-1616
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/include/cuvs/cluster/kmeans.hpp` around lines 1538 - 1544, `cluster_cost`
in the public kmeans API is source-breaking because `cost` was changed from a
host scalar to a device scalar, so restore compatibility by keeping the existing
host-scalar overloads as deprecated shims alongside the new
`raft::device_scalar_view` overload in `cluster_cost` and the related overloads
at the referenced locations. Update the overload set in `cuvs::cluster::kmeans`
so downstream callers can still compile, emit deprecation warnings from the old
signatures, and add the required migration/deprecation note in the public API
docs for the affected functions.
Sources: Coding guidelines, Path instructions
| auto d_inertia = raft::make_device_scalar<MathT>(handle, MathT{0}); | ||
| cuvs::cluster::kmeans::cluster_cost(handle, X_view, centroids_view, d_inertia.view()); | ||
| raft::copy(handle, | ||
| raft::make_host_scalar_view<MathT>(inertia), | ||
| raft::make_const_mdspan(d_inertia.view())); | ||
| raft::resource::sync_stream(handle, stream); |
There was a problem hiding this comment.
🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win
HIGH: final hierarchical inertia still drops the cached dataset norms.
build_hierarchical() has already materialized dataset_norm above, but this cluster_cost() call does not pass it through, so the final inertia path pays for another full row-norm allocation/recompute over n_rows. That gives back much of the win this PR is targeting on the hierarchical path.
Have you considered forwarding dataset_norm here for the L2-based cases that precompute it?
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/cluster/detail/kmeans_balanced.cuh` around lines 1142 - 1147, The
final inertia computation in build_hierarchical is still recomputing row norms
instead of reusing the cached dataset_norm. Update the cluster_cost call in this
hierarchical path to forward dataset_norm for the L2 cases that already
precompute it, and keep the existing fallback behavior for other metrics or
paths. Use the build_hierarchical and cluster_cost symbols to locate the call
site and thread the cached norm through the final inertia calculation.
| The source must define a device function named by `function_name` with signature: | ||
|
|
||
| Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment. |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
Document the exact UDF predicate signature where it is referenced.
Line 328 says “with signature:” but the signature itself is missing, and Line 330 is unclear (“UDF dereferences it.”). This makes user UDF implementations error-prone (arg types/order must match exactly).
Suggested doc patch
The source must define a device function named by `function_name` with signature:
+```cpp
+__device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data);
+```
-Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.
+Return `true` to allow a source vector to appear in the results and `false` to reject it. `filter_data` is passed through unchanged and must point to device-accessible memory when dereferenced. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| The source must define a device function named by `function_name` with signature: | |
| Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment. | |
| The source must define a device function named by `function_name` with signature: | |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@fern/pages/cpp_api/cpp-api-neighbors-common.md` around lines 328 - 330, The
UDF predicate documentation is missing the exact device-function signature and
leaves the argument semantics ambiguous. Update the text around the
`function_name` reference to explicitly show the full `__device__` predicate
signature, including the argument order and types (`query_id`, `source_id`, and
`filter_data`), and clarify in the same section that `filter_data` is passed
through unchanged and must point to device-accessible memory when dereferenced.
Keep the existing references to `source_index_t` and CAGRA’s `uint32_t` mapping
so users can implement the UDF correctly.
Dataset norms are precomputed and cached during the kmeans training iterations. We need not recompute them during the final inertia computation.
Closes #2057
This also prevents the sync stream for every batch in the OOC setting.
This changes the C++ API in the public header, so it is a breaking change.