made Tensor Parallel API compliant
This commit is contained in:
parent
abd1edf9f9
commit
2b2781a374
@ -10,7 +10,9 @@ import src.distributed.process_group_manager as pgm
|
||||
|
||||
class ContextParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__(model, config)
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
||||
|
||||
@ -4,6 +4,10 @@ import torch.nn as nn
|
||||
|
||||
class TensorParallel(nn.Module):
|
||||
def __init__(self, model, init_method = init.xavier_normal_):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
("attention", "q_proj", "column"),
|
||||
("attention", "k_proj", "column"),
|
||||
@ -47,4 +51,9 @@ class TensorParallel(nn.Module):
|
||||
init_method=self.init_method
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
Loading…
Reference in New Issue
Block a user