folder refactoring + split cp & pp communications near implementation details
This commit is contained in:
parent
1a000975be
commit
db926026a6
@ -1,11 +1,10 @@
|
||||
# Inspired by https://github.com/zhuzilin/ring-flash-attention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from typing import Any, Optional, Tuple
|
||||
from picotron.distributed.distributed_primtives import ContextComms
|
||||
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.context_parallel.cp_communications import ContextCommunicate
|
||||
|
||||
def ring_attention(q, k, v, sm_scale, is_causal):
|
||||
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)
|
||||
@ -14,7 +13,7 @@ class RingAttentionFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale, is_causal):
|
||||
comm = ContextComms("comm")
|
||||
comm = ContextCommunicate("comm")
|
||||
#TODO(fmom): add flash attention
|
||||
#TODO(fmom): Find a better to save these tensors without cloning
|
||||
k_og = k.clone()
|
||||
@ -52,8 +51,8 @@ class RingAttentionFunc(torch.autograd.Function):
|
||||
sm_scale = ctx.sm_scale
|
||||
is_causal = ctx.is_causal
|
||||
|
||||
kv_comm = ContextComms("kv_comm")
|
||||
d_kv_comm = ContextComms("d_kv_comm")
|
||||
kv_comm = ContextCommunicate("kv_comm")
|
||||
d_kv_comm = ContextCommunicate("d_kv_comm")
|
||||
dq, dk, dv = None, None, None
|
||||
next_dk, next_dv = None, None
|
||||
|
||||
@ -187,4 +186,4 @@ def update_rope_for_context_parallel(cos, sin):
|
||||
assert seq_len % cp_word_size == 0, f"Input sequence length ({seq_len}) must be divisible by cp_world_size ({cp_word_size})"
|
||||
size_per_partition = seq_len // cp_word_size
|
||||
start_idx, end_idx = cp_rank * size_per_partition, (cp_rank + 1) * size_per_partition
|
||||
return cos[start_idx:end_idx], sin[start_idx:end_idx]
|
||||
return cos[start_idx:end_idx], sin[start_idx:end_idx]
|
||||
54
picotron/context_parallel/cp_communications.py
Normal file
54
picotron/context_parallel/cp_communications.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from typing import List
|
||||
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
class ContextCommunicate:
|
||||
def __init__(self, msg: str = ""):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
self._pending_operations: List[dist.P2POp] = []
|
||||
self._active_requests = None
|
||||
self.rank = pgm.process_group_manager.cp_rank
|
||||
self.world_size = pgm.process_group_manager.cp_world_size
|
||||
self.send_rank = pgm.process_group_manager.cp_send_rank
|
||||
self.recv_rank = pgm.process_group_manager.cp_recv_rank
|
||||
if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True)
|
||||
|
||||
def send_recv(self, tensor_to_send, recv_tensor=None):
|
||||
if recv_tensor is None:
|
||||
result_tensor = torch.zeros_like(tensor_to_send)
|
||||
else:
|
||||
result_tensor = recv_tensor
|
||||
|
||||
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group)
|
||||
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group)
|
||||
|
||||
self._pending_operations.extend([send_operation, recv_operation])
|
||||
|
||||
if VERBOSE:
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True)
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True)
|
||||
return result_tensor
|
||||
|
||||
def commit(self):
|
||||
if self._active_requests is not None: raise RuntimeError("Commit called twice")
|
||||
self._active_requests = dist.batch_isend_irecv(self._pending_operations)
|
||||
if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True)
|
||||
|
||||
def wait(self):
|
||||
if self._active_requests is None: raise RuntimeError("Wait called before commit")
|
||||
for i, request in enumerate(self._active_requests):
|
||||
request.wait()
|
||||
if VERBOSE:
|
||||
operation_type = "send" if i % 2 == 0 else "receive"
|
||||
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
|
||||
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
|
||||
torch.cuda.synchronize()
|
||||
self._active_requests = None
|
||||
self._pending_operations = []
|
||||
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
||||
@ -1,9 +1,9 @@
|
||||
import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import contextlib
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from picotron.parallel.data_parallel.bucket import BucketManager
|
||||
|
||||
from picotron.data_parallel.bucket import BucketManager
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
@ -1,107 +0,0 @@
|
||||
import os
|
||||
import picotron.process_group_manager as pgm
|
||||
from typing import List, Optional
|
||||
import torch, torch.distributed as dist
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_prev_rank
|
||||
elif operation == 'send_forward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return
|
||||
dest = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'recv_backward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'send_backward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return
|
||||
dest = pgm.process_group_manager.pp_prev_rank
|
||||
is_send = operation.startswith('send')
|
||||
peer_rank = dest if is_send else src
|
||||
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
|
||||
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in dist.batch_isend_irecv([op])]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return tensor if not is_send else None
|
||||
|
||||
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
|
||||
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
|
||||
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
||||
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in reqs]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return recv_tensor
|
||||
|
||||
class ContextComms:
|
||||
def __init__(self, msg: str = ""):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
self._pending_operations: List[dist.P2POp] = []
|
||||
self._active_requests = None
|
||||
self.rank = pgm.process_group_manager.cp_rank
|
||||
self.world_size = pgm.process_group_manager.cp_world_size
|
||||
self.send_rank = pgm.process_group_manager.cp_send_rank
|
||||
self.recv_rank = pgm.process_group_manager.cp_recv_rank
|
||||
if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True)
|
||||
|
||||
def send_recv(self, tensor_to_send, recv_tensor=None):
|
||||
if recv_tensor is None:
|
||||
result_tensor = torch.zeros_like(tensor_to_send)
|
||||
else:
|
||||
result_tensor = recv_tensor
|
||||
|
||||
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group)
|
||||
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group)
|
||||
|
||||
self._pending_operations.extend([send_operation, recv_operation])
|
||||
|
||||
if VERBOSE:
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True)
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True)
|
||||
return result_tensor
|
||||
|
||||
def commit(self):
|
||||
if self._active_requests is not None: raise RuntimeError("Commit called twice")
|
||||
self._active_requests = dist.batch_isend_irecv(self._pending_operations)
|
||||
if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True)
|
||||
|
||||
def wait(self):
|
||||
if self._active_requests is None: raise RuntimeError("Wait called before commit")
|
||||
for i, request in enumerate(self._active_requests):
|
||||
request.wait()
|
||||
if VERBOSE:
|
||||
operation_type = "send" if i % 2 == 0 else "receive"
|
||||
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
|
||||
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
|
||||
torch.cuda.synchronize()
|
||||
self._active_requests = None
|
||||
self._pending_operations = []
|
||||
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
||||
|
||||
def all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# only the last stage of the pipeline parallelism contains the loss
|
||||
# we need to average the loss among the data/context parallel group
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
def all_reduce_gradients_across_dp_cp_ranks(model):
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
# Average the gradients across all DP & CP ranks
|
||||
param.grad /= pgm.process_group_manager.cp_dp_world_size
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from picotron.parallel.context_parallel import ring_attention, update_rope_for_context_parallel
|
||||
from picotron.context_parallel import context_parallel
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||
@ -122,7 +122,7 @@ class Attention(nn.Module):
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
# Ring attention for context parallelism
|
||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
out = ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
|
||||
out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
|
||||
elif os.getenv('FLASH_ATTEN', '1') == '1':
|
||||
# flash attention, this is faster!
|
||||
out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim]
|
||||
@ -161,7 +161,7 @@ class DecoderLayer(nn.Module):
|
||||
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
|
||||
|
||||
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
|
||||
self.cos, self.sin = update_rope_for_context_parallel(self.cos, self.sin)
|
||||
self.cos, self.sin = context_parallel.update_rope_for_context_parallel(self.cos, self.sin)
|
||||
|
||||
def forward(self, x, attention_mask = None, position_ids = None):
|
||||
#TODO: Use the default position_ids for RoPE during training. If we have time, work on generation
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import os
|
||||
from picotron.pipeline_parallel.pp_communications import pipeline_communicate, bidirectional_pipeline_communicate
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
46
picotron/pipeline_parallel/pp_communications.py
Normal file
46
picotron/pipeline_parallel/pp_communications.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
if operation == 'recv_forward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_prev_rank
|
||||
elif operation == 'send_forward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return
|
||||
dest = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'recv_backward':
|
||||
if pgm.process_group_manager.pp_is_last_stage: return None
|
||||
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
src = pgm.process_group_manager.pp_next_rank
|
||||
elif operation == 'send_backward':
|
||||
if pgm.process_group_manager.pp_is_first_stage: return
|
||||
dest = pgm.process_group_manager.pp_prev_rank
|
||||
is_send = operation.startswith('send')
|
||||
peer_rank = dest if is_send else src
|
||||
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
|
||||
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in dist.batch_isend_irecv([op])]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return tensor if not is_send else None
|
||||
|
||||
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
|
||||
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
|
||||
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
||||
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
||||
[req.wait() for req in reqs]
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return recv_tensor
|
||||
@ -2,7 +2,7 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.parallel.tensor_parallel.utils import VocabUtility
|
||||
from picotron.tensor_parallel.utils import VocabUtility
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Callable, Optional
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.parallel.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
from picotron.tensor_parallel.mappings import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
@ -2,7 +2,7 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.parallel.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
from picotron.tensor_parallel.utils import split_tensor_along_last_dim
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import picotron.process_group_manager as pgm
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from picotron.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
from picotron.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, initialize_weight_tensor
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
|
||||
17
train.py
17
train.py
@ -20,16 +20,23 @@ import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoConfig
|
||||
import numpy as np
|
||||
from picotron.parallel.tensor_parallel.tensor_parallel import TensorParallel
|
||||
from picotron.tensor_parallel.tensor_parallel import TensorParallel
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from picotron.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||
from model import Llama
|
||||
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from picotron.data_parallel.data_parallel_bucket import DataParallel
|
||||
from picotron.model import Llama
|
||||
import wandb
|
||||
from picotron.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks
|
||||
|
||||
def all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# only the last stage of the pipeline parallelism contains the loss
|
||||
# we need to average the loss among the data/context parallel group
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
acc_loss = 0.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user