Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- feat(example): support chained NextN heads for server MTP drafting
- feat: update llama.cpp to ggml-org/llama.cpp@92e854ab8
- fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306

Expand Down
131 changes: 126 additions & 5 deletions examples/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,14 @@ def __init__(
raise RuntimeError("failed to create MTP draft context")
ctx_other = llama_cpp_ext.llama_get_ctx_other(self.ctx)
self.is_mem_shared = bool(ctx_other and ctx_other == self.target_ctx)
self.sampled_batch_draft = not self.is_mem_shared
self.n_mtp_layers = max(
1,
int(llama_cpp.llama_model_n_layer_nextn(self.model)),
)
self.chain_heads = self.n_mtp_layers > 1 and not self.is_mem_shared
if self.chain_heads:
self.num_pred_tokens = min(self.num_pred_tokens, self.n_mtp_layers)
self.sampled_batch_draft = not self.is_mem_shared and not self.chain_heads
self.n_batch = int(llama_cpp.llama_n_batch(self.ctx))
mem = llama_cpp.llama_get_memory(self.ctx)
if mem is None:
Expand Down Expand Up @@ -1451,6 +1458,17 @@ def _try_decode_batch(self) -> bool:
return False
return True

def _set_nextn_layer_offset(self, offset: int) -> None:
if self.chain_heads:
llama_cpp_ext.llama_set_nextn_layer_offset(self.ctx, offset)

def _try_decode_batch_for_mtp_head(self, head: int) -> bool:
self._set_nextn_layer_offset(head)
try:
return self._try_decode_batch()
finally:
self._set_nextn_layer_offset(0)

def _decode_batch(self) -> None:
n_tokens = int(self.batch.n_tokens)
if n_tokens <= 0:
Expand All @@ -1464,6 +1482,22 @@ def _decode_batch(self) -> None:
self.decode_failures_total += 1
raise RuntimeError(f"MTP draft decode failed with code {result}")

def _decode_batch_for_mtp_heads(
self,
start_pos_by_seq: Dict[int, int],
) -> None:
if not self.chain_heads:
self._decode_batch()
return
try:
for head in range(self.n_mtp_layers):
for seq_id, start_pos in start_pos_by_seq.items():
llama_cpp.llama_memory_seq_rm(self.mem, seq_id, start_pos, -1)
self._set_nextn_layer_offset(head)
self._decode_batch()
finally:
self._set_nextn_layer_offset(0)

def metric_definitions(
self,
) -> List[Tuple[str, str, str, Union[int, float]]]:
Expand Down Expand Up @@ -1577,7 +1611,8 @@ def _process_rows(
target_rows_by_seq: Dict[int, List[int]],
aligned_by_seq: Dict[int, bool],
) -> None:
added_pos_by_seq: Dict[int, int] = {}
added_start_pos_by_seq: Dict[int, int] = {}
added_end_pos_by_seq: Dict[int, int] = {}
self._clear_batch()
for index in range(start, end):
if int(batch.n_seq_id[index]) != 1:
Expand Down Expand Up @@ -1609,15 +1644,85 @@ def _process_rows(
self._set_batch_embedding_row(slot, self.pending_h[seq_id])
else:
self._set_batch_embedding_row(slot, h_tgt_rows[previous_row])
added_pos_by_seq[seq_id] = pos
added_start_pos_by_seq.setdefault(seq_id, pos)
added_end_pos_by_seq[seq_id] = pos
previous_row_by_seq[seq_id] = index
target_rows_by_seq.setdefault(seq_id, []).append(index)

if int(self.batch.n_tokens) > 0:
self._decode_batch()
for seq_id, pos in added_pos_by_seq.items():
self._decode_batch_for_mtp_heads(added_start_pos_by_seq)
for seq_id, pos in added_end_pos_by_seq.items():
self.context_pos[seq_id] = max(self.context_pos[seq_id], pos + 1)

def _draft_chain_heads(
self,
*,
seq_id: int,
first_pos: int,
token: int,
n_predict: int,
) -> np.ndarray:
if self.context_pos[seq_id] > first_pos:
self.truncate(seq_id, first_pos)
if self.context_pos[seq_id] < first_pos:
self.ready[seq_id] = False
return np.array([], dtype=np.intc)

drafted: List[int] = []
chain_tokens = [token]
chain_embeddings = [self.pending_h[seq_id].copy()]
self._reset_sampler(seq_id)

try:
for head in range(min(n_predict, self.n_mtp_layers)):
llama_cpp.llama_memory_seq_rm(self.mem, seq_id, first_pos, -1)
self._clear_batch()
for offset, (chain_token, embedding) in enumerate(
zip(chain_tokens, chain_embeddings)
):
slot = int(self.batch.n_tokens)
self._add_batch_token(
token=chain_token,
pos=first_pos + offset,
seq_id=seq_id,
logits=offset == len(chain_tokens) - 1,
)
self._set_batch_embedding_row(slot, embedding)

output_index = int(self.batch.n_tokens) - 1
if not self._try_decode_batch_for_mtp_head(head):
break
self.context_pos[seq_id] = max(
self.context_pos[seq_id],
first_pos + len(chain_tokens),
)
sampled_token = self._sample_token(output_index, seq_id=seq_id)
if sampled_token is None:
break
drafted.append(sampled_token)
if len(drafted) >= n_predict:
break
h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith(
self.ctx,
output_index,
)
if not h_row:
break
chain_tokens.append(sampled_token)
chain_embeddings.append(
np.ctypeslib.as_array(
h_row,
shape=(self.n_embd,),
).copy()
)
finally:
self._set_nextn_layer_offset(0)
self.truncate(seq_id, first_pos)

if not drafted:
return np.array([], dtype=np.intc)
return np.asarray(drafted, dtype=np.intc)

def draft(
self,
input_ids: np.ndarray,
Expand Down Expand Up @@ -1649,6 +1754,13 @@ def draft(

token = int(input_ids[-1])
drafted: List[int] = []
if self.chain_heads:
return self._draft_chain_heads(
seq_id=seq_id,
first_pos=first_pos,
token=token,
n_predict=n_predict,
)
if not self.is_mem_shared and self.context_pos[seq_id] > first_pos:
self.truncate(seq_id, first_pos)
if not self.is_mem_shared and self.context_pos[seq_id] < first_pos:
Expand Down Expand Up @@ -1709,6 +1821,15 @@ def draft_many(
/,
) -> List[np.ndarray]:
results = [np.array([], dtype=np.intc) for _ in requests]
if self.chain_heads:
for result_index, (input_ids, seq_id, max_tokens) in enumerate(requests):
results[result_index] = self.draft(
input_ids,
seq_id=seq_id,
max_tokens=max_tokens,
)
return results

active: List["MTPDraftProvider.DraftManyState"] = []
for result_index, (input_ids, seq_id, max_tokens) in enumerate(requests):
if (
Expand Down
Loading