Skip to content

KMeans: Reuse Precomputed Norms for Inertia Computation#2258

Open
tarang-jain wants to merge 6 commits into
NVIDIA:mainfrom
tarang-jain:cluster-cost
Open

KMeans: Reuse Precomputed Norms for Inertia Computation#2258
tarang-jain wants to merge 6 commits into
NVIDIA:mainfrom
tarang-jain:cluster-cost

Conversation

@tarang-jain

@tarang-jain tarang-jain commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Dataset norms are precomputed and cached during the kmeans training iterations. We need not recompute them during the final inertia computation.
Closes #2057
This also prevents the sync stream for every batch in the OOC setting.
This changes the C++ API in the public header, so it is a breaking change.

@tarang-jain tarang-jain requested a review from a team as a code owner June 23, 2026 22:52
@tarang-jain tarang-jain added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Jun 23, 2026
@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

cluster_cost now accepts optional precomputed X_norm inputs and uses device scalar cost outputs. kmeans_fit accumulates inertia on device across batches, then copies the final value to host once. Neighbor filtering docs add UDF and udf_filter.

Changes

cluster_cost X_norm and device cost handling

Layer / File(s) Summary
cluster_cost API and norm selection
cpp/include/cuvs/cluster/kmeans.hpp, cpp/src/cluster/kmeans.cuh, cpp/src/cluster/kmeans_cluster_cost.cu, c/src/cluster/kmeans.cpp, fern/pages/cpp_api/cpp-api-cluster-kmeans.md
cluster_cost overloads change cost to device scalar views and add optional X_norm inputs. The implementation reuses provided norms or computes them internally, forwards the selected norms to min_cluster_distance, and the wrappers copy the device result back to host where needed. The API docs are updated to match the new signatures and parameters.
kmeans_fit on-device inertia accumulation
cpp/src/cluster/detail/kmeans.cuh, cpp/src/cluster/detail/kmeans_balanced.cuh
kmeans_fit and the balanced inertia path use temporary device scalars for inertia and batch cost, accumulate batch costs on device, then copy the final inertia to host after synchronizing the stream.

neighbors filtering documentation

Layer / File(s) Summary
FilterType and udf_filter docs
fern/pages/cpp_api/cpp-api-neighbors-common.md
FilterType documentation adds UDF, and a new udf_filter section documents its fields and device predicate behavior.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • rapidsai/cuvs#2015: Both PRs modify the streamed KMeans fitting path in cpp/src/cluster/detail/kmeans.cuh, including per-batch cost/inertia handling and norm-related inputs.
  • rapidsai/cuvs#2017: Both PRs change cuvs::cluster::kmeans::cluster_cost output handling and related cost accumulation call sites.

Suggested labels

breaking, cpp

Suggested reviewers

  • lowener
  • cjnolet
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Out of Scope Changes check ⚠️ Warning The neighbors::filtering UDF documentation changes are unrelated to the KMeans norm/inertia work. Move the neighbors::filtering documentation updates to a separate PR or remove them from this change.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: reusing precomputed norms for KMeans inertia.
Linked Issues check ✅ Passed The code changes implement issue #2057 by adding optional X_norm support and moving inertia accumulation to the device.
Description check ✅ Passed The description matches the code changes by describing cached norms, reduced syncs, and the public API break.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/src/cluster/detail/kmeans.cuh`:
- Around line 963-975: The host data path (in the else block after checking
data_on_device) unconditionally copies from h_norm_cache without verifying it
was populated. When max_iter equals zero, the training loop never initializes
h_norm_cache, causing the raft::copy operation to read uninitialized memory and
compute incorrect final inertia. Add a guard condition in the else block to
check if max_iter is zero, and handle this case separately (either by skipping
the norm computation or computing norms on-the-fly) before attempting to copy
from the unpopulated h_norm_cache.

In `@cpp/src/cluster/kmeans.cuh`:
- Around line 404-414: Before using the X_norm parameter's data handle in the
conditional block where X_norm.has_value() is true, add a validation check to
ensure that X_norm->extent(0) equals n_samples. If the sizes do not match, raise
an appropriate error (such as raft::log::warn or throw an exception) to prevent
out-of-bounds device reads. This validation should occur immediately after
confirming X_norm has a value and before calling const_cast to obtain the data
pointer.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0950519f-ff3a-4b92-8771-aa7c3a2cd615

📥 Commits

Reviewing files that changed from the base of the PR and between 71306ec and 04d00e3.

📒 Files selected for processing (2)
  • cpp/src/cluster/detail/kmeans.cuh
  • cpp/src/cluster/kmeans.cuh

Comment on lines +963 to +975
std::optional<raft::device_vector_view<const DataT, IndexT>> batch_xnorm = std::nullopt;
if (need_compute_norms) {
if constexpr (data_on_device) {
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
} else {
raft::copy(L2NormBatch.data_handle(),
h_norm_cache.data_handle() + data_batch.offset(),
cur_batch_size,
stream);
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle(), cur_batch_size);
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

HIGH: Guard the uncached host-norm path for zero-iteration fits.

Issue: For host data, this final inertia path always copies from h_norm_cache; if max_iter == 0, the training loop never populated it, so inertia is computed from uninitialized norms.
Why: This returns incorrect final inertia for a valid-looking parameter combination because max_iter is not rejected earlier.

Suggested fix
         if (need_compute_norms) {
           if constexpr (data_on_device) {
             batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
               L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
           } else {
-            raft::copy(L2NormBatch.data_handle(),
-                       h_norm_cache.data_handle() + data_batch.offset(),
-                       cur_batch_size,
-                       stream);
+            if (norms_cached) {
+              raft::copy(L2NormBatch.data_handle(),
+                         h_norm_cache.data_handle() + data_batch.offset(),
+                         cur_batch_size,
+                         stream);
+            } else {
+              compute_batch_norms(data_batch.data(), cur_batch_size);
+            }
             batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
               L2NormBatch.data_handle(), cur_batch_size);
           }
         }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/cluster/detail/kmeans.cuh` around lines 963 - 975, The host data path
(in the else block after checking data_on_device) unconditionally copies from
h_norm_cache without verifying it was populated. When max_iter equals zero, the
training loop never initializes h_norm_cache, causing the raft::copy operation
to read uninitialized memory and compute incorrect final inertia. Add a guard
condition in the else block to check if max_iter is zero, and handle this case
separately (either by skipping the norm computation or computing norms
on-the-fly) before attempting to copy from the unpopulated h_norm_cache.

Comment thread cpp/src/cluster/kmeans.cuh
@tarang-jain tarang-jain requested review from a team as code owners June 24, 2026 14:26

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/include/cuvs/cluster/kmeans.hpp`:
- Around line 1535-1536: Update the public Doxygen for the kmeans API so
`X_norm` is documented as squared row norms (`||x||^2`), not plain L2 norms. Fix
the parameter descriptions in the affected overloads of the kmeans public header
(the same `X_norm` docs repeated across the API surface) to clearly state that
callers must pass precomputed squared norms and that the internal norm
computation is skipped when provided.
- Around line 1538-1544: `cluster_cost` in the public kmeans API is
source-breaking because `cost` was changed from a host scalar to a device
scalar, so restore compatibility by keeping the existing host-scalar overloads
as deprecated shims alongside the new `raft::device_scalar_view` overload in
`cluster_cost` and the related overloads at the referenced locations. Update the
overload set in `cuvs::cluster::kmeans` so downstream callers can still compile,
emit deprecation warnings from the old signatures, and add the required
migration/deprecation note in the public API docs for the affected functions.

In `@cpp/src/cluster/detail/kmeans_balanced.cuh`:
- Around line 1142-1147: The final inertia computation in build_hierarchical is
still recomputing row norms instead of reusing the cached dataset_norm. Update
the cluster_cost call in this hierarchical path to forward dataset_norm for the
L2 cases that already precompute it, and keep the existing fallback behavior for
other metrics or paths. Use the build_hierarchical and cluster_cost symbols to
locate the call site and thread the cached norm through the final inertia
calculation.

In `@fern/pages/cpp_api/cpp-api-neighbors-common.md`:
- Around line 328-330: The UDF predicate documentation is missing the exact
device-function signature and leaves the argument semantics ambiguous. Update
the text around the `function_name` reference to explicitly show the full
`__device__` predicate signature, including the argument order and types
(`query_id`, `source_id`, and `filter_data`), and clarify in the same section
that `filter_data` is passed through unchanged and must point to
device-accessible memory when dereferenced. Keep the existing references to
`source_index_t` and CAGRA’s `uint32_t` mapping so users can implement the UDF
correctly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0ec9614c-2978-4f81-a575-deba2396b285

📥 Commits

Reviewing files that changed from the base of the PR and between 04d00e3 and 4c98eca.

📒 Files selected for processing (8)
  • c/src/cluster/kmeans.cpp
  • cpp/include/cuvs/cluster/kmeans.hpp
  • cpp/src/cluster/detail/kmeans.cuh
  • cpp/src/cluster/detail/kmeans_balanced.cuh
  • cpp/src/cluster/kmeans.cuh
  • cpp/src/cluster/kmeans_cluster_cost.cu
  • fern/pages/cpp_api/cpp-api-cluster-kmeans.md
  • fern/pages/cpp_api/cpp-api-neighbors-common.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/src/cluster/detail/kmeans.cuh

Comment on lines +1535 to +1536
* @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples].
* When provided, the internal norm computation is skipped.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

HIGH: X_norm is documented as an L2 norm, but the implementation consumes squared norms.

This parameter is forwarded into the L2NormX path in cpp/src/cluster/kmeans.cuh, whose contract is ||x||^2. Documenting it as a plain “L2 norm” invites callers to pass sqrt(sum(x^2)), which will silently skew the reported cluster cost.

Please make the public Doxygen explicit that X_norm must contain squared row norms. As per coding guidelines, “All public API functions must include complete Doxygen documentation describing parameters, return values, and any side effects.” As per path instructions for cpp/include/cuvs/**/*, “For public C++ API headers, additionally check: Doxygen documentation for all public functions/classes.”

Also applies to: 1559-1560, 1583-1584, 1607-1608

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/cuvs/cluster/kmeans.hpp` around lines 1535 - 1536, Update the
public Doxygen for the kmeans API so `X_norm` is documented as squared row norms
(`||x||^2`), not plain L2 norms. Fix the parameter descriptions in the affected
overloads of the kmeans public header (the same `X_norm` docs repeated across
the API surface) to clearly state that callers must pass precomputed squared
norms and that the internal norm computation is skipped when provided.

Sources: Coding guidelines, Path instructions

Comment on lines 1538 to +1544
void cluster_cost(
const raft::resources& handle,
raft::device_matrix_view<const float, int> X,
raft::device_matrix_view<const float, int> centroids,
raft::host_scalar_view<float> cost,
std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt);
raft::device_scalar_view<float> cost,
std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt,
std::optional<raft::device_vector_view<const float, int>> X_norm = std::nullopt);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🔴 Critical | 🏗️ Heavy lift

CRITICAL: cluster_cost is now source-breaking in the public C++ API.

Switching cost from raft::host_scalar_view to raft::device_scalar_view in cpp/include/cuvs/ removes the old call surface entirely, so existing downstream callers stop compiling with no deprecation path or migration note.

Consider keeping the old host-scalar overloads as deprecated shims for at least one release and documenting the migration. As per coding guidelines, “API changes require deprecation warnings.” As per path instructions for cpp/include/cuvs/**/*, “Breaking changes require deprecation warnings and migration guide updates.”

Also applies to: 1562-1568, 1586-1592, 1610-1616

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/cuvs/cluster/kmeans.hpp` around lines 1538 - 1544, `cluster_cost`
in the public kmeans API is source-breaking because `cost` was changed from a
host scalar to a device scalar, so restore compatibility by keeping the existing
host-scalar overloads as deprecated shims alongside the new
`raft::device_scalar_view` overload in `cluster_cost` and the related overloads
at the referenced locations. Update the overload set in `cuvs::cluster::kmeans`
so downstream callers can still compile, emit deprecation warnings from the old
signatures, and add the required migration/deprecation note in the public API
docs for the affected functions.

Sources: Coding guidelines, Path instructions

Comment on lines +1142 to +1147
auto d_inertia = raft::make_device_scalar<MathT>(handle, MathT{0});
cuvs::cluster::kmeans::cluster_cost(handle, X_view, centroids_view, d_inertia.view());
raft::copy(handle,
raft::make_host_scalar_view<MathT>(inertia),
raft::make_const_mdspan(d_inertia.view()));
raft::resource::sync_stream(handle, stream);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win

HIGH: final hierarchical inertia still drops the cached dataset norms.

build_hierarchical() has already materialized dataset_norm above, but this cluster_cost() call does not pass it through, so the final inertia path pays for another full row-norm allocation/recompute over n_rows. That gives back much of the win this PR is targeting on the hierarchical path.

Have you considered forwarding dataset_norm here for the L2-based cases that precompute it?

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/cluster/detail/kmeans_balanced.cuh` around lines 1142 - 1147, The
final inertia computation in build_hierarchical is still recomputing row norms
instead of reusing the cached dataset_norm. Update the cluster_cost call in this
hierarchical path to forward dataset_norm for the L2 cases that already
precompute it, and keep the existing fallback behavior for other metrics or
paths. Use the build_hierarchical and cluster_cost symbols to locate the call
site and thread the cached norm through the final inertia calculation.

Comment on lines +328 to +330
The source must define a device function named by `function_name` with signature:

Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

Document the exact UDF predicate signature where it is referenced.

Line 328 says “with signature:” but the signature itself is missing, and Line 330 is unclear (“UDF dereferences it.”). This makes user UDF implementations error-prone (arg types/order must match exactly).

Suggested doc patch
 The source must define a device function named by `function_name` with signature:
+```cpp
+__device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data);
+```
 
-Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.
+Return `true` to allow a source vector to appear in the results and `false` to reject it. `filter_data` is passed through unchanged and must point to device-accessible memory when dereferenced. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
The source must define a device function named by `function_name` with signature:
Return `true` to allow a source vector to appear in the results and `false` to reject it. UDF dereferences it. CAGRA currently provides `source_index_t` as `uint32_t` in the generated JIT fragment.
The source must define a device function named by `function_name` with signature:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@fern/pages/cpp_api/cpp-api-neighbors-common.md` around lines 328 - 330, The
UDF predicate documentation is missing the exact device-function signature and
leaves the argument semantics ambiguous. Update the text around the
`function_name` reference to explicitly show the full `__device__` predicate
signature, including the argument order and types (`query_id`, `source_id`, and
`filter_data`), and clarify in the same section that `filter_data` is passed
through unchanged and must point to device-accessible memory when dereferenced.
Keep the existing references to `source_index_t` and CAGRA’s `uint32_t` mapping
so users can implement the UDF correctly.

@tarang-jain tarang-jain added breaking Introduces a breaking change and removed non-breaking Introduces a non-breaking change labels Jun 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Introduces a breaking change improvement Improves an existing functionality

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] cluster_cost can accept X norm

1 participant