[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
This commit is contained in:
parent
ad113948a6
commit
a9a4b4e4f2
@ -377,13 +377,18 @@ class GPTModel(GPTPreTrainedModel):
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
if not self.parallel_block:
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = dropout_add_layer_norm_parallel_residual(
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
hidden_states, _ = fused_add_norm_fn(
|
||||
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
|
||||
prenorm=False, residual_in_fp32=self.residual_in_fp32
|
||||
|
||||
@ -176,13 +176,13 @@ def test_llama_parallel(model_name, world_size):
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["7B"])
|
||||
@ -267,11 +267,10 @@ def test_llama_generation(model_name):
|
||||
del model
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
# For some reason logits_parallel is off by quite a bit more than 2x
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 8 * hf_error
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_parallel).abs().max().item() }')
|
||||
assert (logits - logits_parallel).abs().max().item() < 2 * hf_error
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_parallel).abs().max().item() }')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user