diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 5b165b7..0293623 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -377,13 +377,18 @@ class GPTModel(GPTPreTrainedModel): else: # Set prenorm=False here since we don't need the residual if not self.parallel_block: - hidden_states = dropout_add_layer_norm( + fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm) + hidden_states = fused_add_norm_fn( hidden_states, residual, self.ln_f.weight, self.ln_f.bias, self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, residual_in_fp32=self.residual_in_fp32 ) else: - hidden_states, _ = dropout_add_layer_norm_parallel_residual( + fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual + if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm_parallel_residual) + hidden_states, _ = fused_add_norm_fn( hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias, None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, residual_in_fp32=self.residual_in_fp32 diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 733ee59..f798122 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -176,13 +176,13 @@ def test_llama_parallel(model_name, world_size): print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') - assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') - assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() + assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() @pytest.mark.parametrize('model_name', ["7B"]) @@ -267,11 +267,10 @@ def test_llama_generation(model_name): del model hf_error = (logits_hf - logits_ref).abs().max().item() - # For some reason logits_parallel is off by quite a bit more than 2x - assert (logits_parallel - logits_ref).abs().max().item() < 8 * hf_error + assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error print(f'HF fp16 logits max diff: {hf_error}') - print(f'Logits max diff: {(logits - logits_parallel).abs().max().item() }') - assert (logits - logits_parallel).abs().max().item() < 2 * hf_error - print(f'Logits CG max diff: {(logits_cg - logits_parallel).abs().max().item() }') + print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') + assert (logits - logits_ref).abs().max().item() < 2 * hf_error + print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') assert torch.equal(logits_cg, logits)