// Copyright © 2023-2024 Apple Inc. #include #include template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template < bool rows_aligned, bool cols_aligned, bool transpose, typename T, typename mma_t, typename loader_a_t, typename loader_b_t> METAL_FUNC void gemm_loop_unaligned( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const int k_iterations, const short tgp_bm, const short tgp_bn, const short tgp_bk) { for (int k = 0; k < k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup memory if (rows_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(short2(tgp_bk, tgp_bm)); } if (cols_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe( transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } } template METAL_FUNC void gemm_loop_finalize( threadgroup T* As, threadgroup T* Bs, thread mma_t& mma_op, thread loader_a_t& loader_a, thread loader_b_t& loader_b, const short2 tile_a, const short2 tile_b) { loader_a.load_safe(tile_a); loader_b.load_safe(tile_b); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); }