14 lines
300 B
Python
14 lines
300 B
Python
# coding=utf-8
|
|
import torch
|
|
import torch_cuda_ext.core as core
|
|
import torch.nn.functional as F
|
|
|
|
eps = float(0.01)
|
|
gamma = float(1)
|
|
states = torch.randn(size=(100, 1024)).half().cuda()
|
|
res_states = F.rms_norm(states, [1024], eps=eps)
|
|
print(res_states)
|
|
|
|
core.rms_norm(states, eps, gamma)
|
|
print(states)
|