| | |
| |
|
| | #include <metal_simdgroup> |
| | #include <metal_stdlib> |
| |
|
| | template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t> |
| | METAL_FUNC void gemm_loop_aligned( |
| | threadgroup T* As, |
| | threadgroup T* Bs, |
| | thread mma_t& mma_op, |
| | thread loader_a_t& loader_a, |
| | thread loader_b_t& loader_b, |
| | const int k_iterations) { |
| | for (int k = 0; k < k_iterations; k++) { |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | loader_a.load_unsafe(); |
| | loader_b.load_unsafe(); |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | mma_op.mma(As, Bs); |
| |
|
| | |
| | loader_a.next(); |
| | loader_b.next(); |
| | } |
| | } |
| |
|
| | template < |
| | bool rows_aligned, |
| | bool cols_aligned, |
| | bool transpose, |
| | typename T, |
| | typename mma_t, |
| | typename loader_a_t, |
| | typename loader_b_t> |
| | METAL_FUNC void gemm_loop_unaligned( |
| | threadgroup T* As, |
| | threadgroup T* Bs, |
| | thread mma_t& mma_op, |
| | thread loader_a_t& loader_a, |
| | thread loader_b_t& loader_b, |
| | const int k_iterations, |
| | const short tgp_bm, |
| | const short tgp_bn, |
| | const short tgp_bk) { |
| | for (int k = 0; k < k_iterations; k++) { |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | if (rows_aligned) { |
| | loader_a.load_unsafe(); |
| | } else { |
| | loader_a.load_safe(short2(tgp_bk, tgp_bm)); |
| | } |
| | if (cols_aligned) { |
| | loader_b.load_unsafe(); |
| | } else { |
| | loader_b.load_safe( |
| | transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); |
| | } |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | mma_op.mma(As, Bs); |
| |
|
| | |
| | loader_a.next(); |
| | loader_b.next(); |
| | } |
| | } |
| |
|
| | template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t> |
| | METAL_FUNC void gemm_loop_finalize( |
| | threadgroup T* As, |
| | threadgroup T* Bs, |
| | thread mma_t& mma_op, |
| | thread loader_a_t& loader_a, |
| | thread loader_b_t& loader_b, |
| | const short2 tile_a, |
| | const short2 tile_b) { |
| | loader_a.load_safe(tile_a); |
| | loader_b.load_safe(tile_b); |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| | mma_op.mma(As, Bs); |
| | } |