Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
use clang-format 14.0
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Dec 20, 2023
1 parent 9311ddc commit c7ee53a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 68 deletions.
4 changes: 2 additions & 2 deletions bestla/bestla/xbyak/xbyak.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#endif

#ifdef __GNUC__
#define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor))
#define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__)*100 + (__GNUC_MINOR__) >= (major)*100 + (minor))
#else
#define XBYAK_GNUC_PREREQ(major, minor) 0
#endif
Expand Down Expand Up @@ -191,7 +191,7 @@ typedef uint8_t uint8;
#endif
#endif
#ifndef MIE_PACK // for shufps
#define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w))
#define MIE_PACK(x, y, z, w) ((x)*64 + (y)*16 + (z)*4 + (w))
#endif

enum {
Expand Down
65 changes: 26 additions & 39 deletions bestla/bestla/xbyak/xbyak_mnemonic.h
Original file line number Diff line number Diff line change
Expand Up @@ -2017,14 +2017,12 @@ void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) {
XBYAK_THROW(ERR_BAD_COMBINATION) opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12);
}
void vmovmskpd(const Reg& r, const Xmm& x) {
if (!r.isBit(i32e))
XBYAK_THROW(ERR_BAD_COMBINATION)
opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50);
if (!r.isBit(i32e)) XBYAK_THROW(ERR_BAD_COMBINATION)
opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50);
}
void vmovmskps(const Reg& r, const Xmm& x) {
if (!r.isBit(i32e))
XBYAK_THROW(ERR_BAD_COMBINATION)
opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50);
if (!r.isBit(i32e)) XBYAK_THROW(ERR_BAD_COMBINATION)
opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50);
}
void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); }
void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); }
Expand Down Expand Up @@ -2403,9 +2401,8 @@ void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) {
opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A);
}
void vpmovmskb(const Reg32e& r, const Xmm& x) {
if (!x.is(Operand::XMM | Operand::YMM))
XBYAK_THROW(ERR_BAD_COMBINATION)
opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7);
if (!x.is(Operand::XMM | Operand::YMM)) XBYAK_THROW(ERR_BAD_COMBINATION)
opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7);
}
void vpmovsxbd(const Xmm& xm, const Operand& op) {
opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21);
Expand Down Expand Up @@ -3569,27 +3566,24 @@ void vcvtph2dq(const Xmm& x, const Operand& op) {
opVex(x, 0, op, T_N8 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_ER_Y | T_MUST_EVEX | T_B16, 0x5B);
}
void vcvtph2pd(const Xmm& x, const Operand& op) {
if (!op.isXMM() && !op.isMEM())
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x5A);
if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x5A);
}
void vcvtph2psx(const Xmm& x, const Operand& op) {
checkCvt1(x, op);
opVex(x, 0, op, T_N8 | T_N_VL | T_66 | T_MAP6 | T_EW0 | T_YMM | T_SAE_Y | T_MUST_EVEX | T_B16, 0x13);
}
void vcvtph2qq(const Xmm& x, const Operand& op) {
if (!op.isXMM() && !op.isMEM())
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_ER_X | T_MUST_EVEX | T_B16, 0x7B);
if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_ER_X | T_MUST_EVEX | T_B16, 0x7B);
}
void vcvtph2udq(const Xmm& x, const Operand& op) {
checkCvt1(x, op);
opVex(x, 0, op, T_N8 | T_N_VL | T_MAP5 | T_EW0 | T_YMM | T_ER_Y | T_MUST_EVEX | T_B16, 0x79);
}
void vcvtph2uqq(const Xmm& x, const Operand& op) {
if (!op.isXMM() && !op.isMEM())
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_ER_X | T_MUST_EVEX | T_B16, 0x79);
if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_ER_X | T_MUST_EVEX | T_B16, 0x79);
}
void vcvtph2uw(const Xmm& x, const Operand& op) {
opAVX_X_XM_IMM(x, op, T_MAP5 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B16, 0x7D);
Expand Down Expand Up @@ -3643,9 +3637,8 @@ void vcvtsh2usi(const Reg32e& r, const Operand& op) {
opVex(r, &xm0, op, type, 0x79);
}
void vcvtsi2sh(const Xmm& x1, const Xmm& x2, const Operand& op) {
if (!(x1.isXMM() && x2.isXMM() && op.isBit(32 | 64)))
XBYAK_THROW(ERR_BAD_COMBINATION)
int type = (T_F3 | T_MAP5 | T_ER_R | T_MUST_EVEX | T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8));
if (!(x1.isXMM() && x2.isXMM() && op.isBit(32 | 64))) XBYAK_THROW(ERR_BAD_COMBINATION)
int type = (T_F3 | T_MAP5 | T_ER_R | T_MUST_EVEX | T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8));
opVex(x1, &x2, op, type, 0x2A);
}
void vcvtss2sh(const Xmm& x1, const Xmm& x2, const Operand& op) {
Expand All @@ -3669,18 +3662,16 @@ void vcvttph2dq(const Xmm& x, const Operand& op) {
opVex(x, 0, op, T_N8 | T_N_VL | T_F3 | T_MAP5 | T_EW0 | T_YMM | T_SAE_Y | T_MUST_EVEX | T_B16, 0x5B);
}
void vcvttph2qq(const Xmm& x, const Operand& op) {
if (!op.isXMM() && !op.isMEM())
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x7A);
if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x7A);
}
void vcvttph2udq(const Xmm& x, const Operand& op) {
checkCvt1(x, op);
opVex(x, 0, op, T_N8 | T_N_VL | T_MAP5 | T_EW0 | T_YMM | T_SAE_Y | T_MUST_EVEX | T_B16, 0x78);
}
void vcvttph2uqq(const Xmm& x, const Operand& op) {
if (!op.isXMM() && !op.isMEM())
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x78);
if (!op.isXMM() && !op.isMEM()) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(x, 0, op, T_N4 | T_N_VL | T_66 | T_MAP5 | T_EW0 | T_YMM | T_SAE_X | T_MUST_EVEX | T_B16, 0x78);
}
void vcvttph2uw(const Xmm& x, const Operand& op) {
opAVX_X_XM_IMM(x, op, T_MAP5 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B16, 0x7C);
Expand Down Expand Up @@ -3739,9 +3730,8 @@ void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) {
opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B);
}
void vcvtusi2sh(const Xmm& x1, const Xmm& x2, const Operand& op) {
if (!(x1.isXMM() && x2.isXMM() && op.isBit(32 | 64)))
XBYAK_THROW(ERR_BAD_COMBINATION)
int type = (T_F3 | T_MAP5 | T_ER_R | T_MUST_EVEX | T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8));
if (!(x1.isXMM() && x2.isXMM() && op.isBit(32 | 64))) XBYAK_THROW(ERR_BAD_COMBINATION)
int type = (T_F3 | T_MAP5 | T_ER_R | T_MUST_EVEX | T_M_K) | (op.isBit(32) ? (T_EW0 | T_N4) : (T_EW1 | T_N8));
opVex(x1, &x2, op, type, 0x7B);
}
void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) {
Expand Down Expand Up @@ -3924,19 +3914,16 @@ void vfnmsub231sh(const Xmm& x1, const Xmm& x2, const Operand& op) {
opAVX_X_X_XM(x1, x2, op, T_N2 | T_66 | T_MAP6 | T_EW0 | T_ER_X | T_MUST_EVEX, 0xBF);
}
void vfpclasspd(const Opmask& k, const Operand& op, uint8_t imm) {
if (!op.isBit(128 | 256 | 512))
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm);
if (!op.isBit(128 | 256 | 512)) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm);
}
void vfpclassph(const Opmask& k, const Operand& op, uint8_t imm) {
if (!op.isBit(128 | 256 | 512))
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B16, 0x66, imm);
if (!op.isBit(128 | 256 | 512)) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B16, 0x66, imm);
}
void vfpclassps(const Opmask& k, const Operand& op, uint8_t imm) {
if (!op.isBit(128 | 256 | 512))
XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm);
if (!op.isBit(128 | 256 | 512)) XBYAK_THROW(ERR_BAD_MEM_SIZE)
opVex(k.changeBit(op.getBit()), 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm);
}
void vfpclasssd(const Opmask& k, const Operand& op, uint8_t imm) {
if (!op.isXMEM())
Expand Down
108 changes: 81 additions & 27 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,13 @@ struct ne_tensor* ne_debug_op(struct ne_context* ctx, struct ne_tensor* a, ne_de
return result;
}

struct ne_tensor* ne_dup(struct ne_context* ctx, struct ne_tensor* a) { return ne_dup_impl(ctx, a, false); }
struct ne_tensor* ne_dup(struct ne_context* ctx, struct ne_tensor* a) {
return ne_dup_impl(ctx, a, false);
}

struct ne_tensor* ne_dup_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_dup_impl(ctx, a, true); }
struct ne_tensor* ne_dup_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_dup_impl(ctx, a, true);
}

// ne_add

Expand Down Expand Up @@ -1679,9 +1683,13 @@ struct ne_tensor* ne_sqr_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sqr(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqr_impl(ctx, a, false); }
struct ne_tensor* ne_sqr(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqr_impl(ctx, a, false);
}

struct ne_tensor* ne_sqr_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqr_impl(ctx, a, true); }
struct ne_tensor* ne_sqr_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqr_impl(ctx, a, true);
}

// ne_sqrt

Expand All @@ -1702,9 +1710,13 @@ struct ne_tensor* ne_sqrt_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sqrt(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqrt_impl(ctx, a, false); }
struct ne_tensor* ne_sqrt(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqrt_impl(ctx, a, false);
}

struct ne_tensor* ne_sqrt_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sqrt_impl(ctx, a, true); }
struct ne_tensor* ne_sqrt_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sqrt_impl(ctx, a, true);
}

// ne_log

Expand All @@ -1725,9 +1737,13 @@ struct ne_tensor* ne_log_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_log(struct ne_context* ctx, struct ne_tensor* a) { return ne_log_impl(ctx, a, false); }
struct ne_tensor* ne_log(struct ne_context* ctx, struct ne_tensor* a) {
return ne_log_impl(ctx, a, false);
}

struct ne_tensor* ne_log_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_log_impl(ctx, a, true); }
struct ne_tensor* ne_log_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_log_impl(ctx, a, true);
}

// ne_sum

Expand Down Expand Up @@ -1837,9 +1853,13 @@ struct ne_tensor* ne_abs_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_abs(struct ne_context* ctx, struct ne_tensor* a) { return ne_abs_impl(ctx, a, false); }
struct ne_tensor* ne_abs(struct ne_context* ctx, struct ne_tensor* a) {
return ne_abs_impl(ctx, a, false);
}

struct ne_tensor* ne_abs_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_abs_impl(ctx, a, true); }
struct ne_tensor* ne_abs_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_abs_impl(ctx, a, true);
}

// ne_sgn

Expand All @@ -1860,9 +1880,13 @@ struct ne_tensor* ne_sgn_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_sgn(struct ne_context* ctx, struct ne_tensor* a) { return ne_sgn_impl(ctx, a, false); }
struct ne_tensor* ne_sgn(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sgn_impl(ctx, a, false);
}

struct ne_tensor* ne_sgn_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_sgn_impl(ctx, a, true); }
struct ne_tensor* ne_sgn_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_sgn_impl(ctx, a, true);
}

// ne_neg

Expand All @@ -1883,9 +1907,13 @@ struct ne_tensor* ne_neg_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_neg(struct ne_context* ctx, struct ne_tensor* a) { return ne_neg_impl(ctx, a, false); }
struct ne_tensor* ne_neg(struct ne_context* ctx, struct ne_tensor* a) {
return ne_neg_impl(ctx, a, false);
}

struct ne_tensor* ne_neg_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_neg_impl(ctx, a, true); }
struct ne_tensor* ne_neg_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_neg_impl(ctx, a, true);
}

// ne_step

Expand All @@ -1906,9 +1934,13 @@ struct ne_tensor* ne_step_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_step(struct ne_context* ctx, struct ne_tensor* a) { return ne_step_impl(ctx, a, false); }
struct ne_tensor* ne_step(struct ne_context* ctx, struct ne_tensor* a) {
return ne_step_impl(ctx, a, false);
}

struct ne_tensor* ne_step_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_step_impl(ctx, a, true); }
struct ne_tensor* ne_step_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_step_impl(ctx, a, true);
}

// ne_relu

Expand All @@ -1929,9 +1961,13 @@ struct ne_tensor* ne_relu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_relu(struct ne_context* ctx, struct ne_tensor* a) { return ne_relu_impl(ctx, a, false); }
struct ne_tensor* ne_relu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_relu_impl(ctx, a, false);
}

struct ne_tensor* ne_relu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_relu_impl(ctx, a, true); }
struct ne_tensor* ne_relu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_relu_impl(ctx, a, true);
}

// ne_gelu

Expand All @@ -1952,9 +1988,13 @@ struct ne_tensor* ne_gelu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_gelu(struct ne_context* ctx, struct ne_tensor* a) { return ne_gelu_impl(ctx, a, false); }
struct ne_tensor* ne_gelu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_gelu_impl(ctx, a, false);
}

struct ne_tensor* ne_gelu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_gelu_impl(ctx, a, true); }
struct ne_tensor* ne_gelu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_gelu_impl(ctx, a, true);
}

// ne_silu

Expand All @@ -1975,9 +2015,13 @@ struct ne_tensor* ne_silu_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_silu(struct ne_context* ctx, struct ne_tensor* a) { return ne_silu_impl(ctx, a, false); }
struct ne_tensor* ne_silu(struct ne_context* ctx, struct ne_tensor* a) {
return ne_silu_impl(ctx, a, false);
}

struct ne_tensor* ne_silu_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_silu_impl(ctx, a, true); }
struct ne_tensor* ne_silu_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_silu_impl(ctx, a, true);
}

// ne_silu_back

Expand Down Expand Up @@ -2019,9 +2063,13 @@ struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a) { return ne_norm_impl(ctx, a, false); }
struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, false);
}

struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_norm_impl(ctx, a, true); }
struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, true);
}

struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace, float eps) {
bool is_node = false;
Expand Down Expand Up @@ -2367,9 +2415,13 @@ struct ne_tensor* ne_cont_impl(struct ne_context* ctx, struct ne_tensor* a, bool
return result;
}

struct ne_tensor* ne_cont(struct ne_context* ctx, struct ne_tensor* a) { return ne_cont_impl(ctx, a, false); }
struct ne_tensor* ne_cont(struct ne_context* ctx, struct ne_tensor* a) {
return ne_cont_impl(ctx, a, false);
}

struct ne_tensor* ne_cont_inplace(struct ne_context* ctx, struct ne_tensor* a) { return ne_cont_impl(ctx, a, true); }
struct ne_tensor* ne_cont_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_cont_impl(ctx, a, true);
}

// ne_reshape

Expand Down Expand Up @@ -2916,7 +2968,9 @@ struct ne_tensor* ne_soft_max_impl(struct ne_context* ctx, struct ne_tensor* a,
return result;
}

struct ne_tensor* ne_soft_max(struct ne_context* ctx, struct ne_tensor* a) { return ne_soft_max_impl(ctx, a, false); }
struct ne_tensor* ne_soft_max(struct ne_context* ctx, struct ne_tensor* a) {
return ne_soft_max_impl(ctx, a, false);
}

struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_soft_max_impl(ctx, a, true);
Expand Down

0 comments on commit c7ee53a

Please sign in to comment.