Skip to content

Commit

Permalink
Change deserialize ABI to be callable before user object creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerilk committed Jul 9, 2024
1 parent 0c0e487 commit 2ec5ba0
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 44 deletions.
13 changes: 9 additions & 4 deletions bindings/python/cconfigspace/tree_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import Object, Error, Result, ccs_rng, ccs_tree, ccs_tree_space, ccs_feature_space, ccs_features, ccs_tree_configuration, Datum, ccs_bool, _ccs_get_function, CEnumeration, _register_vector, _unregister_vector, ccs_retain_object
from .rng import Rng
from .tree import Tree
from .feature_space import FeatureSpace

class TreeSpaceType(CEnumeration):
_members_ = [
Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(self, handle = None, retain = False, auto_release = True,
ccs_dynamic_tree_space_del_type = ct.CFUNCTYPE(Result, ccs_tree_space)
ccs_dynamic_tree_space_get_child_type = ct.CFUNCTYPE(Result, ccs_tree_space, ccs_tree, ct.c_size_t, ct.POINTER(ccs_tree))
ccs_dynamic_tree_space_serialize_type = ct.CFUNCTYPE(Result, ccs_tree_space, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t))
ccs_dynamic_tree_space_deserialize_type = ct.CFUNCTYPE(Result, ccs_tree_space, ct.c_size_t, ct.c_void_p)
ccs_dynamic_tree_space_deserialize_type = ct.CFUNCTYPE(Result, ccs_tree, ccs_feature_space, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_void_p))

class DynamicTreeSpaceVector(ct.Structure):
_fields_ = [
Expand Down Expand Up @@ -224,15 +225,19 @@ def serialize_wrapper(ts, state_size, p_state, p_state_size):
serialize_wrapper = 0

if deserialize is not None:
def deserialize_wrapper(ts, state_size, p_state):
def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data):
try:
ts = ct.cast(ts, ccs_tree_space)
t = ct.cast(tree, ccs_tree)
p_s = ct.cast(p_state, ct.c_void_p)
p_t = ct.cast(p_tree_space_data, ct.c_void_p)
if p_s.value is None:
state = None
else:
state = ct.cast(p_s, POINTER(c_byte * state_size))
deserialize(TreeSpace.from_handle(ts), state)
tree_space_data = deserialize(Tree.from_handle(t), FeatureSpace.from_handle(feature_space) if feature_space else None, state)
c_tree_space_data = ct.py_object(tree_space_data)
p_t[0] = c_tree_space_data
ct.pythonapi.Py_IncRef(c_tree_space_data)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
Expand Down
12 changes: 8 additions & 4 deletions bindings/python/cconfigspace/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(self, handle = None, retain = False, auto_release = True,
ccs_user_defined_tuner_get_history_type = ct.CFUNCTYPE(Result, ccs_tuner, ccs_features, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_suggest_type = ct.CFUNCTYPE(Result, ccs_tuner, ccs_features, ct.POINTER(ccs_search_configuration))
ccs_user_defined_tuner_serialize_type = ct.CFUNCTYPE(Result, ccs_tuner, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_deserialize_type = ct.CFUNCTYPE(Result, ccs_tuner, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.c_void_p)
ccs_user_defined_tuner_deserialize_type = ct.CFUNCTYPE(Result, ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_void_p))

class UserDefinedTunerVector(ct.Structure):
_fields_ = [
Expand Down Expand Up @@ -307,12 +307,13 @@ def serialize_wrapper(tun, state_size, p_state, p_state_size):
serialize_wrapper = 0

if deserialize is not None:
def deserialize_wrapper(tun, size_history, p_history, num_optima, p_optima, state_size, p_state):
def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data):
try:
tun = ct.cast(tun, ccs_tuner)
o_space = ct.cast(o_space, ccs_objective_space)
p_h = ct.cast(p_history, ct.c_void_p)
p_o = ct.cast(p_optima, ct.c_void_p)
p_s = ct.cast(p_state, ct.c_void_p)
p_t = ct.cast(p_tuner_data, ct.c_void_p)
if p_h.value is None:
history = []
else:
Expand All @@ -325,7 +326,10 @@ def deserialize_wrapper(tun, size_history, p_history, num_optima, p_optima, stat
state = None
else:
state = ct.cast(p_s, POINTER(c_byte * state_size))
deserialize(Tuner.from_handle(tun), history, optima, state)
tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state)
c_tuner_data = ct.py_object(tuner_data)
p_t[0] = c_tuner_data
ct.pythonapi.Py_IncRef(c_tuner_data)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
Expand Down
8 changes: 5 additions & 3 deletions bindings/ruby/lib/cconfigspace/tree_space.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def initialize(handle = nil, retain: false, auto_release: true,
callback :ccs_dynamic_tree_space_del, [:ccs_tree_space_t], :ccs_result_t
callback :ccs_dynamic_tree_space_get_child, [:ccs_tree_space_t, :ccs_tree_t, :size_t, :pointer], :ccs_result_t
callback :ccs_dynamic_tree_space_serialize, [:ccs_tree_space_t, :size_t, :pointer, :pointer], :ccs_result_t
callback :ccs_dynamic_tree_space_deserialize, [:ccs_tree_space_t, :size_t, :pointer], :ccs_result_t
callback :ccs_dynamic_tree_space_deserialize, [:ccs_tree_t, :ccs_feature_space_t, :size_t, :pointer, :pointer], :ccs_result_t

class DynamicTreeSpaceVector < FFI::Struct
layout :del, :ccs_dynamic_tree_space_del,
Expand Down Expand Up @@ -171,10 +171,12 @@ def self.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserializ
end
deserializewrapper =
if deserialize
lambda { |ts, state_size, p_state|
lambda { |t, feature_space, state_size, p_state, p_tree_space_data|
begin
state = p_state.null? ? nil : p_state.slice(0, state_size)
deserialize(TreeSpace.from_handle(ts), state)
tree_space_data = deserialize(Tree.from_handle(t), feature_space.null? ? nil : FeatureSpace.from_handle(feature_space), state)
p_tree_space_data.write_value(tree_space_data)
FFI.inc_ref(tree_space_data)
CCSError.to_native(:CCS_RESULT_SUCCESS)
rescue => e
CCS.set_error(e)
Expand Down
8 changes: 5 additions & 3 deletions bindings/ruby/lib/cconfigspace/tuner.rb
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def initialize(handle = nil, retain: false, auto_release: true,
callback :ccs_user_defined_tuner_get_history, [:ccs_tuner_t, :ccs_features_t, :size_t, :pointer, :pointer], :ccs_result_t
callback :ccs_user_defined_tuner_suggest, [:ccs_tuner_t, :ccs_features_t, :pointer], :ccs_result_t
callback :ccs_user_defined_tuner_serialize, [:ccs_tuner_t, :size_t, :pointer, :pointer], :ccs_result_t
callback :ccs_user_defined_tuner_deserialize, [:ccs_tuner_t, :size_t, :pointer, :size_t, :pointer, :size_t, :pointer], :ccs_result_t
callback :ccs_user_defined_tuner_deserialize, [:ccs_objective_space_t, :size_t, :pointer, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t

class UserDefinedTunerVector < FFI::Struct
layout :del, :ccs_user_defined_tuner_del,
Expand Down Expand Up @@ -244,12 +244,14 @@ def self.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_histo
end
deserializewrapper =
if deserialize
lambda { |tun, history_size, p_history, num_optima, p_optima, state_size, p_state|
lambda { |o_space, history_size, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data|
begin
history = p_history.null? ? [] : history_size.times.collect { |i| Evaluation::from_handle(p_p_history.get_pointer(i*8)) }
optima = p_optima.null? ? [] : num_optima.times.collect { |i| Evaluation::from_handle(p_optima.get_pointer(i*8)) }
state = p_state.null? ? nil : p_state.slice(0, state_size)
deserialize(Tuner.from_handle(tun), history, optima, state)
tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state)
p_tuner_data.write_value(tuner_data)
FFI.inc_ref(tuner_data)
CCSError.to_native(:CCS_RESULT_SUCCESS)
rescue => e
CCS.set_error(e)
Expand Down
11 changes: 7 additions & 4 deletions include/cconfigspace/tree_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ struct ccs_dynamic_tree_space_vector_s {

/**
* The tree space deserialization interface, can be NULL. In this case,
* only the tree is deserialized.
* only the tree is deserialized. Must return the tree space data
* to use at initialization
*/
ccs_result_t (*deserialize_state)(
ccs_tree_space_t tree_space,
size_t state_size,
const void *state);
ccs_tree_t tree,
ccs_feature_space_t feature_space,
size_t state_size,
const void *state,
void **tree_space_data_ret);
};

/**
Expand Down
20 changes: 11 additions & 9 deletions include/cconfigspace/tuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,17 +345,19 @@ struct ccs_user_defined_tuner_vector_s {
size_t *state_size_ret);

/**
* The tuner deserialization interface, can be NULL, in which case,
* the history will be set through the tell interface
* The tuner deserialization interface, can be NULL, in which
* case, the history will be set through the tell interface. Must
* return the tuner data to use at initialization
*/
ccs_result_t (*deserialize_state)(
ccs_tuner_t tuner,
size_t size_history,
ccs_evaluation_t *history,
size_t num_optima,
ccs_evaluation_t *optima,
size_t state_size,
const void *state);
ccs_objective_space_t objective_space,
size_t size_history,
ccs_evaluation_t *history,
size_t num_optima,
ccs_evaluation_t *optima,
size_t state_size,
const void *state,
void **tuner_data_ret);
};

/**
Expand Down
21 changes: 14 additions & 7 deletions src/tree_space_deserialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,25 @@ _ccs_deserialize_bin_tree_space_dynamic(
data, version, buffer_size, buffer, opts),
end);
CCS_VALIDATE(_ccs_deserialize_bin_ccs_blob(&blob, buffer_size, buffer));
CCS_VALIDATE_ERR_GOTO(
res,
ccs_create_dynamic_tree_space(
data->name, data->tree, data->feature_space, data->rng,
vector, opts->data, tree_space_ret),
end);

void *tree_space_data;
if (vector->deserialize_state)
CCS_VALIDATE_ERR_GOTO(
res,
vector->deserialize_state(
*tree_space_ret, blob.sz, blob.blob),
data->tree, data->feature_space,
blob.sz, blob.blob, &tree_space_data),
tree_space);
else
tree_space_data = opts->data;

CCS_VALIDATE_ERR_GOTO(
res,
ccs_create_dynamic_tree_space(
data->name, data->tree,
data->feature_space, data->rng,
vector, tree_space_data, tree_space_ret),
end);
goto end;
tree_space:
ccs_release_object(*tree_space_ret);
Expand Down
26 changes: 16 additions & 10 deletions src/tuner_deserialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,24 +187,30 @@ _ccs_deserialize_bin_user_defined_tuner(
_ccs_deserialize_bin_ccs_user_defined_tuner_data(
&data, version, buffer_size, buffer, opts),
end);
CCS_VALIDATE_ERR_GOTO(
res,
ccs_create_user_defined_tuner(
data.base_data.common_data.name,
data.base_data.common_data.objective_space, vector,
opts->data, tuner_ret),
end);

void *tuner_data;
if (vector->deserialize_state)
CCS_VALIDATE_ERR_GOTO(
res,
vector->deserialize_state(
*tuner_ret, data.base_data.size_history,
data.base_data.common_data.objective_space,
data.base_data.size_history,
data.base_data.history,
data.base_data.size_optima,
data.base_data.optima, data.blob.sz,
data.blob.blob),
tuner);
data.blob.blob, &tuner_data),
end);
else
tuner_data = opts->data;

CCS_VALIDATE_ERR_GOTO(
res,
ccs_create_user_defined_tuner(
data.base_data.common_data.name,
data.base_data.common_data.objective_space,
vector, tuner_data, tuner_ret),
end);
if (!vector->deserialize_state)
CCS_VALIDATE_ERR_GOTO(
res,
vector->tell(
Expand Down

0 comments on commit 2ec5ba0

Please sign in to comment.