| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <torch/torch.h> |
| |
|
| | #import <Foundation/Foundation.h> |
| | #import <Metal/Metal.h> |
| |
|
| | #include <algorithm> |
| | #include <iostream> |
| | #include <sstream> |
| | #include <unordered_map> |
| |
|
| | #ifdef EMBEDDED_METALLIB_HEADER |
| | #include EMBEDDED_METALLIB_HEADER |
| | #endif |
| |
|
| | |
| | |
| | |
| |
|
| | static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& t) { |
| | return __builtin_bit_cast(id<MTLBuffer>, t.storage().data()); |
| | } |
| |
|
| | namespace { |
| |
|
| | static id<MTLLibrary> library = nil; |
| |
|
| | id<MTLLibrary> get_library() { |
| | if (library != nil) |
| | return library; |
| | id<MTLDevice> device = MTLCreateSystemDefaultDevice(); |
| | NSError* error = nil; |
| |
|
| | #ifdef EMBEDDED_METALLIB_HEADER |
| | library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error); |
| | if (library == nil) { |
| | std::cerr << "Failed to create Metal library from embedded header" |
| | << std::endl; |
| | if (error) |
| | std::cerr << "Error: " << [[error localizedDescription] UTF8String] |
| | << std::endl; |
| | } |
| | #else |
| | library = [device newDefaultLibrary]; |
| | if (library == nil) { |
| | std::cerr << "Failed to load Metal library" << std::endl; |
| | if (error) |
| | std::cerr << "Error: " << [[error localizedDescription] UTF8String] |
| | << std::endl; |
| | } |
| | #endif |
| | return library; |
| | } |
| |
|
| | id<MTLComputePipelineState> get_pipeline(const std::string& name) { |
| | static std::unordered_map<std::string, id<MTLComputePipelineState>> cache; |
| | auto it = cache.find(name); |
| | if (it != cache.end()) |
| | return it->second; |
| |
|
| | id<MTLLibrary> lib = get_library(); |
| | if (!lib) |
| | return nil; |
| |
|
| | id<MTLFunction> func = |
| | [lib newFunctionWithName:[NSString stringWithUTF8String:name.c_str()]]; |
| | if (!func) { |
| | std::cerr << "Kernel not found: " << name << std::endl; |
| | return nil; |
| | } |
| |
|
| | NSError* error = nil; |
| | id<MTLDevice> device = MTLCreateSystemDefaultDevice(); |
| | id<MTLComputePipelineState> state = |
| | [device newComputePipelineStateWithFunction:func error:&error]; |
| | if (!state) { |
| | std::cerr << "Failed to create pipeline for " << name << std::endl; |
| | return nil; |
| | } |
| | cache[name] = state; |
| | return state; |
| | } |
| |
|
| | std::string type_str(torch::ScalarType type) { |
| | switch (type) { |
| | case torch::kFloat32: |
| | return "float"; |
| | case torch::kFloat16: |
| | return "half"; |
| | case torch::kBFloat16: |
| | return "bfloat16_t"; |
| | default: |
| | throw std::runtime_error("Unsupported dtype for BnB MPS kernels"); |
| | } |
| | } |
| |
|
| | void set_tensor( |
| | id<MTLComputeCommandEncoder> enc, |
| | const torch::Tensor& t, |
| | int index) { |
| | [enc setBuffer:getMTLBufferStorage(t) |
| | offset:t.storage_offset() * t.element_size() |
| | atIndex:index]; |
| | } |
| |
|
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit( |
| | at::Tensor input, |
| | int64_t blocksize, |
| | int64_t quant_type) { |
| | TORCH_CHECK(input.is_mps(), "Input must be on MPS device"); |
| | TORCH_CHECK( |
| | blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512, |
| | "Only blocksize 64, 128, 256, and 512 are supported"); |
| | TORCH_CHECK( |
| | quant_type == 1 || quant_type == 2, |
| | "quant_type must be 1 (FP4) or 2 (NF4)"); |
| |
|
| | int n = static_cast<int>(input.numel()); |
| | int num_blocks = |
| | (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize); |
| | int packed_size = (n + 1) / 2; |
| |
|
| | auto absmax = |
| | torch::empty({num_blocks}, input.options().dtype(torch::kFloat32)); |
| | auto packed = |
| | torch::empty({packed_size}, input.options().dtype(torch::kUInt8)); |
| |
|
| | std::stringstream ss; |
| | ss << "bnb_quantize_blockwise_" << type_str(input.scalar_type()) << "_bs_" |
| | << blocksize << "_qt_" << quant_type; |
| |
|
| | auto pipeline = get_pipeline(ss.str()); |
| | TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); |
| |
|
| | @autoreleasepool { |
| | dispatch_sync(torch::mps::get_dispatch_queue(), ^{ |
| | @autoreleasepool { |
| | id<MTLCommandBuffer> commandBuffer = |
| | torch::mps::get_command_buffer(); |
| | TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); |
| |
|
| | id<MTLComputeCommandEncoder> encoder = |
| | [commandBuffer computeCommandEncoder]; |
| | TORCH_CHECK(encoder, "Failed to create compute encoder"); |
| |
|
| | [encoder setComputePipelineState:pipeline]; |
| |
|
| | int idx = 0; |
| | set_tensor(encoder, input, idx++); |
| | set_tensor(encoder, absmax, idx++); |
| | set_tensor(encoder, packed, idx++); |
| | [encoder setBytes:&n length:sizeof(int) atIndex:idx++]; |
| |
|
| | NSUInteger threads_per_tg = pipeline.threadExecutionWidth; |
| | MTLSize grid = MTLSizeMake(num_blocks, 1, 1); |
| | MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1); |
| | [encoder dispatchThreads:grid threadsPerThreadgroup:tg]; |
| | [encoder endEncoding]; |
| |
|
| | torch::mps::commit(); |
| | } |
| | }); |
| | } |
| |
|
| | return std::make_tuple(packed, absmax); |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | at::Tensor bnb_dequantize_4bit( |
| | at::Tensor packed, |
| | at::Tensor absmax, |
| | int64_t blocksize, |
| | int64_t quant_type, |
| | int64_t numel, |
| | torch::ScalarType output_dtype) { |
| | TORCH_CHECK(packed.is_mps(), "packed must be on MPS device"); |
| | TORCH_CHECK(absmax.is_mps(), "absmax must be on MPS device"); |
| | TORCH_CHECK( |
| | blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512, |
| | "Only blocksize 64, 128, 256, and 512 are supported"); |
| |
|
| | int n = static_cast<int>(numel); |
| | int num_blocks = |
| | (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize); |
| |
|
| | auto output = torch::empty({n}, packed.options().dtype(output_dtype)); |
| |
|
| | std::stringstream ss; |
| | ss << "bnb_dequantize_blockwise_" << type_str(output_dtype) << "_bs_" |
| | << blocksize << "_qt_" << quant_type; |
| |
|
| | auto pipeline = get_pipeline(ss.str()); |
| | TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); |
| |
|
| | @autoreleasepool { |
| | dispatch_sync(torch::mps::get_dispatch_queue(), ^{ |
| | @autoreleasepool { |
| | id<MTLCommandBuffer> commandBuffer = |
| | torch::mps::get_command_buffer(); |
| | TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); |
| |
|
| | id<MTLComputeCommandEncoder> encoder = |
| | [commandBuffer computeCommandEncoder]; |
| | TORCH_CHECK(encoder, "Failed to create compute encoder"); |
| |
|
| | [encoder setComputePipelineState:pipeline]; |
| |
|
| | int idx = 0; |
| | set_tensor(encoder, packed, idx++); |
| | set_tensor(encoder, absmax, idx++); |
| | set_tensor(encoder, output, idx++); |
| | [encoder setBytes:&n length:sizeof(int) atIndex:idx++]; |
| |
|
| | NSUInteger max_tg = pipeline.maxTotalThreadsPerThreadgroup; |
| | NSUInteger desired = (blocksize + 1) / 2; |
| | NSUInteger tg_size = |
| | std::min(max_tg, std::max(static_cast<NSUInteger>(1), desired)); |
| | if (tg_size < pipeline.threadExecutionWidth) { |
| | tg_size = std::min(pipeline.threadExecutionWidth, max_tg); |
| | } |
| |
|
| | MTLSize grid = MTLSizeMake(tg_size * num_blocks, 1, 1); |
| | MTLSize tg = MTLSizeMake(tg_size, 1, 1); |
| | [encoder dispatchThreads:grid threadsPerThreadgroup:tg]; |
| | [encoder endEncoding]; |
| |
|
| | torch::mps::commit(); |
| | } |
| | }); |
| | } |
| |
|
| | return output; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | at::Tensor bnb_gemv_4bit( |
| | at::Tensor x, |
| | at::Tensor w, |
| | at::Tensor absmax, |
| | int64_t blocksize, |
| | int64_t quant_type, |
| | int64_t output_features) { |
| | TORCH_CHECK( |
| | x.is_mps() && w.is_mps() && absmax.is_mps(), |
| | "All tensors must be on MPS device"); |
| |
|
| | int K = static_cast<int>(x.size(-1)); |
| | int N = static_cast<int>(output_features); |
| |
|
| | auto out_sizes = x.sizes().vec(); |
| | out_sizes.back() = N; |
| | auto y = torch::zeros(out_sizes, x.options()); |
| |
|
| | std::stringstream ss; |
| | ss << "bnb_qmv_" << type_str(x.scalar_type()) << "_bs_" << blocksize |
| | << "_qt_" << quant_type; |
| |
|
| | auto pipeline = get_pipeline(ss.str()); |
| | TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); |
| |
|
| | @autoreleasepool { |
| | dispatch_sync(torch::mps::get_dispatch_queue(), ^{ |
| | @autoreleasepool { |
| | id<MTLCommandBuffer> commandBuffer = |
| | torch::mps::get_command_buffer(); |
| | TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); |
| |
|
| | id<MTLComputeCommandEncoder> encoder = |
| | [commandBuffer computeCommandEncoder]; |
| | TORCH_CHECK(encoder, "Failed to create compute encoder"); |
| |
|
| | [encoder setComputePipelineState:pipeline]; |
| |
|
| | int idx = 0; |
| | set_tensor(encoder, w, idx++); |
| | set_tensor(encoder, absmax, idx++); |
| | set_tensor(encoder, x, idx++); |
| | set_tensor(encoder, y, idx++); |
| | [encoder setBytes:&K length:sizeof(int) atIndex:idx++]; |
| | [encoder setBytes:&N length:sizeof(int) atIndex:idx++]; |
| |
|
| | int rows_per_tg = 8; |
| | int grid_y = (N + rows_per_tg - 1) / rows_per_tg; |
| |
|
| | [encoder dispatchThreadgroups:MTLSizeMake(1, grid_y, 1) |
| | threadsPerThreadgroup:MTLSizeMake(32 * 2, 1, 1)]; |
| | [encoder endEncoding]; |
| |
|
| | torch::mps::commit(); |
| | } |
| | }); |
| | } |
| |
|
| | return y; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | at::Tensor bnb_gemm_4bit( |
| | at::Tensor x, |
| | at::Tensor w, |
| | at::Tensor absmax, |
| | int64_t blocksize, |
| | int64_t quant_type, |
| | int64_t output_features) { |
| | TORCH_CHECK( |
| | x.is_mps() && w.is_mps() && absmax.is_mps(), |
| | "All tensors must be on MPS device"); |
| | TORCH_CHECK(x.dim() >= 2, "Input must be at least 2D for GEMM"); |
| |
|
| | int K = static_cast<int>(x.size(-1)); |
| | int M = static_cast<int>(x.size(-2)); |
| | int N = static_cast<int>(output_features); |
| |
|
| | auto out_sizes = x.sizes().vec(); |
| | out_sizes.back() = N; |
| | auto y = torch::zeros(out_sizes, x.options()); |
| |
|
| | std::stringstream ss; |
| | ss << "bnb_qmm_t_" << type_str(x.scalar_type()) << "_bs_" << blocksize |
| | << "_qt_" << quant_type; |
| |
|
| | auto pipeline = get_pipeline(ss.str()); |
| | TORCH_CHECK(pipeline, "Kernel not found: ", ss.str()); |
| |
|
| | @autoreleasepool { |
| | dispatch_sync(torch::mps::get_dispatch_queue(), ^{ |
| | @autoreleasepool { |
| | id<MTLCommandBuffer> commandBuffer = |
| | torch::mps::get_command_buffer(); |
| | TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); |
| |
|
| | id<MTLComputeCommandEncoder> encoder = |
| | [commandBuffer computeCommandEncoder]; |
| | TORCH_CHECK(encoder, "Failed to create compute encoder"); |
| |
|
| | [encoder setComputePipelineState:pipeline]; |
| |
|
| | int idx = 0; |
| | set_tensor(encoder, w, idx++); |
| | set_tensor(encoder, absmax, idx++); |
| | set_tensor(encoder, x, idx++); |
| | set_tensor(encoder, y, idx++); |
| | [encoder setBytes:&K length:sizeof(int) atIndex:idx++]; |
| | [encoder setBytes:&N length:sizeof(int) atIndex:idx++]; |
| | [encoder setBytes:&M length:sizeof(int) atIndex:idx++]; |
| |
|
| | int grid_x = (N + 31) / 32; |
| | int grid_y = (M + 31) / 32; |
| |
|
| | [encoder dispatchThreadgroups:MTLSizeMake(grid_x, grid_y, 1) |
| | threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; |
| | [encoder endEncoding]; |
| |
|
| | torch::mps::commit(); |
| | } |
| | }); |
| | } |
| |
|
| | return y; |
| | } |
| |
|