torch_ext/training/pp.py

10 lines
122 B
Python
Raw Normal View History

2025-03-28 22:19:03 +08:00
# coding=utf-8
import torch
import torch.nn as nn
class PipelineParallel(object):
def __init__(self):
pass