| # This test generates all variants of wmma intrinsics and verifies that LLVM |
| # generates correct instructions for them. |
| |
| # RUN: python %s > %t.ll |
| # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll |
| |
| from itertools import product |
| from string import Template |
| |
| def make_wmma_slice_ty(abcd, itype): |
| elt_ty = "<2 x half>" if itype == "f16" else "float" |
| num_elts = 4 if abcd in "cd" and itype == "f16" else 8; |
| return [elt_ty] * num_elts |
| |
| def make_wmma_ld_ret_ty(abc, itype): |
| return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype)) |
| |
| # returns address space |
| def get_aspace(space): |
| space_map = { |
| ".global" : 1, |
| ".shared" : 3, |
| ".const" : 4, |
| ".local" : 5, |
| ".param" : 101, |
| "" : 0, |
| ".generic": 0 |
| } |
| return space_map[space]; |
| |
| def get_pspace(space): |
| return "p%di8" % get_aspace(space); |
| |
| # Convenient test patterns. |
| check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8) |
| check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4) |
| check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8) |
| |
| known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"] |
| |
| def gen_wmma_load_tests(): |
| load_template = """ |
| declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); |
| |
| ; CHECK-LABEL: .func {{.*}}test_${function}( |
| define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) { |
| ; CHECK: ${instruction} |
| ; CHECK: {${check_result}} |
| ; CHECK: [%rd{{[0-9]+}}]${stride_pattern} |
| %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); |
| ret ${ret_ty} %v0; |
| } |
| |
| ; CHECK-LABEL: .func{{.*}}test_${function}_o( |
| define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) { |
| ; CHECK: ${instruction} |
| ; CHECK: {${check_result}} |
| ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} |
| %src1 = getelementptr i8, i8 ${as}* %src, i32 128; |
| %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args}); |
| ret ${ret_ty} %v0; |
| } |
| """ |
| intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" |
| instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}" |
| |
| for geom, abc, layout, space, stride, itype in product( |
| known_geoms, |
| "abc", |
| ["row","col"], |
| ["",".shared",".global"], |
| ["", ".stride"], |
| ["f16", "f32"]): |
| |
| params = { |
| "abc" : abc, |
| "layout" : layout, |
| "space" : space, |
| "stride" : stride, |
| "itype" : itype, |
| "pspace" : get_pspace(space), |
| "as" : "addrspace(%d)" % get_aspace(space), |
| "geom" : geom, |
| } |
| |
| if itype == "f32" and abc != "c": |
| continue |
| |
| test_params = params |
| test_params["intrinsic"] = Template(intrinsic_template).substitute(params) |
| test_params["function"] = test_params["intrinsic"].replace(".","_") |
| test_params["instruction"] = Template(instruction_template).substitute(params) |
| test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) |
| if abc == "c" : |
| test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8 |
| else: |
| test_params["check_result"] = check_f16_8 |
| |
| if stride: |
| test_params["extra_args"] = ", i32 %stride"; |
| test_params["stride_pattern"] = ", %r{{[0-9]+}}" |
| else: |
| test_params["extra_args"] = "" |
| test_params["stride_pattern"] = "" |
| |
| print(Template(load_template).substitute(test_params)) |
| |
| def make_wmma_slice_args(itype, abcd, prefix="v"): |
| return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t |
| in enumerate(make_wmma_slice_ty(abcd, itype))]) |
| |
| def gen_wmma_store_tests(): |
| store_template = """ |
| declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args}); |
| |
| ; CHECK-LABEL: .func {{.*}}test_${function}( |
| define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) { |
| ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}} |
| ; CHECK: {${check_args}} |
| ; CHECK: ${stride_pattern} |
| call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args}); |
| ret void |
| } |
| |
| ; CHECK-LABEL: .func{{.*}}test_${function}_o( |
| define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) { |
| ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128] |
| ; CHECK: ${check_args} |
| ; CHECK: ${stride_pattern} |
| %src1 = getelementptr i8, i8 ${as}* %src, i32 128; |
| call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args}); |
| ret void |
| } |
| """ |
| intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" |
| instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}" |
| |
| for geom, abc, layout, space, stride, itype in product( |
| known_geoms, |
| "d", |
| ["row","col"], |
| ["",".shared",".global"], |
| ["", ".stride"], |
| ["f16", "f32"]): |
| |
| params = { |
| "abc" : abc, |
| "layout" : layout, |
| "space" : space, |
| "stride" : stride, |
| "itype" : itype, |
| "pspace" : get_pspace(space), |
| "as" : "addrspace(%d)" % get_aspace(space), |
| "geom" : geom, |
| } |
| |
| test_params = params |
| test_params["intrinsic"] = Template(intrinsic_template).substitute(params) |
| test_params["function"] = test_params["intrinsic"].replace(".","_") |
| test_params["instruction"] = Template(instruction_template).substitute(params) |
| test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) |
| test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8 |
| if stride: |
| test_params["extra_args"] = ", i32 %stride"; |
| test_params["stride_pattern"] = ", %r{{[0-9]+}};" |
| else: |
| test_params["extra_args"] = "" |
| test_params["stride_pattern"] = ";" |
| test_params["args"] = make_wmma_slice_args(itype, "d"); |
| |
| print(Template(store_template).substitute(test_params)) |
| |
| def gen_wmma_mma_tests(): |
| mma_template = """ |
| declare ${ret_ty} @${intrinsic}( |
| ${args}); |
| |
| ; CHECK-LABEL: .func {{.*}}test_${function}( |
| define ${ret_ty} @test_${function}( |
| ${args}) { |
| ; CHECK: ${instruction} |
| ; CHECK-NEXT: ${check_d} |
| ; CHECK-NEXT: ${check_ab} |
| ; CHECK-NEXT: ${check_ab} |
| ; CHECK-NEXT: ${check_c} |
| %r = call ${ret_ty} @${intrinsic}( |
| ${args}); |
| ret ${ret_ty} %r; |
| } |
| """ |
| intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}" |
| instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" |
| |
| for geom, alayout, blayout, ctype, dtype, satf in product( |
| known_geoms, |
| ["row","col"], |
| ["row","col"], |
| ["f16", "f32"], |
| ["f16", "f32"], |
| [".satfinite", ""]): |
| |
| params = { |
| "alayout" : alayout, |
| "blayout" : blayout, |
| "ctype" : ctype, |
| "dtype" : dtype, |
| "satf" : satf, |
| "geom" : geom, |
| } |
| |
| test_params = params |
| test_params["intrinsic"] = Template(intrinsic_template).substitute(params) |
| test_params["function"] = test_params["intrinsic"].replace(".", "_") |
| test_params["instruction"] = Template(instruction_template).substitute(params) |
| test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype) |
| test_params["check_ab"] = check_f16_8 |
| test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8 |
| test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8 |
| args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd) |
| for abcd, t in (("a", "f16"), |
| ("b", "f16"), |
| ("c", ctype))) |
| test_params["args"] = args |
| print(Template(mma_template).substitute(test_params)) |
| |
| def main(): |
| gen_wmma_load_tests() |
| gen_wmma_store_tests() |
| gen_wmma_mma_tests() |
| |
| main() |