need to update max position embeding when seq_len is greater (for rope)

This commit is contained in:
ferdinand.mom 2024-10-30 15:12:06 +00:00
parent 508d57f948
commit 363dbd5c05

View File

@ -121,6 +121,7 @@ if __name__ == "__main__":
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)