torch_ext/csrc/softmax.ptx
2025-03-28 22:19:03 +08:00

258 lines
5.6 KiB
Plaintext

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-34097967
// Cuda compilation tools, release 12.4, V12.4.131
// Based on NVVM 7.0.1
//
.version 8.4
.target sm_52
.address_size 64
// .globl _Z7findMaxPKfPfi
.extern .shared .align 16 .b8 sharedMax[];
.extern .shared .align 16 .b8 sharedSum[];
.visible .entry _Z7findMaxPKfPfi(
.param .u64 _Z7findMaxPKfPfi_param_0,
.param .u64 _Z7findMaxPKfPfi_param_1,
.param .u32 _Z7findMaxPKfPfi_param_2
)
{
.reg .pred %p<6>;
.reg .f32 %f<9>;
.reg .b32 %r<15>;
.reg .b64 %rd<9>;
ld.param.u64 %rd1, [_Z7findMaxPKfPfi_param_0];
ld.param.u64 %rd2, [_Z7findMaxPKfPfi_param_1];
ld.param.u32 %r9, [_Z7findMaxPKfPfi_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r2, %r1, %r3;
setp.ge.u32 %p1, %r4, %r9;
mov.f32 %f8, 0fFF800000;
@%p1 bra $L__BB0_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r4, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f8, [%rd5];
$L__BB0_2:
shl.b32 %r10, %r3, 2;
mov.u32 %r11, sharedMax;
add.s32 %r5, %r11, %r10;
st.shared.f32 [%r5], %f8;
bar.sync 0;
shr.u32 %r14, %r1, 1;
setp.eq.s32 %p2, %r14, 0;
@%p2 bra $L__BB0_6;
$L__BB0_3:
setp.ge.u32 %p3, %r3, %r14;
@%p3 bra $L__BB0_5;
ld.shared.f32 %f4, [%r5];
shl.b32 %r12, %r14, 2;
add.s32 %r13, %r5, %r12;
ld.shared.f32 %f5, [%r13];
max.f32 %f6, %f4, %f5;
st.shared.f32 [%r5], %f6;
$L__BB0_5:
bar.sync 0;
shr.u32 %r14, %r14, 1;
setp.ne.s32 %p4, %r14, 0;
@%p4 bra $L__BB0_3;
$L__BB0_6:
setp.ne.s32 %p5, %r3, 0;
@%p5 bra $L__BB0_8;
ld.shared.f32 %f7, [sharedMax];
cvta.to.global.u64 %rd6, %rd2;
mul.wide.u32 %rd7, %r2, 4;
add.s64 %rd8, %rd6, %rd7;
st.global.f32 [%rd8], %f7;
$L__BB0_8:
ret;
}
// .globl _Z10computeExpPKffPfi
.visible .entry _Z10computeExpPKffPfi(
.param .u64 _Z10computeExpPKffPfi_param_0,
.param .f32 _Z10computeExpPKffPfi_param_1,
.param .u64 _Z10computeExpPKffPfi_param_2,
.param .u32 _Z10computeExpPKffPfi_param_3
)
{
.reg .pred %p<2>;
.reg .f32 %f<20>;
.reg .b32 %r<8>;
.reg .b64 %rd<8>;
ld.param.u64 %rd1, [_Z10computeExpPKffPfi_param_0];
ld.param.f32 %f1, [_Z10computeExpPKffPfi_param_1];
ld.param.u64 %rd2, [_Z10computeExpPKffPfi_param_2];
ld.param.u32 %r2, [_Z10computeExpPKffPfi_param_3];
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.u32 %p1, %r1, %r2;
@%p1 bra $L__BB1_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r1, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f2, [%rd5];
sub.f32 %f3, %f2, %f1;
mov.f32 %f4, 0f3F000000;
mov.f32 %f5, 0f3BBB989D;
fma.rn.f32 %f6, %f3, %f5, %f4;
cvt.sat.f32.f32 %f7, %f6;
mov.f32 %f8, 0f4B400001;
mov.f32 %f9, 0f437C0000;
fma.rm.f32 %f10, %f7, %f9, %f8;
add.f32 %f11, %f10, 0fCB40007F;
neg.f32 %f12, %f11;
mov.f32 %f13, 0f3FB8AA3B;
fma.rn.f32 %f14, %f3, %f13, %f12;
mov.f32 %f15, 0f32A57060;
fma.rn.f32 %f16, %f3, %f15, %f14;
mov.b32 %r6, %f10;
shl.b32 %r7, %r6, 23;
mov.b32 %f17, %r7;
ex2.approx.ftz.f32 %f18, %f16;
mul.f32 %f19, %f18, %f17;
cvta.to.global.u64 %rd6, %rd2;
add.s64 %rd7, %rd6, %rd4;
st.global.f32 [%rd7], %f19;
$L__BB1_2:
ret;
}
// .globl _Z13block_softmaxPKf
.visible .entry _Z13block_softmaxPKf(
.param .u64 _Z13block_softmaxPKf_param_0
)
{
ret;
}
// .globl _Z10computeSumPKfPfi
.visible .entry _Z10computeSumPKfPfi(
.param .u64 _Z10computeSumPKfPfi_param_0,
.param .u64 _Z10computeSumPKfPfi_param_1,
.param .u32 _Z10computeSumPKfPfi_param_2
)
{
.reg .pred %p<6>;
.reg .f32 %f<9>;
.reg .b32 %r<15>;
.reg .b64 %rd<9>;
ld.param.u64 %rd1, [_Z10computeSumPKfPfi_param_0];
ld.param.u64 %rd2, [_Z10computeSumPKfPfi_param_1];
ld.param.u32 %r9, [_Z10computeSumPKfPfi_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r2, %r1, %r3;
setp.ge.u32 %p1, %r4, %r9;
mov.f32 %f8, 0f00000000;
@%p1 bra $L__BB3_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r4, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f8, [%rd5];
$L__BB3_2:
shl.b32 %r10, %r3, 2;
mov.u32 %r11, sharedSum;
add.s32 %r5, %r11, %r10;
st.shared.f32 [%r5], %f8;
bar.sync 0;
shr.u32 %r14, %r1, 1;
setp.eq.s32 %p2, %r14, 0;
@%p2 bra $L__BB3_6;
$L__BB3_3:
setp.ge.u32 %p3, %r3, %r14;
@%p3 bra $L__BB3_5;
shl.b32 %r12, %r14, 2;
add.s32 %r13, %r5, %r12;
ld.shared.f32 %f4, [%r5];
ld.shared.f32 %f5, [%r13];
add.f32 %f6, %f5, %f4;
st.shared.f32 [%r5], %f6;
$L__BB3_5:
bar.sync 0;
shr.u32 %r14, %r14, 1;
setp.ne.s32 %p4, %r14, 0;
@%p4 bra $L__BB3_3;
$L__BB3_6:
setp.ne.s32 %p5, %r3, 0;
@%p5 bra $L__BB3_8;
ld.shared.f32 %f7, [sharedSum];
cvta.to.global.u64 %rd6, %rd2;
mul.wide.u32 %rd7, %r2, 4;
add.s64 %rd8, %rd6, %rd7;
st.global.f32 [%rd8], %f7;
$L__BB3_8:
ret;
}
// .globl _Z14computeSoftmaxPffi
.visible .entry _Z14computeSoftmaxPffi(
.param .u64 _Z14computeSoftmaxPffi_param_0,
.param .f32 _Z14computeSoftmaxPffi_param_1,
.param .u32 _Z14computeSoftmaxPffi_param_2
)
{
.reg .pred %p<2>;
.reg .f32 %f<4>;
.reg .b32 %r<6>;
.reg .b64 %rd<5>;
ld.param.u64 %rd1, [_Z14computeSoftmaxPffi_param_0];
ld.param.f32 %f1, [_Z14computeSoftmaxPffi_param_1];
ld.param.u32 %r2, [_Z14computeSoftmaxPffi_param_2];
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.u32 %p1, %r1, %r2;
@%p1 bra $L__BB4_2;
cvta.to.global.u64 %rd2, %rd1;
mul.wide.u32 %rd3, %r1, 4;
add.s64 %rd4, %rd2, %rd3;
ld.global.f32 %f2, [%rd4];
div.rn.f32 %f3, %f2, %f1;
st.global.f32 [%rd4], %f3;
$L__BB4_2:
ret;
}