Skip to content

Switch Product Quantization VQ to mean() when n_centers = 1#2250

Open
lowener wants to merge 3 commits into
NVIDIA:mainfrom
lowener:26.08-pq-mean
Open

Switch Product Quantization VQ to mean() when n_centers = 1#2250
lowener wants to merge 3 commits into
NVIDIA:mainfrom
lowener:26.08-pq-mean

Conversation

@lowener

@lowener lowener commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

When running a PQ preprocessing operation, the VQ option can be used to act as a way to mean-center the dataset.
This proposed change will enable to do that mean-centering operation faster by using a direct call to raft::stats::mean instead of running the expectation-maximization steps.

Signed-off-by: Mickael Ide <mide@nvidia.com>
@lowener lowener requested a review from a team as a code owner June 22, 2026 15:04
@lowener lowener added improvement Improves an existing functionality non-breaking Introduces a non-breaking change C++ labels Jun 22, 2026
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 84910738-679d-4aae-9293-797f503afad0

📥 Commits

Reviewing files that changed from the base of the PR and between e091067 and 9b5bd2a.

📒 Files selected for processing (1)
  • cpp/src/neighbors/detail/vpq_dataset.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/src/neighbors/detail/vpq_dataset.cuh

📝 Walkthrough

Summary by CodeRabbit

  • Refactor
    • Improved vector quantization (VQ) center training by adding a faster single-center option (computed directly) and refining the multi-center path to use balanced k-means only when needed.
    • Keeps standard multi-center behavior intact while reducing unnecessary computation, improving training efficiency and execution speed.

Walkthrough

In train_vq, a new header dependency is added for raft::stats::mean. The kmeans_in_type alias is moved earlier. Explicit matrix views for vq_centers and the VQ trainset are constructed. A conditional branch handles vq_n_centers == 1 by computing the single center via raft::stats::mean; the general case continues to call cuvs::cluster::kmeans::fit with L2Expanded and params.kmeans_n_iters.

Changes

VQ center training refactor

Layer / File(s) Summary
Single-center fast path and k-means wiring
cpp/src/neighbors/detail/vpq_dataset.cuh
Header include for raft::stats::mean is added. kmeans_in_type is defined at the top of train_vq; dedicated views for vq_centers and the trainset are built; a vq_n_centers == 1 branch calls raft::stats::mean to compute the center, while the else branch calls cuvs::cluster::kmeans::fit with L2Expanded metric and params.kmeans_n_iters.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~5 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main optimization: switching PQ's VQ to use mean() when n_centers equals 1, which directly matches the primary change.
Description check ✅ Passed The description explains the purpose of the change (faster mean-centering using direct mean call instead of EM steps) and is directly related to the code modifications.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

Signed-off-by: Mickael Ide <mide@nvidia.com>
Signed-off-by: Mickael Ide <mide@nvidia.com>
@cjnolet

cjnolet commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

this proposed change will enable to do that mean-centering operation faster by using a direct call to raft::stats::mean

Thanks for the PQ @lowener. Can you please provide some benchmarks here to demonstrate the difference in the perf for this change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

C++ improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants