[Bugfix][Model] Jamba assertions and no chunked prefill by default for Jamba (#6784)

This commit is contained in:
tomeras91 2024-07-27 06:45:31 +03:00 committed by GitHub
parent 3c3012398e
commit ed94e4f427
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 1 deletions

View File

@ -754,10 +754,14 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (
model_config.contains_seqlen_agnostic_layers(
parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not self.enable_prefix_caching):
and not self.enable_prefix_caching
and not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "

View File

@ -644,6 +644,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
assert not scheduler_config.chunked_prefill_enabled, \
"Jamba currently does not support chunked prefill"
assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching"
super().__init__()
self.config = config
self.scheduler_config = scheduler_config