From 987a7c5c992bd2f6756d09cd890d70988da79479 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 29 Oct 2024 14:08:53 +0000 Subject: [PATCH] add todo ring attention --- src/parallel/context_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index fa3fc63..035e8a0 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -15,7 +15,8 @@ class RingAttentionFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sm_scale, is_causal): comm = ContextComms("comm") - #NOTE: Find a better to save these tensors without cloning + #TODO(fmom): add flash attention + #TODO(fmom): Find a better to save these tensors without cloning k_og = k.clone() v_og = v.clone() out, lse = None, None