ChuxiJ commited on
Commit
9ff5c6a
·
1 Parent(s): 5f3faee

fix cfg kv block allocate

Browse files
acestep/llm_inference.py CHANGED
@@ -1957,6 +1957,13 @@ class LLMHandler:
1957
  reset_context()
1958
  except ImportError:
1959
  pass
 
 
 
 
 
 
 
1960
  # Clear CUDA cache to release any corrupted memory
1961
  if torch.cuda.is_available():
1962
  torch.cuda.empty_cache()
 
1957
  reset_context()
1958
  except ImportError:
1959
  pass
1960
+ # Also reset the LLM scheduler to release allocated KV cache blocks
1961
+ # This prevents 'deque index out of range' errors from block leaks
1962
+ try:
1963
+ if hasattr(self.llm, 'reset'):
1964
+ self.llm.reset()
1965
+ except Exception:
1966
+ pass # Ignore errors during cleanup
1967
  # Clear CUDA cache to release any corrupted memory
1968
  if torch.cuda.is_available():
1969
  torch.cuda.empty_cache()
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py CHANGED
@@ -84,6 +84,24 @@ class LLMEngine:
84
  def is_finished(self):
85
  return self.scheduler.is_finished()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def generate(
88
  self,
89
  prompts: list[str] | list[list[int]],
@@ -91,6 +109,11 @@ class LLMEngine:
91
  use_tqdm: bool = True,
92
  unconditional_prompts: list[str] | list[list[int]] | None = None,
93
  ) -> list[str]:
 
 
 
 
 
94
  if use_tqdm:
95
  pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
96
  if not isinstance(sampling_params, list):
@@ -101,24 +124,31 @@ class LLMEngine:
101
  self.add_request(prompt, sp, uncond_prompt)
102
  outputs = {}
103
  prefill_throughput = decode_throughput = 0.
104
- while not self.is_finished():
105
- t = perf_counter()
106
- output, num_tokens = self.step()
107
- if use_tqdm:
108
- if num_tokens > 0:
109
- prefill_throughput = num_tokens / (perf_counter() - t)
110
- else:
111
- decode_throughput = -num_tokens / (perf_counter() - t)
112
- pbar.set_postfix({
113
- "Prefill": f"{int(prefill_throughput)}tok/s",
114
- "Decode": f"{int(decode_throughput)}tok/s",
115
- })
116
- for seq_id, token_ids in output:
117
- outputs[seq_id] = token_ids
118
  if use_tqdm:
119
- pbar.update(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
121
  outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
122
- if use_tqdm:
123
- pbar.close()
124
  return outputs
 
84
  def is_finished(self):
85
  return self.scheduler.is_finished()
86
 
87
+ def reset(self):
88
+ """
89
+ Reset the scheduler state and release all allocated blocks.
90
+ This should be called when an exception occurs during generation to prevent
91
+ KV cache block leaks that can cause 'deque index out of range' errors.
92
+ """
93
+ # Deallocate all running sequences
94
+ while self.scheduler.running:
95
+ seq = self.scheduler.running.popleft()
96
+ if seq.block_table: # Only deallocate if blocks are allocated
97
+ self.scheduler.block_manager.deallocate(seq)
98
+
99
+ # Deallocate all waiting sequences (they might have blocks from preemption)
100
+ while self.scheduler.waiting:
101
+ seq = self.scheduler.waiting.popleft()
102
+ if seq.block_table:
103
+ self.scheduler.block_manager.deallocate(seq)
104
+
105
  def generate(
106
  self,
107
  prompts: list[str] | list[list[int]],
 
109
  use_tqdm: bool = True,
110
  unconditional_prompts: list[str] | list[list[int]] | None = None,
111
  ) -> list[str]:
112
+ # Clean up any residual state from previous interrupted generations
113
+ # This prevents 'deque index out of range' errors from accumulated block leaks
114
+ if not self.is_finished():
115
+ self.reset()
116
+
117
  if use_tqdm:
118
  pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
119
  if not isinstance(sampling_params, list):
 
124
  self.add_request(prompt, sp, uncond_prompt)
125
  outputs = {}
126
  prefill_throughput = decode_throughput = 0.
127
+ try:
128
+ while not self.is_finished():
129
+ t = perf_counter()
130
+ output, num_tokens = self.step()
 
 
 
 
 
 
 
 
 
 
131
  if use_tqdm:
132
+ if num_tokens > 0:
133
+ prefill_throughput = num_tokens / (perf_counter() - t)
134
+ else:
135
+ decode_throughput = -num_tokens / (perf_counter() - t)
136
+ pbar.set_postfix({
137
+ "Prefill": f"{int(prefill_throughput)}tok/s",
138
+ "Decode": f"{int(decode_throughput)}tok/s",
139
+ })
140
+ for seq_id, token_ids in output:
141
+ outputs[seq_id] = token_ids
142
+ if use_tqdm:
143
+ pbar.update(1)
144
+ except Exception:
145
+ # Clean up on exception to prevent block leaks
146
+ self.reset()
147
+ raise
148
+ finally:
149
+ if use_tqdm:
150
+ pbar.close()
151
+
152
  outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
153
  outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
 
 
154
  return outputs
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py CHANGED
@@ -41,8 +41,12 @@ class Scheduler:
41
 
42
  # Calculate tokens for both sequences
43
  total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
- can_allocate_both = (self.block_manager.can_allocate(seq) and
45
- self.block_manager.can_allocate(paired_seq))
 
 
 
 
46
 
47
  if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
48
  break
@@ -101,9 +105,13 @@ class Scheduler:
101
  # Remove paired_seq from temp_running
102
  temp_running.remove(paired_seq)
103
 
104
- # Check if both can append
105
- can_append_both = (self.block_manager.can_append(seq) and
106
- self.block_manager.can_append(paired_seq))
 
 
 
 
107
 
108
  if not can_append_both:
109
  # Try preempting other sequences
@@ -112,8 +120,8 @@ class Scheduler:
112
  other_seq = temp_running.pop(0)
113
  if other_seq != seq and other_seq != paired_seq:
114
  self.preempt(other_seq)
115
- can_append_both = (self.block_manager.can_append(seq) and
116
- self.block_manager.can_append(paired_seq))
117
  preempted = True
118
  else:
119
  temp_running.append(other_seq)
 
41
 
42
  # Calculate tokens for both sequences
43
  total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
+
45
+ # FIX: Check if we have enough blocks for BOTH sequences combined
46
+ # The old check was wrong: it checked each sequence independently,
47
+ # but didn't account for the total blocks needed by both
48
+ total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
49
+ can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
50
 
51
  if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
52
  break
 
105
  # Remove paired_seq from temp_running
106
  temp_running.remove(paired_seq)
107
 
108
+ # FIX: Check if we have enough blocks for BOTH sequences to append
109
+ # Each sequence needs 1 block when at block boundary (len % block_size == 1)
110
+ block_size = self.block_manager.block_size
111
+ blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
112
+ blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
113
+ total_blocks_needed = blocks_needed_seq + blocks_needed_paired
114
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
115
 
116
  if not can_append_both:
117
  # Try preempting other sequences
 
120
  other_seq = temp_running.pop(0)
121
  if other_seq != seq and other_seq != paired_seq:
122
  self.preempt(other_seq)
123
+ # Recalculate with the same correct logic
124
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
125
  preempted = True
126
  else:
127
  temp_running.append(other_seq)