Skip to content

Commit

Permalink
gc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Janis Erdmanis committed Oct 30, 2024
1 parent 38b8efc commit 4565d42
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 277 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2024 PeaceFounder
Copyright (c) 2024 Janis Erdmanis

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
22 changes: 20 additions & 2 deletions src/OpenSSLGroups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,28 @@ using Base.GMP

export octet, order, value

const ctx = ccall((:BN_CTX_new, libcrypto), Ptr{Cvoid}, ()) # may need to be set at runtime and etc...

include("utils.jl")
include("context.jl")
include("point.jl")
include("curves.jl")

# The context is a scratchspace and never leaves internal function boundary, hence, using threadid is appropriate
const THREAD_CTXS = ThreadLocal{OpenSSLContext}()

function get_ctx()
@assert haskey(THREAD_CTXS) "Thread context not initialized"
return THREAD_CTXS[].ctx
end

function __init__()
@sync begin
for tid in 1:Threads.nthreads()
Threads.@spawn begin
THREAD_CTXS[] = OpenSSLContext()
end
end
end
end


end # module OpenSSLGroups
56 changes: 56 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
mutable struct ThreadLocal{T}
values::Vector{Union{Nothing, T}}

ThreadLocal{T}() where T = new(fill(nothing, Base.Threads.nthreads()))
end

# Check if value exists for current thread
function Base.haskey(t::ThreadLocal)
tid = Base.Threads.threadid()
return t.values[tid] !== nothing
end

# Get value for current thread
function Base.getindex(t::ThreadLocal)
tid = Base.Threads.threadid()
val = t.values[tid]
if val === nothing
throw(KeyError("No value for thread $tid"))
end
return val
end

# Set value for current thread
function Base.setindex!(t::ThreadLocal{T}, v::T) where T
tid = Base.Threads.threadid()
t.values[tid] = v
end

# Delete value for current thread
function Base.delete!(t::ThreadLocal)
tid = Base.Threads.threadid()
t.values[tid] = nothing
end


# Thread context structure
mutable struct OpenSSLContext
ctx::Ptr{Nothing}

function OpenSSLContext()
ctx = ccall((:BN_CTX_new, libcrypto), Ptr{Nothing}, ())
if ctx == C_NULL
throw(OpenSSLError("Failed to create BN_CTX"))
end
obj = new(ctx)
finalizer(obj) do x
if x.ctx != C_NULL
@ccall libcrypto.BN_CTX_free(x.ctx::Ptr{Nothing})::Cvoid
x.ctx = C_NULL
end
end
return obj
end
end


41 changes: 32 additions & 9 deletions src/curves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,45 +87,68 @@
# 934 brainpoolP512t1 brainpoolP512t1
# 1172 SM2 sm2


macro prime_curve(curve_name, struct_name::Symbol)

# Handle both Symbol and String representations
curve_str = if curve_name isa Symbol
String(curve_name)
else
# Remove outer quotation marks if present and convert to string
string(curve_name) #[2:end-1]
string(curve_name)
end

return quote
struct $(esc(struct_name)) <: OpenSSLPrimePoint
mutable struct $(esc(struct_name)) <: OpenSSLPrimePoint
pointer::Ptr{Nothing}
$(esc(struct_name))(x::Ptr{Nothing}) = new(x)
function $(esc(struct_name))(x::Ptr{Nothing})
point = new(x)

finalizer(point) do p
if p.pointer != C_NULL
@ccall libcrypto.EC_POINT_free(p.pointer::Ptr{Nothing})::Cvoid
p.pointer = C_NULL
end
end

return point
end
end

$(esc(:group_pointer))(::Type{$(esc(struct_name))}) = $(esc(:group_pointer))($(esc(:get_curve_nid))($curve_str))
end
end


macro binary_curve(curve_name, struct_name::Symbol)

# Handle both Symbol and String representations
curve_str = if curve_name isa Symbol
String(curve_name)
else
# Remove outer quotation marks if present and convert to string
string(curve_name) #[2:end-1]
string(curve_name)
end

return quote
struct $(esc(struct_name)) <: OpenSSLBinaryPoint
mutable struct $(esc(struct_name)) <: OpenSSLBinaryPoint
pointer::Ptr{Nothing}
$(esc(struct_name))(x::Ptr{Nothing}) = new(x)
function $(esc(struct_name))(x::Ptr{Nothing})
point = new(x)

finalizer(point) do p
if p.pointer != C_NULL
@ccall libcrypto.EC_POINT_free(p.pointer::Ptr{Nothing})::Cvoid
p.pointer = C_NULL
end
end

return point
end
end

$(esc(:group_pointer))(::Type{$(esc(struct_name))}) = $(esc(:group_pointer))($(esc(:get_curve_nid))($curve_str))
end
end


# Prime Curves (secp, prime, brainpoolP)
@prime_curve secp112r1 SecP112r1
@prime_curve secp112r2 SecP112r2
Expand Down
Loading

0 comments on commit 4565d42

Please sign in to comment.