diff --git a/CHANGELOG.md b/CHANGELOG.md index 925e941d8..2659bac3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat(example): support chained NextN heads for server MTP drafting - feat: update llama.cpp to ggml-org/llama.cpp@92e854ab8 - fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306 diff --git a/examples/server/server.py b/examples/server/server.py index 16f8c9f7e..72adc7905 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -1288,7 +1288,14 @@ def __init__( raise RuntimeError("failed to create MTP draft context") ctx_other = llama_cpp_ext.llama_get_ctx_other(self.ctx) self.is_mem_shared = bool(ctx_other and ctx_other == self.target_ctx) - self.sampled_batch_draft = not self.is_mem_shared + self.n_mtp_layers = max( + 1, + int(llama_cpp.llama_model_n_layer_nextn(self.model)), + ) + self.chain_heads = self.n_mtp_layers > 1 and not self.is_mem_shared + if self.chain_heads: + self.num_pred_tokens = min(self.num_pred_tokens, self.n_mtp_layers) + self.sampled_batch_draft = not self.is_mem_shared and not self.chain_heads self.n_batch = int(llama_cpp.llama_n_batch(self.ctx)) mem = llama_cpp.llama_get_memory(self.ctx) if mem is None: @@ -1451,6 +1458,17 @@ def _try_decode_batch(self) -> bool: return False return True + def _set_nextn_layer_offset(self, offset: int) -> None: + if self.chain_heads: + llama_cpp_ext.llama_set_nextn_layer_offset(self.ctx, offset) + + def _try_decode_batch_for_mtp_head(self, head: int) -> bool: + self._set_nextn_layer_offset(head) + try: + return self._try_decode_batch() + finally: + self._set_nextn_layer_offset(0) + def _decode_batch(self) -> None: n_tokens = int(self.batch.n_tokens) if n_tokens <= 0: @@ -1464,6 +1482,22 @@ def _decode_batch(self) -> None: self.decode_failures_total += 1 raise RuntimeError(f"MTP draft decode failed with code {result}") + def _decode_batch_for_mtp_heads( + self, + start_pos_by_seq: Dict[int, int], + ) -> None: + if not self.chain_heads: + self._decode_batch() + return + try: + for head in range(self.n_mtp_layers): + for seq_id, start_pos in start_pos_by_seq.items(): + llama_cpp.llama_memory_seq_rm(self.mem, seq_id, start_pos, -1) + self._set_nextn_layer_offset(head) + self._decode_batch() + finally: + self._set_nextn_layer_offset(0) + def metric_definitions( self, ) -> List[Tuple[str, str, str, Union[int, float]]]: @@ -1577,7 +1611,8 @@ def _process_rows( target_rows_by_seq: Dict[int, List[int]], aligned_by_seq: Dict[int, bool], ) -> None: - added_pos_by_seq: Dict[int, int] = {} + added_start_pos_by_seq: Dict[int, int] = {} + added_end_pos_by_seq: Dict[int, int] = {} self._clear_batch() for index in range(start, end): if int(batch.n_seq_id[index]) != 1: @@ -1609,15 +1644,85 @@ def _process_rows( self._set_batch_embedding_row(slot, self.pending_h[seq_id]) else: self._set_batch_embedding_row(slot, h_tgt_rows[previous_row]) - added_pos_by_seq[seq_id] = pos + added_start_pos_by_seq.setdefault(seq_id, pos) + added_end_pos_by_seq[seq_id] = pos previous_row_by_seq[seq_id] = index target_rows_by_seq.setdefault(seq_id, []).append(index) if int(self.batch.n_tokens) > 0: - self._decode_batch() - for seq_id, pos in added_pos_by_seq.items(): + self._decode_batch_for_mtp_heads(added_start_pos_by_seq) + for seq_id, pos in added_end_pos_by_seq.items(): self.context_pos[seq_id] = max(self.context_pos[seq_id], pos + 1) + def _draft_chain_heads( + self, + *, + seq_id: int, + first_pos: int, + token: int, + n_predict: int, + ) -> np.ndarray: + if self.context_pos[seq_id] > first_pos: + self.truncate(seq_id, first_pos) + if self.context_pos[seq_id] < first_pos: + self.ready[seq_id] = False + return np.array([], dtype=np.intc) + + drafted: List[int] = [] + chain_tokens = [token] + chain_embeddings = [self.pending_h[seq_id].copy()] + self._reset_sampler(seq_id) + + try: + for head in range(min(n_predict, self.n_mtp_layers)): + llama_cpp.llama_memory_seq_rm(self.mem, seq_id, first_pos, -1) + self._clear_batch() + for offset, (chain_token, embedding) in enumerate( + zip(chain_tokens, chain_embeddings) + ): + slot = int(self.batch.n_tokens) + self._add_batch_token( + token=chain_token, + pos=first_pos + offset, + seq_id=seq_id, + logits=offset == len(chain_tokens) - 1, + ) + self._set_batch_embedding_row(slot, embedding) + + output_index = int(self.batch.n_tokens) - 1 + if not self._try_decode_batch_for_mtp_head(head): + break + self.context_pos[seq_id] = max( + self.context_pos[seq_id], + first_pos + len(chain_tokens), + ) + sampled_token = self._sample_token(output_index, seq_id=seq_id) + if sampled_token is None: + break + drafted.append(sampled_token) + if len(drafted) >= n_predict: + break + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith( + self.ctx, + output_index, + ) + if not h_row: + break + chain_tokens.append(sampled_token) + chain_embeddings.append( + np.ctypeslib.as_array( + h_row, + shape=(self.n_embd,), + ).copy() + ) + finally: + self._set_nextn_layer_offset(0) + self.truncate(seq_id, first_pos) + + if not drafted: + return np.array([], dtype=np.intc) + return np.asarray(drafted, dtype=np.intc) + def draft( self, input_ids: np.ndarray, @@ -1649,6 +1754,13 @@ def draft( token = int(input_ids[-1]) drafted: List[int] = [] + if self.chain_heads: + return self._draft_chain_heads( + seq_id=seq_id, + first_pos=first_pos, + token=token, + n_predict=n_predict, + ) if not self.is_mem_shared and self.context_pos[seq_id] > first_pos: self.truncate(seq_id, first_pos) if not self.is_mem_shared and self.context_pos[seq_id] < first_pos: @@ -1709,6 +1821,15 @@ def draft_many( /, ) -> List[np.ndarray]: results = [np.array([], dtype=np.intc) for _ in requests] + if self.chain_heads: + for result_index, (input_ids, seq_id, max_tokens) in enumerate(requests): + results[result_index] = self.draft( + input_ids, + seq_id=seq_id, + max_tokens=max_tokens, + ) + return results + active: List["MTPDraftProvider.DraftManyState"] = [] for result_index, (input_ids, seq_id, max_tokens) in enumerate(requests): if (