From 4565d423aa91afe183a284b817c1042be8ea1caf Mon Sep 17 00:00:00 2001 From: Janis Erdmanis Date: Wed, 30 Oct 2024 19:47:54 +0200 Subject: [PATCH] gc cleanup --- LICENSE | 2 +- src/OpenSSLGroups.jl | 22 ++- src/context.jl | 56 +++++++ src/curves.jl | 41 +++-- src/point.jl | 357 ++++++++++++++++++++++--------------------- src/utils.jl | 179 +++++++++++----------- test/runtests.jl | 2 + 7 files changed, 382 insertions(+), 277 deletions(-) create mode 100644 src/context.jl diff --git a/LICENSE b/LICENSE index 67a745f..c1874f5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 PeaceFounder +Copyright (c) 2024 Janis Erdmanis Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/OpenSSLGroups.jl b/src/OpenSSLGroups.jl index 510e393..04db0cc 100644 --- a/src/OpenSSLGroups.jl +++ b/src/OpenSSLGroups.jl @@ -16,10 +16,28 @@ using Base.GMP export octet, order, value -const ctx = ccall((:BN_CTX_new, libcrypto), Ptr{Cvoid}, ()) # may need to be set at runtime and etc... - include("utils.jl") +include("context.jl") include("point.jl") include("curves.jl") +# The context is a scratchspace and never leaves internal function boundary, hence, using threadid is appropriate +const THREAD_CTXS = ThreadLocal{OpenSSLContext}() + +function get_ctx() + @assert haskey(THREAD_CTXS) "Thread context not initialized" + return THREAD_CTXS[].ctx +end + +function __init__() + @sync begin + for tid in 1:Threads.nthreads() + Threads.@spawn begin + THREAD_CTXS[] = OpenSSLContext() + end + end + end +end + + end # module OpenSSLGroups diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 0000000..b9811a2 --- /dev/null +++ b/src/context.jl @@ -0,0 +1,56 @@ +mutable struct ThreadLocal{T} + values::Vector{Union{Nothing, T}} + + ThreadLocal{T}() where T = new(fill(nothing, Base.Threads.nthreads())) +end + +# Check if value exists for current thread +function Base.haskey(t::ThreadLocal) + tid = Base.Threads.threadid() + return t.values[tid] !== nothing +end + +# Get value for current thread +function Base.getindex(t::ThreadLocal) + tid = Base.Threads.threadid() + val = t.values[tid] + if val === nothing + throw(KeyError("No value for thread $tid")) + end + return val +end + +# Set value for current thread +function Base.setindex!(t::ThreadLocal{T}, v::T) where T + tid = Base.Threads.threadid() + t.values[tid] = v +end + +# Delete value for current thread +function Base.delete!(t::ThreadLocal) + tid = Base.Threads.threadid() + t.values[tid] = nothing +end + + +# Thread context structure +mutable struct OpenSSLContext + ctx::Ptr{Nothing} + + function OpenSSLContext() + ctx = ccall((:BN_CTX_new, libcrypto), Ptr{Nothing}, ()) + if ctx == C_NULL + throw(OpenSSLError("Failed to create BN_CTX")) + end + obj = new(ctx) + finalizer(obj) do x + if x.ctx != C_NULL + @ccall libcrypto.BN_CTX_free(x.ctx::Ptr{Nothing})::Cvoid + x.ctx = C_NULL + end + end + return obj + end +end + + diff --git a/src/curves.jl b/src/curves.jl index eafd69b..9dc7de7 100644 --- a/src/curves.jl +++ b/src/curves.jl @@ -87,45 +87,68 @@ # 934 brainpoolP512t1 brainpoolP512t1 # 1172 SM2 sm2 + macro prime_curve(curve_name, struct_name::Symbol) + # Handle both Symbol and String representations curve_str = if curve_name isa Symbol String(curve_name) else - # Remove outer quotation marks if present and convert to string - string(curve_name) #[2:end-1] + string(curve_name) end return quote - struct $(esc(struct_name)) <: OpenSSLPrimePoint + mutable struct $(esc(struct_name)) <: OpenSSLPrimePoint pointer::Ptr{Nothing} - $(esc(struct_name))(x::Ptr{Nothing}) = new(x) + function $(esc(struct_name))(x::Ptr{Nothing}) + point = new(x) + + finalizer(point) do p + if p.pointer != C_NULL + @ccall libcrypto.EC_POINT_free(p.pointer::Ptr{Nothing})::Cvoid + p.pointer = C_NULL + end + end + + return point + end end $(esc(:group_pointer))(::Type{$(esc(struct_name))}) = $(esc(:group_pointer))($(esc(:get_curve_nid))($curve_str)) end end + macro binary_curve(curve_name, struct_name::Symbol) + # Handle both Symbol and String representations curve_str = if curve_name isa Symbol String(curve_name) else - # Remove outer quotation marks if present and convert to string - string(curve_name) #[2:end-1] + string(curve_name) end return quote - struct $(esc(struct_name)) <: OpenSSLBinaryPoint + mutable struct $(esc(struct_name)) <: OpenSSLBinaryPoint pointer::Ptr{Nothing} - $(esc(struct_name))(x::Ptr{Nothing}) = new(x) + function $(esc(struct_name))(x::Ptr{Nothing}) + point = new(x) + + finalizer(point) do p + if p.pointer != C_NULL + @ccall libcrypto.EC_POINT_free(p.pointer::Ptr{Nothing})::Cvoid + p.pointer = C_NULL + end + end + + return point + end end $(esc(:group_pointer))(::Type{$(esc(struct_name))}) = $(esc(:group_pointer))($(esc(:get_curve_nid))($curve_str)) end end - # Prime Curves (secp, prime, brainpoolP) @prime_curve secp112r1 SecP112r1 @prime_curve secp112r2 SecP112r2 diff --git a/src/point.jl b/src/point.jl index 83fc3fa..7f6ad83 100644 --- a/src/point.jl +++ b/src/point.jl @@ -1,255 +1,262 @@ abstract type OpenSSLPoint <: AbstractPoint end function generator(::Type{P}) where P <: OpenSSLPoint - group = group_pointer(P) - - result = ccall((:EC_POINT_new, libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) - - point = ccall((:EC_GROUP_get0_generator, libcrypto), Ptr{Cvoid}, - (Ptr{Cvoid},), group) - - return P(point) # copy is not necessary as the public API is nonmutating + result = @ccall libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} + point = @ccall libcrypto.EC_GROUP_get0_generator(group::Ptr{Cvoid})::Ptr{Cvoid} + return P(point) end (::Type{P})() where P <: OpenSSLPoint = generator(P) - function Base.:(==)(x::P, y::P) where P <: OpenSSLPoint - + ctx = get_ctx() group = group_pointer(P) - - ret = ccall((:EC_POINT_cmp, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, pointer(x), pointer(y), ctx) - + ret = @ccall libcrypto.EC_POINT_cmp( + group::Ptr{Cvoid}, + pointer(x)::Ptr{Cvoid}, + pointer(y)::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint return ret == 0 end - Base.pointer(p::OpenSSLPoint) = p.pointer function Base.iszero(point::P) where P <: OpenSSLPoint - group = group_pointer(P) - - # Check if point is at infinity - ret = ccall((:EC_POINT_is_at_infinity, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}), - group, pointer(point)) - + ret = @ccall libcrypto.EC_POINT_is_at_infinity( + group::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid} + )::Cint return ret == 1 end - function Base.zero(::Type{P}) where P <: OpenSSLPoint - group = group_pointer(P) - - result = ccall((:EC_POINT_new, libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) - - # Set point at infinity - ret = ccall((:EC_POINT_set_to_infinity, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}), group, result) + result = @ccall libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} + ret = @ccall libcrypto.EC_POINT_set_to_infinity( + group::Ptr{Cvoid}, + result::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed to set point at infinity") end - return P(result) end - function (::Type{P})(bytes::Vector{UInt8}) where P <: OpenSSLPoint - + ctx = get_ctx() group = group_pointer(P) - - point = ccall((:EC_POINT_new, OpenSSL_jll.libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) + point = @ccall OpenSSL_jll.libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} if point == C_NULL error("Failed to create new EC_POINT") end - ret = ccall((:EC_POINT_oct2point, OpenSSL_jll.libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{UInt8}, Csize_t, Ptr{Cvoid}), - group, point, bytes, length(bytes), ctx) + ret = @ccall OpenSSL_jll.libcrypto.EC_POINT_oct2point( + group::Ptr{Cvoid}, + point::Ptr{Cvoid}, + bytes::Ptr{UInt8}, + length(bytes)::Csize_t, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 - ccall((:EC_POINT_free, OpenSSL_jll.libcrypto), Cvoid, (Ptr{Cvoid},), point) + @ccall OpenSSL_jll.libcrypto.EC_POINT_free(point::Ptr{Cvoid})::Cvoid error("Failed to initialize point from bytes") end - return P(point) end function octet_legacy(point::P) where P <: OpenSSLPoint - + ctx = get_ctx() group = group_pointer(P) - buffer_size = 200 # Adjust if needed buffer = Vector{UInt8}(undef, buffer_size) - _length = ccall((:EC_POINT_point2oct, libcrypto), Csize_t, - (Ptr{Cvoid}, Ptr{Cvoid}, Cint, Ptr{UInt8}, Csize_t, Ptr{Cvoid}), - group, pointer(point), 4, buffer, buffer_size, ctx) # 4 is POINT_CONVERSION_UNCOMPRESSED + GC.@preserve buffer begin + _length = @ccall libcrypto.EC_POINT_point2oct( + group::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid}, + 4::Cint, + buffer::Ptr{UInt8}, + buffer_size::Csize_t, + ctx::Ptr{Cvoid} + )::Csize_t + end + if _length == 0 error("Failed to convert result to octet string") end - return buffer[1:_length] end function Base.:+(x::P, y::P) where P <: OpenSSLPoint - + ctx = get_ctx() group = group_pointer(P) - result = ccall((:EC_POINT_new, libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) - - # Perform point addition: result = result + point2 - ret = ccall((:EC_POINT_add, libcrypto), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), group, result, pointer(x), pointer(y), ctx) + result = @ccall libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} + + ret = @ccall libcrypto.EC_POINT_add( + group::Ptr{Cvoid}, + result::Ptr{Cvoid}, + pointer(x)::Ptr{Cvoid}, + pointer(y)::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed in point addition") end - return P(result) end - function Base.:*(k::Integer, point::P) where P <: OpenSSLPoint - - scalar = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) - + scalar = @ccall libcrypto.BN_new()::Ptr{Cvoid} scalar_hex = string(k, base=16) - ret = ccall((:BN_hex2bn, libcrypto), Cint, (Ptr{Ptr{Cvoid}}, Cstring), Ref(scalar), scalar_hex) + ret = @ccall libcrypto.BN_hex2bn( + Ref(scalar)::Ptr{Ptr{Cvoid}}, + scalar_hex::Cstring + )::Cint if ret == 0 error("Failed to set scalar") end + ctx = get_ctx() group = group_pointer(P) - result = ccall((:EC_POINT_new, libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) - - ret = ccall((:EC_POINT_mul, libcrypto), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), group, result, C_NULL, pointer(point), scalar, ctx) + result = @ccall libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} + + ret = @ccall libcrypto.EC_POINT_mul( + group::Ptr{Cvoid}, + result::Ptr{Cvoid}, + C_NULL::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid}, + scalar::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed in point multiplication") end + openssl_bignum_free(scalar) return P(result) end Base.:*(point::OpenSSLPoint, k::Integer) = k * point -function Base.:-(point::P) where P <: OpenSSLPoint # the substraction then is ensured by AbstractPoint - +function Base.:-(point::P) where P <: OpenSSLPoint + ctx = get_ctx() group = group_pointer(P) - - # Create a temporary point for the inverted point inverted_point = copy(point) - ret = ccall((:EC_POINT_invert, libcrypto), Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), group, pointer(inverted_point), ctx) + ret = @ccall libcrypto.EC_POINT_invert( + group::Ptr{Cvoid}, + pointer(inverted_point)::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed to invert point") end - return P(inverted_point) end function Base.copy(point::P) where P <: OpenSSLPoint - group = group_pointer(P) - result = ccall((:EC_POINT_new, libcrypto), Ptr{Cvoid}, (Ptr{Cvoid},), group) + result = @ccall libcrypto.EC_POINT_new(group::Ptr{Cvoid})::Ptr{Cvoid} - ret = ccall((:EC_POINT_copy, libcrypto), Cint, (Ptr{Cvoid}, Ptr{Cvoid}), result, pointer(point)) + ret = @ccall libcrypto.EC_POINT_copy( + result::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed to copy point") end - return P(result) end +function order(::Type{P}) where P <: OpenSSLPoint + ctx = get_ctx() + group = group_pointer(P) + order = @ccall libcrypto.BN_new()::Ptr{Cvoid} -function gx(point::P) where P <: OpenSSLPoint + ret = @ccall libcrypto.EC_GROUP_get_order( + group::Ptr{Cvoid}, + order::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint + if ret != 1 + error("Failed to get order") + end + order_bigint = order |> bn2bigint + openssl_bignum_free(order) + return order_bigint +end + +function cofactor(::Type{P}) where P <: OpenSSLPoint + ctx = get_ctx() + group = group_pointer(P) + cofactor = @ccall libcrypto.BN_new()::Ptr{Cvoid} + + ret = @ccall libcrypto.EC_GROUP_get_cofactor( + group::Ptr{Cvoid}, + cofactor::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint + + cofactor_bigint = cofactor |> bn2bigint + openssl_bignum_free(cofactor) + return cofactor_bigint +end + +function gx(point::P) where P <: OpenSSLPoint F = field(P) x, y = value(point) - return F(x) end function gy(point::P) where P <: OpenSSLPoint - F = field(P) x, y = value(point) - return F(y) end function (::Type{P})(x::F, y::F) where {P <: OpenSSLPoint, F <: Field} - @check F == field(P) - x_bytes = octet(x) y_bytes = octet(y) - po = UInt8[4, x_bytes..., y_bytes...] - return P(po) end - -function order(::Type{P}) where P <: OpenSSLPoint - - group = group_pointer(P) - - order = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) - - # Get the order of the curve - ret = ccall((:EC_GROUP_get_order, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, order, ctx) - if ret != 1 - error("Failed to get order") - end - - return order |> bn2bigint -end - - -function cofactor(::Type{P}) where P <: OpenSSLPoint - - group = group_pointer(P) - - cofactor = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) - ret = ccall((:EC_GROUP_get_cofactor, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, cofactor, ctx) - - return cofactor |> bn2bigint -end - - ### Binary curve point specializations abstract type OpenSSLBinaryPoint <: OpenSSLPoint end function value(point::P) where P <: OpenSSLBinaryPoint - + ctx = get_ctx() group = group_pointer(P) - - gx = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) - gy = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) + gx = @ccall libcrypto.BN_new()::Ptr{Cvoid} + gy = @ccall libcrypto.BN_new()::Ptr{Cvoid} - ret = ccall((:EC_POINT_get_affine_coordinates_GF2m, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, pointer(point), gx, gy, ctx) + ret = @ccall libcrypto.EC_POINT_get_affine_coordinates_GF2m( + group::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid}, + gx::Ptr{Cvoid}, + gy::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed to get generator coordinates") end F = field(P) M = div(bitlength(F), 8, RoundUp) - xf, yf = F(bn2octet(gx, M)), F(bn2octet(gy, M)) - return convert(BitVector, xf), convert(BitVector, yf) # Perhaps I should rather return a tuole of field elements here? + openssl_bignum_free(gx) + openssl_bignum_free(gy) + + return convert(BitVector, xf), convert(BitVector, yf) end - function reducer(bytes::Vector{UInt8}) # Create a BitVector with enough space for all bits bits = BitVector(undef, length(bytes) * 8) @@ -270,40 +277,40 @@ function reducer(bytes::Vector{UInt8}) return bits[N:end] end - function curve_parameters(::Type{P}) where P <: OpenSSLBinaryPoint - + ctx = get_ctx() group = group_pointer(P) - p = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Field characteristic (prime p) + p = @ccall libcrypto.BN_new()::Ptr{Cvoid} + a = @ccall libcrypto.BN_new()::Ptr{Cvoid} + b = @ccall libcrypto.BN_new()::Ptr{Cvoid} - #m = Ref{Cint}() - a = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Curve parameter a - b = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Curve parameter b - - - # For binary curves, we use EC_GROUP_get_curve_GF2m instead of GFp - ret = ccall((:EC_GROUP_get_curve_GF2m, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, p, a, b, ctx) + ret = @ccall libcrypto.EC_GROUP_get_curve_GF2m( + group::Ptr{Cvoid}, + p::Ptr{Cvoid}, + a::Ptr{Cvoid}, + b::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint - hex_str = unsafe_string(ccall((:BN_bn2hex, libcrypto), Ptr{UInt8}, - (Ptr{Cvoid},), p)) + hex_str = unsafe_string(@ccall libcrypto.BN_bn2hex(p::Ptr{Cvoid})::Ptr{UInt8}) m = reducer(hex2bytes(hex_str)) + a_octet = bn2octet(a) + b_octet = bn2octet(b) + + openssl_bignum_free(p) + openssl_bignum_free(a) + openssl_bignum_free(b) - return m, bn2octet(a), bn2octet(b) + return m, a_octet, b_octet end - function field(::Type{P}) where P <: OpenSSLBinaryPoint - reducer, = curve_parameters(P) - return @F2PB{reducer} end - (::Type{P})((x, y)::Tuple{BitVector, BitVector}) where P <: OpenSSLBinaryPoint = P(x, y) function (::Type{P})(x::T, y::T) where {P <: OpenSSLBinaryPoint, T <: Union{BitVector, Vector{UInt8}}} @@ -311,31 +318,23 @@ function (::Type{P})(x::T, y::T) where {P <: OpenSSLBinaryPoint, T <: Union{BitV return P(F(x), F(y)) end - function spec(::Type{P}) where P <: OpenSSLBinaryPoint - _, a_octet, b_octet = curve_parameters(P) - F = field(P) - basis = spec(F) n = order(P) - M = div(bitlength(F), 8, RoundUp) a_bits = convert(BitVector, F(expand(a_octet, M))) b_bits = convert(BitVector, F(expand(b_octet, M))) _cofactor = cofactor(P) |> Int - gx, gy = value(generator(P)) - names = [string(lowercase(string(nameof(P))))] return EC2N(basis, n, a_bits, b_bits, _cofactor, gx, gy; names) end - ### Prime curve point specializations abstract type OpenSSLPrimePoint <: OpenSSLPoint end @@ -352,50 +351,64 @@ modulus(::Type{P}) where P <: OpenSSLPrimePoint = curve_parameters(P) |> first field(::Type{P}) where P <: OpenSSLPrimePoint = FP{static(modulus(P))} function curve_parameters(::Type{P}) where P <: OpenSSLPrimePoint - + ctx = get_ctx() group = group_pointer(P) - p = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Field characteristic (prime p) - a = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Curve parameter a - b = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) # Curve parameter b + p = @ccall libcrypto.BN_new()::Ptr{Cvoid} + a = @ccall libcrypto.BN_new()::Ptr{Cvoid} + b = @ccall libcrypto.BN_new()::Ptr{Cvoid} - # Get field parameters - ret = ccall((:EC_GROUP_get_curve_GFp, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, p, a, b, ctx) + ret = @ccall libcrypto.EC_GROUP_get_curve_GFp( + group::Ptr{Cvoid}, + p::Ptr{Cvoid}, + a::Ptr{Cvoid}, + b::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint if ret != 1 error("Failed to get curve parameters") end - return bn2bigint(p), bn2bigint(a), bn2bigint(b) -end + p_bigint = bn2bigint(p) + a_bigint = bn2bigint(a) + b_bigint = bn2bigint(b) + openssl_bignum_free(p) + openssl_bignum_free(a) + openssl_bignum_free(b) -function spec(::Type{P}) where P <: OpenSSLPoint - + return p_bigint, a_bigint, b_bigint +end + +function spec(::Type{P}) where P <: OpenSSLPoint p, a, b = curve_parameters(P) n = order(P) - _cofactor = cofactor(P) |> Int - gx, gy = value(generator(P)) - names = [string(lowercase(string(nameof(P))))] - return ECP(p, n, a, b, _cofactor, gx, gy; names) end - function value(point::P) where P <: OpenSSLPrimePoint - + ctx = get_ctx() group = group_pointer(P) - gx = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) - gy = ccall((:BN_new, libcrypto), Ptr{Cvoid}, ()) + gx = @ccall libcrypto.BN_new()::Ptr{Cvoid} + gy = @ccall libcrypto.BN_new()::Ptr{Cvoid} + + ret = @ccall libcrypto.EC_POINT_get_affine_coordinates( + group::Ptr{Cvoid}, + pointer(point)::Ptr{Cvoid}, + gx::Ptr{Cvoid}, + gy::Ptr{Cvoid}, + ctx::Ptr{Cvoid} + )::Cint + + gx_bigint = bn2bigint(gx) + gy_bigint = bn2bigint(gy) - ret = ccall((:EC_POINT_get_affine_coordinates, libcrypto), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), - group, pointer(point), gx, gy, ctx) + openssl_bignum_free(gx) + openssl_bignum_free(gy) - return bn2bigint(gx), bn2bigint(gy) + return gx_bigint, gy_bigint end diff --git a/src/utils.jl b/src/utils.jl index d815eb1..118ec58 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,78 @@ -function bn2bigint(bn::Ptr{Cvoid}) +# Common functionality for curve operations +function get_curves_buffer() + # Get the built-in curves count + curves_count = @ccall libcrypto.EC_get_builtin_curves( + C_NULL::Ptr{Cvoid}, + 0::Cint + )::Cint + + # Allocate memory for curve information + # EC_builtin_curve struct is { nid: nid, comment: *c_char } + # Size is typically 16 bytes on 64-bit systems + buffer_size = 16 # sizeof(EC_builtin_curve) + curves_buffer = Vector{UInt8}(undef, curves_count * buffer_size) + + # Get the actual curve information + GC.@preserve curves_buffer begin + ret = @ccall libcrypto.EC_get_builtin_curves( + curves_buffer::Ptr{UInt8}, + curves_count::Cint + )::Cint + end + + return curves_buffer, curves_count, buffer_size +end + +function iterate_curves(callback::Function) + curves_buffer, curves_count, buffer_size = get_curves_buffer() + for i in 0:(curves_count-1) + # Extract NID from the buffer (first field of EC_builtin_curve struct) + nid = unsafe_load(Ptr{Cint}(pointer(curves_buffer) + i * buffer_size)) + + # Get the curve name using NID + sn = @ccall libcrypto.OBJ_nid2sn(nid::Cint)::Ptr{UInt8} + if sn != C_NULL + name = unsafe_string(sn) + + # Get long name (description) + ln = @ccall libcrypto.OBJ_nid2ln(nid::Cint)::Ptr{UInt8} + description = ln != C_NULL ? unsafe_string(ln) : "" + + callback(nid, name, description) + end + end +end + +function list_curves() + println("Total number of supported curves: ", get_curves_buffer()[2]) + println("\nSupported curves:") + println("NID\tCurve Name") + println("-" ^ 50) + + iterate_curves() do nid, name, description + println("$nid\t$name\t$description") + end +end + +function get_curve_nid(name::String) + result = Ref{Int}() + found = false + + iterate_curves() do nid, curve_name, description + if curve_name == name || description == name + result[] = nid + found = true + end + end + + found || error("Curve $name not found") + return result[] +end + +function bn2bigint(bn::Ptr{Cvoid}) # Convert BIGNUM to hex string - hex_str = ccall((:BN_bn2hex, OpenSSL_jll.libcrypto), Ptr{UInt8}, (Ptr{Cvoid},), bn) + hex_str = @ccall OpenSSL_jll.libcrypto.BN_bn2hex(bn::Ptr{Cvoid})::Ptr{UInt8} if hex_str == C_NULL error("Failed to convert BIGNUM to hex") end @@ -13,11 +84,9 @@ function bn2bigint(bn::Ptr{Cvoid}) return parse(BigInt, jl_hex_str, base=16) end - function bn2octet(bn::Ptr{Cvoid}) - # Convert BIGNUM to hex string - hex_str = ccall((:BN_bn2hex, OpenSSL_jll.libcrypto), Ptr{UInt8}, (Ptr{Cvoid},), bn) + hex_str = @ccall OpenSSL_jll.libcrypto.BN_bn2hex(bn::Ptr{Cvoid})::Ptr{UInt8} if hex_str == C_NULL error("Failed to convert BIGNUM to hex") end @@ -29,7 +98,7 @@ function bn2octet(bn::Ptr{Cvoid}) jl_hex_str = "0" * jl_hex_str end - return hex2bytes(jl_hex_str) # May need a reverse to match specification here + return hex2bytes(jl_hex_str) end function expand(x::Vector{UInt8}, n::Int) @@ -37,102 +106,26 @@ function expand(x::Vector{UInt8}, n::Int) end bn2octet(bn::Ptr{Cvoid}, n::Int) = expand(bn2octet(bn), n) -#bn2octet(bn::Ptr{Cvoid}, n::Int) = reverse(expand(reverse(bn2octet(bn)), n)) - -function list_curves() - #libcrypto = OpenSSL_jll.libcrypto - - # Get the built-in curves count - curves_count = ccall((:EC_get_builtin_curves, libcrypto), Cint, - (Ptr{Cvoid}, Cint), C_NULL, 0) - - println("Total number of supported curves: ", curves_count) - - # Allocate memory for curve information - # EC_builtin_curve struct is { nid: nid, comment: *c_char } - # Size is typically 16 bytes on 64-bit systems - buffer_size = 16 # sizeof(EC_builtin_curve) - curves_buffer = Vector{UInt8}(undef, curves_count * buffer_size) - - # Get the actual curve information - ret = ccall((:EC_get_builtin_curves, libcrypto), Cint, - (Ptr{UInt8}, Cint), - curves_buffer, curves_count) - - println("\nSupported curves:") - println("NID\tCurve Name") - println("-" ^ 50) - - # Process each curve - for i in 0:(curves_count-1) - # Extract NID from the buffer (first field of EC_builtin_curve struct) - nid = unsafe_load(Ptr{Cint}(pointer(curves_buffer) + i * buffer_size)) - - # Get the curve name using NID - sn = ccall((:OBJ_nid2sn, libcrypto), Ptr{UInt8}, (Cint,), nid) - if sn != C_NULL - name = unsafe_string(sn) - - # Get long name (description) - ln = ccall((:OBJ_nid2ln, libcrypto), Ptr{UInt8}, (Cint,), nid) - description = "" - if ln != C_NULL - description = unsafe_string(ln) - end - - println("$nid\t$name\t$description") - end - end -end function group_pointer(enum::Int) - # Create a new EC_GROUP object for the secp256k1 curve - group = ccall((:EC_GROUP_new_by_curve_name, libcrypto), Ptr{Cvoid}, (Cint,), enum) + group = @ccall libcrypto.EC_GROUP_new_by_curve_name(enum::Cint)::Ptr{Cvoid} if group == C_NULL error("Failed to create EC_GROUP") end - return group end -function get_curve_nid(name::String) - - # Get the built-in curves count - curves_count = ccall((:EC_get_builtin_curves, libcrypto), Cint, - (Ptr{Cvoid}, Cint), C_NULL, 0) - - - buffer_size = 16 # sizeof(EC_builtin_curve) - curves_buffer = Vector{UInt8}(undef, curves_count * buffer_size) - - # Get the actual curve information - ret = ccall((:EC_get_builtin_curves, libcrypto), Cint, - (Ptr{UInt8}, Cint), - curves_buffer, curves_count) - - # Process each curve - for i in 0:(curves_count-1) - # Extract NID from the buffer (first field of EC_builtin_curve struct) - nid = unsafe_load(Ptr{Cint}(pointer(curves_buffer) + i * buffer_size)) - - # Get the curve name using NID - sn = ccall((:OBJ_nid2sn, libcrypto), Ptr{UInt8}, (Cint,), nid) - if sn != C_NULL - _name = unsafe_string(sn) - - if _name == name - return nid |> Int - end - - # Get long name (description) - ln = ccall((:OBJ_nid2ln, libcrypto), Ptr{UInt8}, (Cint,), nid) - - if ln != C_NULL && name == unsafe_string(ln) - return nid |> Int - end - end +function openssl_point_free(ptr::Ptr{Nothing}) + if ptr != C_NULL + @debug "Freeing OpenSSL point" pointer=ptr + @ccall libcrypto.EC_POINT_free(ptr::Ptr{Nothing})::Cvoid end +end - error("Curve $name not found") +function openssl_bignum_free(ptr::Ptr{Nothing}) + if ptr != C_NULL + @debug "Freeing OpenSSL bignum" pointer=ptr + @ccall libcrypto.BN_free(ptr::Ptr{Cvoid})::Cvoid + end end diff --git a/test/runtests.jl b/test/runtests.jl index 9875327..3959e48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using SafeTestsets +sleep(1) + @safetestset "Testing point arithmetics and serialization" begin include("basics.jl") end