// bitsandbytes MPS Metal kernels - ObjC++ dispatch // Interfaces between PyTorch MPS tensors and Metal compute kernels. // Uses the same dispatch pattern as kernels-community/activation, with // get_command_buffer() moved inside dispatch_sync to avoid race conditions // during model loading. #include #import #import #include #include #include #include #ifdef EMBEDDED_METALLIB_HEADER #include EMBEDDED_METALLIB_HEADER #endif // ============================================================================ // Metal helpers // ============================================================================ static inline id getMTLBufferStorage(const torch::Tensor& t) { return __builtin_bit_cast(id, t.storage().data()); } namespace { static id library = nil; id get_library() { if (library != nil) return library; id device = MTLCreateSystemDefaultDevice(); NSError* error = nil; #ifdef EMBEDDED_METALLIB_HEADER library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error); if (library == nil) { std::cerr << "Failed to create Metal library from embedded header" << std::endl; if (error) std::cerr << "Error: " << [[error localizedDescription] UTF8String] << std::endl; } #else library = [device newDefaultLibrary]; if (library == nil) { std::cerr << "Failed to load Metal library" << std::endl; if (error) std::cerr << "Error: " << [[error localizedDescription] UTF8String] << std::endl; } #endif return library; } id get_pipeline(const std::string& name) { static std::unordered_map> cache; auto it = cache.find(name); if (it != cache.end()) return it->second; id lib = get_library(); if (!lib) return nil; id func = [lib newFunctionWithName:[NSString stringWithUTF8String:name.c_str()]]; if (!func) { std::cerr << "Kernel not found: " << name << std::endl; return nil; } NSError* error = nil; id device = MTLCreateSystemDefaultDevice(); id state = [device newComputePipelineStateWithFunction:func error:&error]; if (!state) { std::cerr << "Failed to create pipeline for " << name << std::endl; return nil; } cache[name] = state; return state; } std::string type_str(torch::ScalarType type) { switch (type) { case torch::kFloat32: return "float"; case torch::kFloat16: return "half"; case torch::kBFloat16: return "bfloat16_t"; default: throw std::runtime_error("Unsupported dtype for BnB MPS kernels"); } } void set_tensor( id enc, const torch::Tensor& t, int index) { [enc setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:index]; } } // namespace // ============================================================================ // Public API: quantize_4bit // ============================================================================ std::tuple bnb_quantize_4bit( at::Tensor input, int64_t blocksize, int64_t quant_type) { TORCH_CHECK(input.is_mps(), "Input must be on MPS device"); TORCH_CHECK( blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512, "Only blocksize 64, 128, 256, and 512 are supported"); TORCH_CHECK( quant_type == 1 || quant_type == 2, "quant_type must be 1 (FP4) or 2 (NF4)"); int n = static_cast(input.numel()); int num_blocks = (n + static_cast(blocksize) - 1) / static_cast(blocksize); int packed_size = (n + 1) / 2; auto absmax = torch::empty({num_blocks}, input.options().dtype(torch::kFloat32)); auto packed = torch::empty({packed_size}, input.options().dtype(torch::kUInt8)); std::stringstream ss; ss << "bnb_quantize_blockwise_" << type_str(input.scalar_type()) << "_bs_" << blocksize << "_qt_" << quant_type; auto pipeline = get_pipeline(ss.str()); TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); @autoreleasepool { dispatch_sync(torch::mps::get_dispatch_queue(), ^{ @autoreleasepool { id commandBuffer = torch::mps::get_command_buffer(); TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); id encoder = [commandBuffer computeCommandEncoder]; TORCH_CHECK(encoder, "Failed to create compute encoder"); [encoder setComputePipelineState:pipeline]; int idx = 0; set_tensor(encoder, input, idx++); set_tensor(encoder, absmax, idx++); set_tensor(encoder, packed, idx++); [encoder setBytes:&n length:sizeof(int) atIndex:idx++]; NSUInteger threads_per_tg = pipeline.threadExecutionWidth; MTLSize grid = MTLSizeMake(num_blocks, 1, 1); MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1); [encoder dispatchThreads:grid threadsPerThreadgroup:tg]; [encoder endEncoding]; torch::mps::commit(); } }); } return std::make_tuple(packed, absmax); } // ============================================================================ // Public API: dequantize_blockwise // ============================================================================ at::Tensor bnb_dequantize_4bit( at::Tensor packed, at::Tensor absmax, int64_t blocksize, int64_t quant_type, int64_t numel, torch::ScalarType output_dtype) { TORCH_CHECK(packed.is_mps(), "packed must be on MPS device"); TORCH_CHECK(absmax.is_mps(), "absmax must be on MPS device"); TORCH_CHECK( blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512, "Only blocksize 64, 128, 256, and 512 are supported"); int n = static_cast(numel); int num_blocks = (n + static_cast(blocksize) - 1) / static_cast(blocksize); auto output = torch::empty({n}, packed.options().dtype(output_dtype)); std::stringstream ss; ss << "bnb_dequantize_blockwise_" << type_str(output_dtype) << "_bs_" << blocksize << "_qt_" << quant_type; auto pipeline = get_pipeline(ss.str()); TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); @autoreleasepool { dispatch_sync(torch::mps::get_dispatch_queue(), ^{ @autoreleasepool { id commandBuffer = torch::mps::get_command_buffer(); TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); id encoder = [commandBuffer computeCommandEncoder]; TORCH_CHECK(encoder, "Failed to create compute encoder"); [encoder setComputePipelineState:pipeline]; int idx = 0; set_tensor(encoder, packed, idx++); set_tensor(encoder, absmax, idx++); set_tensor(encoder, output, idx++); [encoder setBytes:&n length:sizeof(int) atIndex:idx++]; NSUInteger max_tg = pipeline.maxTotalThreadsPerThreadgroup; NSUInteger desired = (blocksize + 1) / 2; NSUInteger tg_size = std::min(max_tg, std::max(static_cast(1), desired)); if (tg_size < pipeline.threadExecutionWidth) { tg_size = std::min(pipeline.threadExecutionWidth, max_tg); } MTLSize grid = MTLSizeMake(tg_size * num_blocks, 1, 1); MTLSize tg = MTLSizeMake(tg_size, 1, 1); [encoder dispatchThreads:grid threadsPerThreadgroup:tg]; [encoder endEncoding]; torch::mps::commit(); } }); } return output; } // ============================================================================ // Public API: GEMV (matrix-vector multiply) // y = dequant(W) @ x // ============================================================================ at::Tensor bnb_gemv_4bit( at::Tensor x, at::Tensor w, at::Tensor absmax, int64_t blocksize, int64_t quant_type, int64_t output_features) { TORCH_CHECK( x.is_mps() && w.is_mps() && absmax.is_mps(), "All tensors must be on MPS device"); int K = static_cast(x.size(-1)); int N = static_cast(output_features); auto out_sizes = x.sizes().vec(); out_sizes.back() = N; auto y = torch::zeros(out_sizes, x.options()); std::stringstream ss; ss << "bnb_qmv_" << type_str(x.scalar_type()) << "_bs_" << blocksize << "_qt_" << quant_type; auto pipeline = get_pipeline(ss.str()); TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); @autoreleasepool { dispatch_sync(torch::mps::get_dispatch_queue(), ^{ @autoreleasepool { id commandBuffer = torch::mps::get_command_buffer(); TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); id encoder = [commandBuffer computeCommandEncoder]; TORCH_CHECK(encoder, "Failed to create compute encoder"); [encoder setComputePipelineState:pipeline]; int idx = 0; set_tensor(encoder, w, idx++); set_tensor(encoder, absmax, idx++); set_tensor(encoder, x, idx++); set_tensor(encoder, y, idx++); [encoder setBytes:&K length:sizeof(int) atIndex:idx++]; [encoder setBytes:&N length:sizeof(int) atIndex:idx++]; int rows_per_tg = 8; int grid_y = (N + rows_per_tg - 1) / rows_per_tg; [encoder dispatchThreadgroups:MTLSizeMake(1, grid_y, 1) threadsPerThreadgroup:MTLSizeMake(32 * 2, 1, 1)]; [encoder endEncoding]; torch::mps::commit(); } }); } return y; } // ============================================================================ // Public API: GEMM (matrix-matrix multiply with transposed weight) // Y = X @ dequant(W).T // ============================================================================ at::Tensor bnb_gemm_4bit( at::Tensor x, at::Tensor w, at::Tensor absmax, int64_t blocksize, int64_t quant_type, int64_t output_features) { TORCH_CHECK( x.is_mps() && w.is_mps() && absmax.is_mps(), "All tensors must be on MPS device"); TORCH_CHECK(x.dim() >= 2, "Input must be at least 2D for GEMM"); int K = static_cast(x.size(-1)); int M = static_cast(x.size(-2)); int N = static_cast(output_features); auto out_sizes = x.sizes().vec(); out_sizes.back() = N; auto y = torch::zeros(out_sizes, x.options()); std::stringstream ss; ss << "bnb_qmm_t_" << type_str(x.scalar_type()) << "_bs_" << blocksize << "_qt_" << quant_type; auto pipeline = get_pipeline(ss.str()); TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); @autoreleasepool { dispatch_sync(torch::mps::get_dispatch_queue(), ^{ @autoreleasepool { id commandBuffer = torch::mps::get_command_buffer(); TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); id encoder = [commandBuffer computeCommandEncoder]; TORCH_CHECK(encoder, "Failed to create compute encoder"); [encoder setComputePipelineState:pipeline]; int idx = 0; set_tensor(encoder, w, idx++); set_tensor(encoder, absmax, idx++); set_tensor(encoder, x, idx++); set_tensor(encoder, y, idx++); [encoder setBytes:&K length:sizeof(int) atIndex:idx++]; [encoder setBytes:&N length:sizeof(int) atIndex:idx++]; [encoder setBytes:&M length:sizeof(int) atIndex:idx++]; int grid_x = (N + 31) / 32; int grid_y = (M + 31) / 32; [encoder dispatchThreadgroups:MTLSizeMake(grid_x, grid_y, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder endEncoding]; torch::mps::commit(); } }); } return y; }