Merge QKV into one linear layer (#15)
This commit is contained in:
parent
2c5cd0defe
commit
1f01a18d39
@ -33,22 +33,21 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# TODO: Merge the gate and down linear layers.
|
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||||
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
bias=False, gather_output=False,
|
||||||
bias=False, gather_output=False,
|
perform_initialization=False)
|
||||||
perform_initialization=False)
|
|
||||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||||
bias=False, input_is_parallel=True,
|
bias=False, input_is_parallel=True,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
|
||||||
bias=False, gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
assert hidden_act == 'silu'
|
assert hidden_act == 'silu'
|
||||||
self.act_fn = nn.SiLU()
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
gate, _ = self.gate_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
up, _ = self.up_proj(x)
|
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
|
||||||
|
gate, up = torch.split(gate_up, 1, dim=-2)
|
||||||
|
gate = gate.squeeze(dim=-2).contiguous()
|
||||||
|
up = up.squeeze(dim=-2).contiguous()
|
||||||
x = self.act_fn(gate) * up
|
x = self.act_fn(gate) * up
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x)
|
||||||
return x
|
return x
|
||||||
@ -70,24 +69,9 @@ class LlamaAttention(nn.Module):
|
|||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
# TODO: Merge the QKV linear layers.
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
self.q_proj = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.total_num_heads * self.head_dim,
|
3 * self.total_num_heads * self.head_dim,
|
||||||
bias=False,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False,
|
|
||||||
)
|
|
||||||
self.k_proj = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.total_num_heads * self.head_dim,
|
|
||||||
bias=False,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False,
|
|
||||||
)
|
|
||||||
self.v_proj = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.total_num_heads * self.head_dim,
|
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
@ -109,9 +93,12 @@ class LlamaAttention(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q, _ = self.q_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
k, _ = self.k_proj(hidden_states)
|
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||||
v, _ = self.v_proj(hidden_states)
|
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||||
|
q = q.squeeze(dim=-2).contiguous()
|
||||||
|
k = k.squeeze(dim=-2).contiguous()
|
||||||
|
v = v.squeeze(dim=-2).contiguous()
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||||
@ -230,8 +217,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||||
"q_proj.weight", "k_proj.weight",
|
"qkv_proj.weight", "gate_proj.weight",
|
||||||
"v_proj.weight", "gate_proj.weight",
|
|
||||||
"up_proj.weight"]
|
"up_proj.weight"]
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
@ -239,23 +225,42 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, param in state_dict.items():
|
for name, param in state_dict.items():
|
||||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
if "qkv_proj" in name or "gate_up_proj" in name:
|
||||||
name)))
|
if "qkv_proj" in name:
|
||||||
for p in self._column_parallel_weights:
|
original_name = "qkv_proj"
|
||||||
if p in name:
|
weight_names = ["q_proj", "k_proj", "v_proj"]
|
||||||
shard_size = param.shape[0]
|
shard_size = param.shape[0] // 3
|
||||||
loaded_weight = loaded_weight[
|
else:
|
||||||
|
original_name = "gate_up_proj"
|
||||||
|
weight_names = ["gate_proj", "up_proj"]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
weights_to_concat = []
|
||||||
|
for weight_name in weight_names:
|
||||||
|
weight = np.load(os.path.join(
|
||||||
|
weights_path, name.replace(original_name, weight_name)))
|
||||||
|
weights_to_concat.append(weight[
|
||||||
shard_size * tensor_model_parallel_rank
|
shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)])
|
||||||
break
|
loaded_weight = torch.from_numpy(
|
||||||
for p in self._row_parallel_weights:
|
np.concatenate(weights_to_concat, axis=0))
|
||||||
if p in name:
|
else:
|
||||||
shard_size = param.shape[1]
|
loaded_weight = torch.from_numpy(
|
||||||
loaded_weight = loaded_weight[
|
np.load(os.path.join(weights_path, name)))
|
||||||
:,
|
for p in self._column_parallel_weights:
|
||||||
shard_size * tensor_model_parallel_rank
|
if p in name:
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
shard_size = param.shape[0]
|
||||||
break
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
for p in self._row_parallel_weights:
|
||||||
|
if p in name:
|
||||||
|
shard_size = param.shape[1]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
:,
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
assert param.shape == loaded_weight.shape
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|||||||
@ -53,16 +53,9 @@ class OPTAttention(nn.Module):
|
|||||||
self.head_dim = embed_dim // total_num_heads
|
self.head_dim = embed_dim // total_num_heads
|
||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
|
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||||
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
gather_output=False,
|
||||||
gather_output=False,
|
perform_initialization=False)
|
||||||
perform_initialization=False)
|
|
||||||
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
@ -75,16 +68,18 @@ class OPTAttention(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q, _ = self.q_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
k, _ = self.k_proj(hidden_states)
|
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||||
v, _ = self.v_proj(hidden_states)
|
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||||
|
q = q.squeeze(dim=-2).contiguous()
|
||||||
|
k = k.squeeze(dim=-2).contiguous()
|
||||||
|
v = v.squeeze(dim=-2).contiguous()
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
class OPTDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: OPTConfig):
|
def __init__(self, config: OPTConfig):
|
||||||
@ -262,11 +257,7 @@ class OPTForCausalLM(nn.Module):
|
|||||||
self.lm_head_weight, hidden_states, input_metadata)
|
self.lm_head_weight, hidden_states, input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["embed_tokens.weight",
|
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||||
"q_proj.weight", "k_proj.weight",
|
|
||||||
"v_proj.weight", "fc1.weight"]
|
|
||||||
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
|
|
||||||
"v_proj.bias", "fc1.bias"]
|
|
||||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
def load_weights(self, weights_path: str):
|
||||||
@ -275,24 +266,35 @@ class OPTForCausalLM(nn.Module):
|
|||||||
for name, param in state_dict.items():
|
for name, param in state_dict.items():
|
||||||
if "lm_head_weight" in name:
|
if "lm_head_weight" in name:
|
||||||
continue
|
continue
|
||||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
if "qkv_proj" in name:
|
||||||
name)))
|
shard_size = param.shape[0] // 3
|
||||||
for p in (self._column_parallel_weights
|
weights_to_concat = []
|
||||||
+ self._column_parallel_biases):
|
for weight_name in ["q_proj", "k_proj", "v_proj"]:
|
||||||
if p in name:
|
weight = np.load(os.path.join(
|
||||||
shard_size = param.shape[0]
|
weights_path, name.replace("qkv_proj", weight_name)))
|
||||||
loaded_weight = loaded_weight[
|
weights_to_concat.append(weight[
|
||||||
shard_size * tensor_model_parallel_rank
|
shard_size * tensor_model_parallel_rank
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
:shard_size * (tensor_model_parallel_rank + 1)])
|
||||||
break
|
loaded_weight = torch.from_numpy(
|
||||||
for p in self._row_parallel_weights:
|
np.concatenate(weights_to_concat, axis=0))
|
||||||
if p in name:
|
else:
|
||||||
shard_size = param.shape[1]
|
loaded_weight = torch.from_numpy(
|
||||||
loaded_weight = loaded_weight[
|
np.load(os.path.join(weights_path, name)))
|
||||||
:,
|
for p in self._column_parallel_weights:
|
||||||
shard_size * tensor_model_parallel_rank
|
if p in name:
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
shard_size = param.shape[0]
|
||||||
break
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
for p in self._row_parallel_weights:
|
||||||
|
if p in name:
|
||||||
|
shard_size = param.shape[1]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
:,
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
assert param.shape == loaded_weight.shape
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user