Merge QKV into one linear layer (#15)
This commit is contained in:
parent
2c5cd0defe
commit
1f01a18d39
@ -33,22 +33,21 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO: Merge the gate and down linear layers.
|
||||
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
assert hidden_act == 'silu'
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
gate, _ = self.gate_proj(x)
|
||||
up, _ = self.up_proj(x)
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
|
||||
gate, up = torch.split(gate_up, 1, dim=-2)
|
||||
gate = gate.squeeze(dim=-2).contiguous()
|
||||
up = up.squeeze(dim=-2).contiguous()
|
||||
x = self.act_fn(gate) * up
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
@ -70,24 +69,9 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# TODO: Merge the QKV linear layers.
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
@ -109,9 +93,12 @@ class LlamaAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
@ -230,8 +217,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"q_proj.weight", "k_proj.weight",
|
||||
"v_proj.weight", "gate_proj.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
@ -239,23 +225,42 @@ class LlamaForCausalLM(nn.Module):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
||||
name)))
|
||||
for p in self._column_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
if "qkv_proj" in name or "gate_up_proj" in name:
|
||||
if "qkv_proj" in name:
|
||||
original_name = "qkv_proj"
|
||||
weight_names = ["q_proj", "k_proj", "v_proj"]
|
||||
shard_size = param.shape[0] // 3
|
||||
else:
|
||||
original_name = "gate_up_proj"
|
||||
weight_names = ["gate_proj", "up_proj"]
|
||||
shard_size = param.shape[0] // 2
|
||||
weights_to_concat = []
|
||||
for weight_name in weight_names:
|
||||
weight = np.load(os.path.join(
|
||||
weights_path, name.replace(original_name, weight_name)))
|
||||
weights_to_concat.append(weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.concatenate(weights_to_concat, axis=0))
|
||||
else:
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.load(os.path.join(weights_path, name)))
|
||||
for p in self._column_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
@ -53,16 +53,9 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
|
||||
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
@ -75,16 +68,18 @@ class OPTAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
@ -262,11 +257,7 @@ class OPTForCausalLM(nn.Module):
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight",
|
||||
"q_proj.weight", "k_proj.weight",
|
||||
"v_proj.weight", "fc1.weight"]
|
||||
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
|
||||
"v_proj.bias", "fc1.bias"]
|
||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, weights_path: str):
|
||||
@ -275,24 +266,35 @@ class OPTForCausalLM(nn.Module):
|
||||
for name, param in state_dict.items():
|
||||
if "lm_head_weight" in name:
|
||||
continue
|
||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
||||
name)))
|
||||
for p in (self._column_parallel_weights
|
||||
+ self._column_parallel_biases):
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
if "qkv_proj" in name:
|
||||
shard_size = param.shape[0] // 3
|
||||
weights_to_concat = []
|
||||
for weight_name in ["q_proj", "k_proj", "v_proj"]:
|
||||
weight = np.load(os.path.join(
|
||||
weights_path, name.replace("qkv_proj", weight_name)))
|
||||
weights_to_concat.append(weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.concatenate(weights_to_concat, axis=0))
|
||||
else:
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.load(os.path.join(weights_path, name)))
|
||||
for p in self._column_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user