add todo ring attention
This commit is contained in:
parent
46af5b0425
commit
987a7c5c99
@ -15,7 +15,8 @@ class RingAttentionFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale, is_causal):
|
||||
comm = ContextComms("comm")
|
||||
#NOTE: Find a better to save these tensors without cloning
|
||||
#TODO(fmom): add flash attention
|
||||
#TODO(fmom): Find a better to save these tensors without cloning
|
||||
k_og = k.clone()
|
||||
v_og = v.clone()
|
||||
out, lse = None, None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user