Skip to content

Commit

Permalink
feat(llvm): try (and fail) to workaround bad modular addition codegen…
Browse files Browse the repository at this point in the history
… with inline function.
  • Loading branch information
mratsim committed Aug 12, 2024
1 parent 0354d5b commit 08b8671
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 18 deletions.
15 changes: 12 additions & 3 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M, car
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)

let name = "_finalsub_mayo_u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, M, carry])):
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, M, carry]),
{kHot, kInline}):

let (rr, aa, MM, carry) = llvmParams

Expand Down Expand Up @@ -126,7 +129,10 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Valu
## (say using 255 bits for the modulus out of 256 available in words)

let name = "_finalsub_noo_u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, M])):
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, M]),
{kHot, kInline}):

let (rr, aa, MM) = llvmParams

Expand Down Expand Up @@ -156,7 +162,10 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
let red = if fd.spareBits >= 1: "noo"
else: "mayo"
let name = "_modadd_" & red & "_u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, b, M])):
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, b, M]),
{kHot}):

let (r, aa, bb, M) = llvmParams

Expand Down
46 changes: 37 additions & 9 deletions constantine/math_compiler/ir.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ import
# ############################################################

type
AttrKind* = enum
# Other important properties like
# - norecurse
# - memory side-effects memory(argmem: readwrtite)
# can be deduced.
kHot,
kInline,
kAlwaysInline,
kNoInline

Assembler_LLVM* = ref object
ctx*: ContextRef
module*: ModuleRef
Expand All @@ -31,7 +41,9 @@ type
# It doesn't seem possible to retrieve a function type
# from its value, so we store them here.
# If we store the type we might as well store the impl
fns: Table[string, tuple[ty: TypeRef, impl: ValueRef]]
# and we store whether it's internal to apply the fastcc calling convention
fns: Table[string, tuple[ty: TypeRef, impl: ValueRef, internal: bool]]
attrs: array[AttrKind, AttributeRef]

# Convenience
void_t*: TypeRef
Expand Down Expand Up @@ -96,6 +108,11 @@ proc new*(T: type Assembler_LLVM, backend: Backend, moduleName: cstring): Assemb

result.configure(backend)

result.attrs[kHot] = result.ctx.createAttr("hot")
result.attrs[kInline] = result.ctx.createAttr("inlinehint")
result.attrs[kAlwaysInline] = result.ctx.createAttr("alwaysinline")
result.attrs[kNoInline] = result.ctx.createAttr("noinline")

# ############################################################
#
# Syntax Sugar
Expand Down Expand Up @@ -333,7 +350,7 @@ proc tagCudaKernel(asy: Assembler_LLVM, fn: ValueRef) =

proc setPublic(asy: Assembler_LLVM, fn: ValueRef) =
case asy.backend
of bkAmdGpu: fn.setCallingConvention(AMDGPU_KERNEL)
of bkAmdGpu: fn.setFnCallConv(AMDGPU_KERNEL)
of bkNvidiaPtx: asy.tagCudaKernel(fn)
else: discard

Expand Down Expand Up @@ -432,16 +449,22 @@ macro unpackParams[N: static int](
let fn = `br`.getCurrentFunction()
fn.getParam(uint32 `i`)

proc addAttributes(asy: Assembler_LLVM, fn: ValueRef, attrs: set[AttrKind]) =
for attr in attrs:
fn.addAttribute(kAttrFnIndex, asy.attrs[attr])

fn.addAttribute(kAttrFnIndex, asy.attrs[kHot])

template llvmFnDef[N: static int](
asy: Assembler_LLVM,
name, sectionName: string,
returnType: TypeRef,
paramTypes: array[N, TypeRef],
internal: bool,
attrs: set[AttrKind],
body: untyped) =
## This setups common prologue to implement a function in LLVM
## Function parameters are available with the `llvmParams` magic variable

let paramsTys = asy.wrapTypesForFnCall(paramTypes)

var fn = asy.module.getFunction(cstring name)
Expand All @@ -451,7 +474,7 @@ template llvmFnDef[N: static int](
let fnTy = function_t(returnType, paramsTys.wrapped)
fn = asy.module.addFunction(cstring name, fnTy)

asy.fns[name] = (fnTy, fn)
asy.fns[name] = (fnTy, fn, internal)

let blck = asy.ctx.appendBasicBlock(fn)
asy.br.positionAtEnd(blck)
Expand All @@ -465,11 +488,12 @@ template llvmFnDef[N: static int](
body

if internal:
fn.setCallingConvention(Fast)
fn.setFnCallConv(Fast)
fn.setLinkage(linkInternal)
else:
asy.setPublic(fn)
fn.setSection(sectionName)
asy.addAttributes(fn, attrs)

asy.br.positionAtEnd(savedLoc)

Expand All @@ -478,26 +502,30 @@ template llvmInternalFnDef*[N: static int](
name, sectionName: string,
returnType: TypeRef,
paramTypes: array[N, TypeRef],
attrs: set[AttrKind] = {},
body: untyped) =
llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = true, body)
llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = true, attrs, body)

template llvmPublicFnDef*[N: static int](
asy: Assembler_LLVM,
name, sectionName: string,
returnType: TypeRef,
paramTypes: array[N, TypeRef],
body: untyped) =
llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = false, body)
llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = false, {}, body)

proc callFn*(
asy: Assembler_LLVM,
name: string,
params: openArray[ValueRef]): ValueRef {.discardable.} =

if asy.fns[name].ty.getReturnType().getTypeKind() == tkVoid:
asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params)
result = asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params)
else:
asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params, cstring(name))
result = asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params, cstring(name))

if asy.fns[name].internal:
result.setInstrCallConv(Fast)

# ############################################################
#
Expand Down
24 changes: 22 additions & 2 deletions constantine/platforms/abis/llvm_abi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type
TypeRef* = distinct pointer
ValueRef* = distinct pointer
MetadataRef = distinct pointer
AttributeRef* = distinct pointer
LLVMstring = distinct cstring
ErrorMessageString = distinct cstring
## A string with a buffer owned by LLVM
Expand Down Expand Up @@ -123,16 +124,19 @@ proc verify(module: ModuleRef, failureAction: VerifierFailureAction, msg: var LL

{.push used.}
proc initializeX86AsmPrinter() {.importc: "LLVMInitializeX86AsmPrinter".}
proc initializeX86AsmParser() {.importc: "LLVMInitializeX86AsmParser".}
proc initializeX86Target() {.importc: "LLVMInitializeX86Target".}
proc initializeX86TargetInfo() {.importc: "LLVMInitializeX86TargetInfo".}
proc initializeX86TargetMC() {.importc: "LLVMInitializeX86TargetMC".}

proc initializeNVPTXAsmPrinter() {.importc: "LLVMInitializeNVPTXAsmPrinter".}
proc initializeNVPTXAsmParser() {.importc: "LLVMInitializeNVPTXAsmParser".}
proc initializeNVPTXTarget() {.importc: "LLVMInitializeNVPTXTarget".}
proc initializeNVPTXTargetInfo() {.importc: "LLVMInitializeNVPTXTargetInfo".}
proc initializeNVPTXTargetMC() {.importc: "LLVMInitializeNVPTXTargetMC".}

proc initializeAMDGPUAsmPrinter() {.importc: "LLVMInitializeAMDGPUAsmPrinter".}
proc initializeAMDGPUAsmParser() {.importc: "LLVMInitializeAMDGPUAsmParser".}
proc initializeAMDGPUTarget() {.importc: "LLVMInitializeAMDGPUTarget".}
proc initializeAMDGPUTargetInfo() {.importc: "LLVMInitializeAMDGPUTargetInfo".}
proc initializeAMDGPUTargetMC() {.importc: "LLVMInitializeAMDGPUTargetMC".}
Expand Down Expand Up @@ -612,8 +616,24 @@ proc countParamTypes*(functionTy: TypeRef): uint32 {.importc: "LLVMCountParamTyp

proc getCalledFunctionType*(fn: ValueRef): TypeRef {.importc: "LLVMGetCalledFunctionType".}

proc getCallingConvention*(function: ValueRef): CallingConvention {.importc: "LLVMGetFunctionCallConv".}
proc setCallingConvention*(function: ValueRef, cc: CallingConvention) {.importc: "LLVMSetFunctionCallConv".}
proc getFnCallConv*(function: ValueRef): CallingConvention {.importc: "LLVMGetFunctionCallConv".}
proc setFnCallConv*(function: ValueRef, cc: CallingConvention) {.importc: "LLVMSetFunctionCallConv".}

proc getInstrCallConv*(instr: ValueRef): CallingConvention {.importc: "LLVMGetInstructionCallConv".}
proc setInstrCallConv*(instr: ValueRef, cc: CallingConvention) {.importc: "LLVMSetInstructionCallConv".}

type
AttributeIndex* {.size: sizeof(cint).} = enum
## Attribute index is either -1 for the function
## 0 for the return value
## or 1..n for each function parameter
kAttrFnIndex = -1
kAttrRetIndex = 0

proc toAttrId*(name: openArray[char]): cuint {.importc: "LLVMGetEnumAttributeKindForName".}
proc toAttr*(ctx: ContextRef, attr_id: uint64, val = 0'u64): AttributeRef {.importc: "LLVMCreateEnumAttribute".}
proc addAttribute*(fn: ValueRef, index: cint, attr: AttributeRef) {.importc: "LLVMAddAttributeAtIndex".}
proc addAttribute*(fn: ValueRef, index: AttributeIndex, attr: AttributeRef) {.importc: "LLVMAddAttributeAtIndex".}

# ############################################################
#
Expand Down
7 changes: 6 additions & 1 deletion constantine/platforms/llvm/llvm.nim
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ proc array_t*(elemType: TypeRef, elemCount: SomeInteger): TypeRef {.inline.}=
proc function_t*(returnType: TypeRef, paramTypes: openArray[TypeRef]): TypeRef {.inline.} =
function_t(returnType, paramTypes, isVarArg = LlvmBool(false))

# Functions
# ------------------------------------------------------------

proc createAttr*(ctx: ContextRef, name: openArray[char]): AttributeRef =
ctx.toAttr(name.toAttrId())

# Values
# ------------------------------------------------------------

Expand All @@ -189,4 +195,3 @@ proc getName*(v: ValueRef): string =

proc constInt*(ty: TypeRef, n: SomeInteger, signExtend = false): ValueRef {.inline.} =
constInt(ty, culonglong(n), LlvmBool(signExtend))

7 changes: 4 additions & 3 deletions research/codegen/x86_poc.nim
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ proc t_field_add() =
target = toTarget(triple),
triple = triple,
cpu = "",
features = "adx,bmi2", # TODO check the proper way to pass options
level = CodeGenLevelAggressive,
features = "", # "adx,bmi2", # TODO check the proper way to pass options
level = CodeGenLevelDefault,
reloc = RelocDefault,
codeModel = CodeModelDefault
)

let pbo = createPassBuilderOptions()
pbo.setMergeFunctions()
let err = asy.module.runPasses(
"default<O3>,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce",
"default<O2>",
# "default<O2>,memcpyopt,sroa,mem2reg,function-attrs,inline,gvn,dse,aggressive-instcombine,adce",
machine,
pbo
)
Expand Down

0 comments on commit 08b8671

Please sign in to comment.