Skip to content

Commit

Permalink
metal lowbit kernels: executorch ops
Browse files Browse the repository at this point in the history
Differential Revision: D65957345

Pull Request resolved: #1322
  • Loading branch information
manuelcandales authored Dec 13, 2024
1 parent 31234db commit ebc4303
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
*/
#ifdef USE_ATEN
using namespace at::native::mps;
using at::native::mps::MetalShaderLibrary;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
#endif
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
Expand Down
64 changes: 64 additions & 0 deletions torchao/experimental/kernels/mps/src/MetalShaderLibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torchao/experimental/kernels/mps/src/common.h>

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
id<MTLFunction> func = loadFunc(fname);

NSError* error = nil;
id<MTLDevice> device = get_metal_device();
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw std::runtime_error(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;

}

private:
std::string shaderSource;
id<MTLLibrary> lib = nil;

id<MTLFunction> loadFunc(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw std::runtime_error("Can't get function:" + func_name);
}
return func;
}

id<MTLLibrary> compileLibraryFromSource(
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLDevice> device = get_metal_device();
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw std::runtime_error(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}
};
101 changes: 2 additions & 99 deletions torchao/experimental/kernels/mps/src/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,12 @@

#pragma once

#include <iostream>
#include <stdexcept>

static void throw_exception(const std::string& str) {
std::cerr << str << std::endl;
throw std::runtime_error(str);
}

inline void dispatch_block(
[[maybe_unused]] id<MTLCommandQueue> queue,
void (^block)()) {
__block std::optional<std::exception_ptr> block_exception;
try {
block();
} catch (...) {
block_exception = std::current_exception();
}
if (block_exception) {
std::rethrow_exception(*block_exception);
}
}

inline id<MTLDevice> getMetalDevice() {
@autoreleasepool {
NSArray* devices = [MTLCopyAllDevices() autorelease];
if (devices.count == 0) {
throw_exception("Metal is not supported");
}
return devices[0];
}
}

static id<MTLDevice> MTL_DEVICE = getMetalDevice();

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;
}
};
id<MTLDevice> getMetalDevice();

class MPSStream {
public:
MPSStream() {
_commandQueue = [MTL_DEVICE newCommandQueue];
_commandQueue = [getMetalDevice() newCommandQueue];
}

~MPSStream() {
Expand Down Expand Up @@ -136,14 +47,6 @@ class MPSStream {
id<MTLComputeCommandEncoder> _commandEncoder = nil;
};

inline void finalize_block(MPSStream* mpsStream) {
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
}

inline MPSStream* getCurrentMPSStream() {
return new MPSStream();
}
20 changes: 20 additions & 0 deletions torchao/experimental/kernels/mps/src/OperationUtils.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <stdexcept>

id<MTLDevice> getMetalDevice() {
@autoreleasepool {
NSArray* devices = [MTLCopyAllDevices() autorelease];
if (devices.count == 0) {
throw std::runtime_error("Metal is not supported");
}
static id<MTLDevice> MTL_DEVICE = devices[0];
return MTL_DEVICE;
}
}
51 changes: 51 additions & 0 deletions torchao/experimental/kernels/mps/src/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
#elif defined(USE_EXECUTORCH)
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
using namespace executorch::backends::mps::delegate;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

inline void dispatch_block(
MPSStream* mpsStream,
void (^block)()) {
#if defined(USE_ATEN)
dispatch_sync_with_rethrow(mpsStream->queue(), block);
#elif defined(USE_EXECUTORCH)
dispatch_sync(mpsStream->queue(), block);
#else
(void)mpsStream;
block();
#endif
}

inline void optionally_wait_for_command_completion(MPSStream* mpsStream) {
#if defined(USE_ATEN)
#elif defined(USE_EXECUTORCH)
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
#else
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
#endif
}

inline id<MTLDevice> get_metal_device() {
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
return MPSDevice::getInstance()->device();
#else
return getMetalDevice();
#endif
}
21 changes: 4 additions & 17 deletions torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,11 @@
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <torchao/experimental/kernels/mps/src/common.h>
#include <torchao/experimental/kernels/mps/src/dispatch.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
#include <torchao/experimental/kernels/mps/src/packing.h>

#include <cassert>
#include <fstream>
#include <sstream>

#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
dispatch_sync_with_rethrow;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

namespace torchao::kernels::mps::lowbit {
namespace {

Expand Down Expand Up @@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
0};

MPSStream* mpsStream = getCurrentMPSStream();
dispatch_block(mpsStream->queue(), ^() {
dispatch_block(mpsStream, ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLComputePipelineState> cpl =
Expand All @@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
length:sizeof(uint32_t) * sizes.size()
atIndex:5];
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
finalize_block(mpsStream);
optionally_wait_for_command_completion(mpsStream);
}
});
}
Expand Down
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/mps/test/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
all: test_lowbit

test_lowbit: test_lowbit.mm
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation
test_lowbit: test_lowbit.mm ../src/OperationUtils.mm
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation

run: test_lowbit
./test_lowbit
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/mps/test/test_lowbit.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
id<MTLBuffer> rc = [device newBufferWithLength:length
options:MTLResourceStorageModeShared];
if (rc == nil) {
throw_exception(
throw std::runtime_error(
"Can't allocate " + std::to_string(length) + " bytes on GPU");
}
return rc;
Expand Down Expand Up @@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu(
: M(m), K(k), N(n), qGroupSize(group_size) {}

void init() {
allocBuffers(MTL_DEVICE);
allocBuffers(getMetalDevice());

T* a_ptr = reinterpret_cast<T*>([buf_A contents]);
uint8_t* w_ptr = reinterpret_cast<uint8_t*>([buf_W contents]);
Expand Down
32 changes: 26 additions & 6 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ endif()
find_package(Torch REQUIRED)

# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
Expand All @@ -41,7 +44,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)

target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
Expand All @@ -53,8 +56,25 @@ find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

install(
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
EXPORT _targets
DESTINATION lib
add_library(torchao_ops_mps_aten SHARED)
target_link_libraries(torchao_ops_mps_aten PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_aten
)
install(TARGETS torchao_ops_mps_aten DESTINATION lib)

if(TORCHAO_BUILD_EXECUTORCH_OPS)
include_directories(${CMAKE_INSTALL_PREFIX}/../..)
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

add_library(torchao_ops_mps_executorch STATIC)
target_link_libraries(torchao_ops_mps_executorch PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
)
install(TARGETS torchao_ops_mps_executorch DESTINATION lib)
endif()
Loading

0 comments on commit ebc4303

Please sign in to comment.