add todo ring attention

This commit is contained in:
ferdinand.mom 2024-10-29 14:08:53 +00:00
parent 46af5b0425
commit 987a7c5c99

View File

@ -15,7 +15,8 @@ class RingAttentionFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale, is_causal):
comm = ContextComms("comm")
#NOTE: Find a better to save these tensors without cloning
#TODO(fmom): add flash attention
#TODO(fmom): Find a better to save these tensors without cloning
k_og = k.clone()
v_og = v.clone()
out, lse = None, None