diff --git a/train.py b/train.py index c9e99b4..ac8d588 100644 --- a/train.py +++ b/train.py @@ -121,6 +121,7 @@ if __name__ == "__main__": model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_key_value_heads = config["model"]["num_key_value_heads"] + model_config.max_position_embeddings = SEQ_LEN start_time = time.time() model = Llama(config=model_config)