torch_ext/test_layernorm.py
2024-12-14 13:34:30 +08:00

14 lines
300 B
Python

# coding=utf-8
import torch
import torch_cuda_ext.core as core
import torch.nn.functional as F
eps = float(0.01)
gamma = float(1)
states = torch.randn(size=(100, 1024)).half().cuda()
res_states = F.rms_norm(states, [1024], eps=eps)
print(res_states)
core.rms_norm(states, eps, gamma)
print(states)