Merge QKV into one linear layer (#15)

This commit is contained in:
Zhuohan Li 2023-04-02 15:23:29 +08:00 committed by GitHub
parent 2c5cd0defe
commit 1f01a18d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 83 deletions

View File

@ -33,22 +33,21 @@ class LlamaMLP(nn.Module):
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x
@ -70,24 +69,9 @@ class LlamaAttention(nn.Module):
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
@ -109,9 +93,12 @@ class LlamaAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
@ -230,8 +217,7 @@ class LlamaForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
@ -239,23 +225,42 @@ class LlamaForCausalLM(nn.Module):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
if "qkv_proj" in name or "gate_up_proj" in name:
if "qkv_proj" in name:
original_name = "qkv_proj"
weight_names = ["q_proj", "k_proj", "v_proj"]
shard_size = param.shape[0] // 3
else:
original_name = "gate_up_proj"
weight_names = ["gate_proj", "up_proj"]
shard_size = param.shape[0] // 2
weights_to_concat = []
for weight_name in weight_names:
weight = np.load(os.path.join(
weights_path, name.replace(original_name, weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)

View File

@ -53,16 +53,9 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
@ -75,16 +68,18 @@ class OPTAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
@ -262,11 +257,7 @@ class OPTForCausalLM(nn.Module):
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "fc1.weight"]
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
"v_proj.bias", "fc1.bias"]
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, weights_path: str):
@ -275,24 +266,35 @@ class OPTForCausalLM(nn.Module):
for name, param in state_dict.items():
if "lm_head_weight" in name:
continue
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in (self._column_parallel_weights
+ self._column_parallel_biases):
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
if "qkv_proj" in name:
shard_size = param.shape[0] // 3
weights_to_concat = []
for weight_name in ["q_proj", "k_proj", "v_proj"]:
weight = np.load(os.path.join(
weights_path, name.replace("qkv_proj", weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)