Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add shm and model enabled with CI case
Browse files Browse the repository at this point in the history
Signed-off-by: Clark Chin <[email protected]>
  • Loading branch information
ClarkChin08 committed Jan 10, 2024
1 parent f7428ce commit eee9d31
Show file tree
Hide file tree
Showing 11 changed files with 618 additions and 143 deletions.
188 changes: 188 additions & 0 deletions .github/workflows/scripts/models/run_tp.sh

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions neural_speed/application/main_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,15 @@ int main(int argc, char** argv) { // NOLINT
model_reset_timings(ctx);
}

#ifdef NE_TP_MODEL
// sync here to make multi node run into inference at the same time
parallel_context* p_ctx = init_parallel_context();
if (get_tp_size(p_ctx) > 1) {
barrier(p_ctx);
}

#endif

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (embd.size() > 0) {
Expand Down
39 changes: 31 additions & 8 deletions neural_speed/core/layers/mha_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,8 +1344,19 @@ class MHAStableInterface {
const auto group_heads = p.head_num / p.heads_kv;
const auto sl_diff = p.sl_kv - p.sl_q;

// TP will need the real rank oder of k
int32_t k_offset = 0;
int32_t log_head_num = p.head_num;
#ifdef NE_TP_MODEL
parallel_context* p_ctx = init_parallel_context();
int32_t world_size = get_tp_size(p_ctx);
int32_t rank = get_tp_rank(p_ctx);
if (world_size > 1) k_offset += rank * p.head_num;
log_head_num *= world_size;
#endif

// alibi slope
const int n_heads_log2_floor = 1 << static_cast<int>(floor(log2(p.head_num)));
const int n_heads_log2_floor = 1 << static_cast<int>(floor(log2(log_head_num)));
const float m0 = powf(2.0f, -(8.f) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(8.f / 2.0f) / n_heads_log2_floor);

Expand Down Expand Up @@ -1381,9 +1392,10 @@ class MHAStableInterface {
const int ihkv = ihn / group_heads;
const int m_size = std::min(M_TILE, p.sl_q - i_m);

const auto alibi_ihn_m = !is_alibi ? 0.f
: (ihn < n_heads_log2_floor) ? powf(m0, ihn + 1)
: powf(m1, 2 * (ihn - n_heads_log2_floor) + 1);
const auto alibi_ihn_m = !is_alibi ? 0.f
: (ihn + k_offset < n_heads_log2_floor)
? powf(m0, ihn + k_offset + 1)
: powf(m1, 2 * (ihn + k_offset - n_heads_log2_floor) + 1);

float s_max[M_TILE]{}; // maximum for each row of the S matrix
std::fill_n(s_max, M_TILE, -INFINITY);
Expand Down Expand Up @@ -1718,7 +1730,17 @@ void bestla_fusion_attn_forward_ref(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>&
const auto ROWPACK = p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? 4
: p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? 2
: 0;
const int n_heads_log2_floor = 1 << static_cast<int>(floor(log2(p.head_num)));
// TP will need the real rank oder of k
int32_t k_offset = 0;
int32_t log_head_num = p.head_num;
#ifdef NE_TP_MODEL
parallel_context* p_ctx = init_parallel_context();
int32_t world_size = get_tp_size(p_ctx);
int32_t rank = get_tp_rank(p_ctx);
if (world_size > 1) k_offset += rank * p.head_num;
log_head_num = p.head_num * world_size;
#endif
const int n_heads_log2_floor = 1 << static_cast<int>(floor(log2(log_head_num)));
const float m0 = powf(2.0f, -(8.f) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(8.f / 2.0f) / n_heads_log2_floor);

Expand All @@ -1737,9 +1759,10 @@ void bestla_fusion_attn_forward_ref(const attn_fwd_args_t<Q_T, K_T, V_T, DST_T>&
const auto unmasked = is_causal ? sl_diff + i + 1 : p.sl_kv;
const auto curr_row = std::unique_ptr<float[]>(new float[unmasked]);

const auto alibi_ihn_m = !is_alibi ? 0.f
: (ihn < n_heads_log2_floor) ? powf(m0, ihn + 1)
: powf(m1, 2 * (ihn - n_heads_log2_floor) + 1);
const auto alibi_ihn_m = !is_alibi ? 0.f
: (ihn + k_offset < n_heads_log2_floor)
? powf(m0, ihn + k_offset + 1)
: powf(m1, 2 * (ihn + k_offset - n_heads_log2_floor) + 1);

// Q x K
float row_max = -INFINITY;
Expand Down
143 changes: 53 additions & 90 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -1324,13 +1324,9 @@ struct ne_tensor* ne_debug_op(struct ne_context* ctx, struct ne_tensor* a, ne_de
return result;
}

struct ne_tensor* ne_dup(struct ne_context* ctx, struct ne_tensor* a) {
return ne_dup_impl(ctx, a, false);
}
struct ne_tensor* ne_dup(struct ne_context* ctx, struct ne_tensor* a) { return ne_dup_impl(ctx, a, false); }

struct ne_tensor* ne_dup_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_dup_impl(ctx, a, true);
}
struct ne_tensor* ne_dup_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_dup_impl(ctx, a, true); }

// ne_add

Expand Down Expand Up @@ -1683,13 +1679,9 @@ struct ne_tensor* ne_sqr_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sqr(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqr_impl(ctx, a, false);
}
struct ne_tensor* ne_sqr(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqr_impl(ctx, a, false); }

struct ne_tensor* ne_sqr_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqr_impl(ctx, a, true);
}
struct ne_tensor* ne_sqr_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqr_impl(ctx, a, true); }

// ne_sqrt

Expand All @@ -1710,13 +1702,9 @@ struct ne_tensor* ne_sqrt_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sqrt(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqrt_impl(ctx, a, false);
}
struct ne_tensor* ne_sqrt(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqrt_impl(ctx, a, false); }

struct ne_tensor* ne_sqrt_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqrt_impl(ctx, a, true);
}
struct ne_tensor* ne_sqrt_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqrt_impl(ctx, a, true); }

// ne_log

Expand All @@ -1737,13 +1725,9 @@ struct ne_tensor* ne_log_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_log(struct ne_context* ctx, struct ne_tensor* a) {
return ne_log_impl(ctx, a, false);
}
struct ne_tensor* ne_log(struct ne_context* ctx, struct ne_tensor* a) { return ne_log_impl(ctx, a, false); }

struct ne_tensor* ne_log_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_log_impl(ctx, a, true);
}
struct ne_tensor* ne_log_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_log_impl(ctx, a, true); }

// ne_sum

Expand Down Expand Up @@ -1853,13 +1837,9 @@ struct ne_tensor* ne_abs_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_abs(struct ne_context* ctx, struct ne_tensor* a) {
return ne_abs_impl(ctx, a, false);
}
struct ne_tensor* ne_abs(struct ne_context* ctx, struct ne_tensor* a) { return ne_abs_impl(ctx, a, false); }

struct ne_tensor* ne_abs_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_abs_impl(ctx, a, true);
}
struct ne_tensor* ne_abs_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_abs_impl(ctx, a, true); }

// ne_sgn

Expand All @@ -1880,13 +1860,9 @@ struct ne_tensor* ne_sgn_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sgn(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sgn_impl(ctx, a, false);
}
struct ne_tensor* ne_sgn(struct ne_context* ctx, struct ne_tensor* a) { return ne_sgn_impl(ctx, a, false); }

struct ne_tensor* ne_sgn_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sgn_impl(ctx, a, true);
}
struct ne_tensor* ne_sgn_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sgn_impl(ctx, a, true); }

// ne_neg

Expand All @@ -1907,13 +1883,9 @@ struct ne_tensor* ne_neg_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_neg(struct ne_context* ctx, struct ne_tensor* a) {
return ne_neg_impl(ctx, a, false);
}
struct ne_tensor* ne_neg(struct ne_context* ctx, struct ne_tensor* a) { return ne_neg_impl(ctx, a, false); }

struct ne_tensor* ne_neg_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_neg_impl(ctx, a, true);
}
struct ne_tensor* ne_neg_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_neg_impl(ctx, a, true); }

// ne_step

Expand All @@ -1934,13 +1906,9 @@ struct ne_tensor* ne_step_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_step(struct ne_context* ctx, struct ne_tensor* a) {
return ne_step_impl(ctx, a, false);
}
struct ne_tensor* ne_step(struct ne_context* ctx, struct ne_tensor* a) { return ne_step_impl(ctx, a, false); }

struct ne_tensor* ne_step_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_step_impl(ctx, a, true);
}
struct ne_tensor* ne_step_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_step_impl(ctx, a, true); }

// ne_relu

Expand All @@ -1961,13 +1929,9 @@ struct ne_tensor* ne_relu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_relu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_relu_impl(ctx, a, false);
}
struct ne_tensor* ne_relu(struct ne_context* ctx, struct ne_tensor* a) { return ne_relu_impl(ctx, a, false); }

struct ne_tensor* ne_relu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_relu_impl(ctx, a, true);
}
struct ne_tensor* ne_relu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_relu_impl(ctx, a, true); }

// ne_gelu

Expand All @@ -1988,13 +1952,9 @@ struct ne_tensor* ne_gelu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_gelu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_gelu_impl(ctx, a, false);
}
struct ne_tensor* ne_gelu(struct ne_context* ctx, struct ne_tensor* a) { return ne_gelu_impl(ctx, a, false); }

struct ne_tensor* ne_gelu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_gelu_impl(ctx, a, true);
}
struct ne_tensor* ne_gelu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_gelu_impl(ctx, a, true); }

// ne_silu

Expand All @@ -2015,13 +1975,9 @@ struct ne_tensor* ne_silu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_silu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_silu_impl(ctx, a, false);
}
struct ne_tensor* ne_silu(struct ne_context* ctx, struct ne_tensor* a) { return ne_silu_impl(ctx, a, false); }

struct ne_tensor* ne_silu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_silu_impl(ctx, a, true);
}
struct ne_tensor* ne_silu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_silu_impl(ctx, a, true); }

// ne_silu_back

Expand Down Expand Up @@ -2063,13 +2019,9 @@ struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, false);
}
struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a) { return ne_norm_impl(ctx, a, false); }

struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, true);
}
struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_norm_impl(ctx, a, true); }

struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace, float eps) {
bool is_node = false;
Expand Down Expand Up @@ -2415,13 +2367,9 @@ struct ne_tensor* ne_cont_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_cont(struct ne_context* ctx, struct ne_tensor* a) {
return ne_cont_impl(ctx, a, false);
}
struct ne_tensor* ne_cont(struct ne_context* ctx, struct ne_tensor* a) { return ne_cont_impl(ctx, a, false); }

struct ne_tensor* ne_cont_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_cont_impl(ctx, a, true);
}
struct ne_tensor* ne_cont_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_cont_impl(ctx, a, true); }

// ne_reshape

Expand Down Expand Up @@ -2968,9 +2916,7 @@ struct ne_tensor* ne_soft_max_impl(struct ne_context* ctx, struct ne_tensor* a,
return result;
}

struct ne_tensor* ne_soft_max(struct ne_context* ctx, struct ne_tensor* a) {
return ne_soft_max_impl(ctx, a, false);
}
struct ne_tensor* ne_soft_max(struct ne_context* ctx, struct ne_tensor* a) { return ne_soft_max_impl(ctx, a, false); }

struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_soft_max_impl(ctx, a, true);
Expand Down Expand Up @@ -7653,7 +7599,7 @@ static void ne_compute_forward_alibi_f32(const struct ne_compute_params* params,
}

const int n_past = ((int32_t*)src1->data)[0];
const int n_head = ((int32_t*)src1->data)[1];
int n_head = ((int32_t*)src1->data)[1];
const float max_bias = ((float*)src1->data)[2];

assert(n_past >= 0);
Expand All @@ -7674,6 +7620,15 @@ static void ne_compute_forward_alibi_f32(const struct ne_compute_params* params,
assert(nb0 == sizeof(float));
assert(ne1 + n_past == ne0);
(void)n_past;
// TP will need the real rank oder of k
int32_t k_offset = 0;
#ifdef NE_TP_MODEL
parallel_context* p_ctx = init_parallel_context();
int32_t world_size = get_tp_size(p_ctx);
int32_t rank = get_tp_rank(p_ctx);
if (world_size > 1) k_offset += rank * n_head;
n_head *= world_size;
#endif

// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int)floor(log2(n_head));
Expand All @@ -7691,10 +7646,10 @@ static void ne_compute_forward_alibi_f32(const struct ne_compute_params* params,

float m_k;

if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
if (k + k_offset < n_heads_log2_floor) {
m_k = powf(m0, k + k_offset + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
m_k = powf(m1, 2 * (k + k_offset - n_heads_log2_floor) + 1);
}

pdst[0] = (i - ne0 + 1) * m_k + src[0];
Expand All @@ -7714,7 +7669,7 @@ static void ne_compute_forward_alibi_f16(const struct ne_compute_params* params,
}

const int n_past = ((int32_t*)src1->data)[0];
const int n_head = ((int32_t*)src1->data)[1];
int n_head = ((int32_t*)src1->data)[1];
const float max_bias = ((float*)src1->data)[2];

assert(n_past >= 0);
Expand All @@ -7735,7 +7690,15 @@ static void ne_compute_forward_alibi_f16(const struct ne_compute_params* params,
assert(nb0 == sizeof(ne_fp16_t));
assert(ne1 + n_past == ne0);
(void)n_past;

// TP will need the real rank oder of k
int32_t k_offset = 0;
#ifdef NE_TP_MODEL
parallel_context* p_ctx = init_parallel_context();
int32_t world_size = get_tp_size(p_ctx);
int32_t rank = get_tp_rank(p_ctx);
if (world_size > 1) k_offset += rank * n_head;
n_head *= world_size;
#endif
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int)floor(log2(n_head));

Expand All @@ -7753,10 +7716,10 @@ static void ne_compute_forward_alibi_f16(const struct ne_compute_params* params,

float m_k;

if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
if (k + k_offset < n_heads_log2_floor) {
m_k = powf(m0, k + k_offset + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
m_k = powf(m1, 2 * (k + k_offset - n_heads_log2_floor) + 1);
}

// we return F32
Expand Down
Loading

0 comments on commit eee9d31

Please sign in to comment.