# coding=utf-8 import torch import torch_cuda_ext.core as core import torch.nn.functional as F eps = float(0.01) gamma = float(1) states = torch.randn(size=(100, 1024)).half().cuda() res_states = F.rms_norm(states, [1024], eps=eps) print(res_states) core.rms_norm(states, eps, gamma) print(states)