diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index fa3fc63..035e8a0 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -15,7 +15,8 @@ class RingAttentionFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sm_scale, is_causal): comm = ContextComms("comm") - #NOTE: Find a better to save these tensors without cloning + #TODO(fmom): add flash attention + #TODO(fmom): Find a better to save these tensors without cloning k_og = k.clone() v_og = v.clone() out, lse = None, None