Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training ops kernels: Speeding up the Llama-based MoE architectures #6734

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions deepspeed/tops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

#############################
######## Training Ops #######
#############################

from .moe_gating import *
from .swiglu import *
from .rope import *
62 changes: 62 additions & 0 deletions deepspeed/tops/includes/moe_gating.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "moe_gating.cuh"

void gate_scatter(torch::Tensor& moe_input,
torch::Tensor& expert_count_cumsums,
torch::Tensor& mapped_slots,
torch::Tensor& activations,
torch::Tensor& expert_counts,
torch::Tensor& mapped_expert_counts,
torch::Tensor& scores,
torch::Tensor& assignments,
torch::Tensor& offsets,
torch::Tensor& backup_offsets,
int top_k,
int capacity,
bool use_rts);

void gate_fwd(torch::Tensor& moe_input,
torch::Tensor& expert_count_cumsums,
torch::Tensor& mapped_slots,
torch::Tensor& activations,
torch::Tensor& expert_counts,
torch::Tensor& mapped_expert_counts,
torch::Tensor& scores,
torch::Tensor& assignments,
torch::Tensor& offsets,
torch::Tensor& backup_offsets,
torch::Tensor& logits,
torch::Tensor& logits_out,
int top_k,
int capacity,
bool use_rts);

void gate_bwd(torch::Tensor& moe_input_grad,
torch::Tensor& scores_grad,
torch::Tensor& activations_grad,
torch::Tensor& logits_grad,
torch::Tensor& logits,
torch::Tensor& assignments,
torch::Tensor& offsets,
torch::Tensor& mapped_slots,
int top_k,
int capacity,
bool use_rts);


void gather_fwd(torch::Tensor& layer_output,
torch::Tensor& moe_output,
torch::Tensor& scores,
torch::Tensor& mapped_slots,
int top_k);

void gather_bwd(torch::Tensor& layer_output_grad,
torch::Tensor& scores_grad,
torch::Tensor& moe_output_grad,
torch::Tensor& moe_output,
torch::Tensor& scores,
torch::Tensor& mapped_slots,
int top_k);
Loading