[Bugfix] Command-R Max Model Length (#3727)
This commit is contained in:
parent
f510395bbf
commit
97356f3c7e
@ -765,15 +765,20 @@ def _get_and_verify_max_len(
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Command-R
|
||||
"model_max_length",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
max_len_key = None
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
max_len = getattr(hf_config, key, None)
|
||||
if max_len is not None:
|
||||
max_len_key = key if max_len < derived_max_model_len \
|
||||
else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len, max_len)
|
||||
if derived_max_model_len == float("inf"):
|
||||
if max_model_len is not None:
|
||||
# If max_model_len is specified, we use it.
|
||||
@ -799,10 +804,18 @@ def _get_and_verify_max_len(
|
||||
if max_model_len is None:
|
||||
max_model_len = derived_max_model_len
|
||||
elif max_model_len > derived_max_model_len:
|
||||
raise ValueError(
|
||||
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||
" in model's config.json). This may lead to incorrect model "
|
||||
"outputs or CUDA errors. Make sure the value is correct and "
|
||||
"within the model context size.")
|
||||
# Some models might have a separate key for specifying model_max_length
|
||||
# that will be bigger than derived_max_model_len. We compare user input
|
||||
# with model_max_length and allow this override when it's smaller.
|
||||
model_max_length = getattr(hf_config, "model_max_length", None)
|
||||
if model_max_length is not None and max_model_len <= model_max_length:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"User-specified max_model_len ({max_model_len}) is greater "
|
||||
"than the derived max_model_len "
|
||||
f"({max_len_key}={derived_max_model_len} or model_max_length="
|
||||
f"{model_max_length} in model's config.json). This may lead "
|
||||
"to incorrect model outputs or CUDA errors. Make sure the "
|
||||
"value is correct and within the model context size.")
|
||||
return int(max_model_len)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user