diff --git a/skale/contracts/manager/node_rotation.py b/skale/contracts/manager/node_rotation.py index 02848093..87368f4f 100644 --- a/skale/contracts/manager/node_rotation.py +++ b/skale/contracts/manager/node_rotation.py @@ -35,7 +35,7 @@ @dataclass class Rotation: - node_id: int + leaving_node_id: int new_node_id: int freeze_until: int rotation_counter: int @@ -49,7 +49,7 @@ class NodeRotation(BaseContract): def schains(self): return self.skale.schains - def get_rotation_obj(self, schain_name): + def get_rotation_obj(self, schain_name) -> Rotation: schain_id = self.schains.name_to_id(schain_name) rotation_data = self.contract.functions.getRotation(schain_id).call() return Rotation(*rotation_data) diff --git a/skale/schain_config/rotation_history.py b/skale/schain_config/rotation_history.py index 042d49ad..f5cf82bd 100644 --- a/skale/schain_config/rotation_history.py +++ b/skale/schain_config/rotation_history.py @@ -28,7 +28,7 @@ RotationNodeData = namedtuple('RotationNodeData', ['index', 'node_id', 'public_key']) -def get_previous_schain_groups(skale, schain_name: str) -> dict: +def get_previous_schain_groups(skale, schain_name: str, leaving_node_id=None) -> dict: """ Returns all previous node groups with public keys and finish timestamps. In case of no rotations returns the current state. @@ -41,7 +41,6 @@ def get_previous_schain_groups(skale, schain_name: str) -> dict: current_public_key = skale.key_storage.get_common_public_key(group_id) rotation = skale.node_rotation.get_rotation_obj(schain_name) - rotation_delay = skale.constants_holder.get_rotation_delay() logger.info(f'Rotation data for {schain_name}: {rotation}') @@ -50,7 +49,7 @@ def get_previous_schain_groups(skale, schain_name: str) -> dict: return node_groups _add_previous_schain_rotations_state( - skale, node_groups, rotation, schain_name, previous_public_keys, rotation_delay) + skale, node_groups, rotation, schain_name, previous_public_keys, leaving_node_id) return node_groups @@ -72,6 +71,7 @@ def _add_current_schain_state( current_nodes[node_id] = RotationNodeData(index, node_id, public_key) node_groups[rotation.rotation_counter] = { + 'rotation': None, 'nodes': current_nodes, 'finish_ts': None, 'bls_public_key': _compose_bls_public_key_info(current_public_key) @@ -84,7 +84,7 @@ def _add_previous_schain_rotations_state( rotation: Rotation, schain_name: str, previous_public_keys: list, - rotation_delay: int + leaving_node_id=None ) -> dict: """ Internal function, handles rotations from (rotation_counter - 2) to 0 and adds them to the @@ -126,7 +126,13 @@ def _add_previous_schain_rotations_state( else: bls_public_key, node_finish_ts = None, None + logger.info(f'Adding rotation: {previous_node_id} -> {latest_exited_node_id}') + node_groups[rotation_id] = { + 'rotation': { + 'leaving_node_id': previous_node_id, + 'new_node_id': latest_exited_node_id + }, 'nodes': nodes, 'finish_ts': node_finish_ts, 'bls_public_key': bls_public_key @@ -134,6 +140,10 @@ def _add_previous_schain_rotations_state( del previous_nodes[latest_exited_node_id] + if leaving_node_id and previous_node_id == leaving_node_id: + logger.info(f'Finishing rotation history parsing: {leaving_node_id} found') + break + def _pop_previous_bls_public_key(previous_public_keys): """ @@ -153,3 +163,23 @@ def _compose_bls_public_key_info(bls_public_key: str) -> dict: 'blsPublicKey2': str(bls_public_key[1][0]), 'blsPublicKey3': str(bls_public_key[1][1]) } + + +def get_new_nodes_list(skale: Skale, name: str, node_groups) -> list: + """Returns list of new nodes in for the latest rotation""" + logger.info(f'Getting new nodes list for chain {name}') + rotation = skale.node_rotation.get_rotation_obj(name) + current_group_ids = node_groups[rotation.rotation_counter]['nodes'].keys() + new_nodes = [] + for index in node_groups: + past_rotation = node_groups[index]['rotation'] + if not past_rotation: + continue + if past_rotation['new_node_id'] in current_group_ids: + new_nodes.append(past_rotation['new_node_id']) + else: + logger.info(f'{past_rotation["new_node_id"]} NOT IN {current_group_ids}') + if rotation.leaving_node_id == past_rotation['leaving_node_id']: + break + logger.info(f'New nodes list for chain {name}: {new_nodes}') + return new_nodes diff --git a/skale/utils/contracts_provision/main.py b/skale/utils/contracts_provision/main.py index cff2b519..4b1a3cd9 100644 --- a/skale/utils/contracts_provision/main.py +++ b/skale/utils/contracts_provision/main.py @@ -125,6 +125,14 @@ def add_test4_schain_type(skale) -> TxRes: def cleanup_nodes_schains(skale): + try: + _cleanup_nodes_schains(skale) + except Exception as e: + print(f'Cleanup failed: {e}') + _cleanup_nodes_schains(skale) + + +def _cleanup_nodes_schains(skale): print('Cleanup nodes and schains') for schain_id in skale.schains_internal.get_all_schains_ids(): schain_data = skale.schains.get(schain_id) diff --git a/tests/manager/node_rotation_test.py b/tests/manager/node_rotation_test.py index 51f94935..625449f4 100644 --- a/tests/manager/node_rotation_test.py +++ b/tests/manager/node_rotation_test.py @@ -17,7 +17,7 @@ def test_get_rotation(skale): def test_get_rotation_obj(skale): assert skale.node_rotation.get_rotation_obj(DEFAULT_SCHAIN_NAME) == Rotation( - node_id=0, + leaving_node_id=0, new_node_id=0, freeze_until=0, rotation_counter=0 diff --git a/tests/rotation_history/rotation_history_test.py b/tests/rotation_history/rotation_history_test.py index f2400f85..8c96b384 100644 --- a/tests/rotation_history/rotation_history_test.py +++ b/tests/rotation_history/rotation_history_test.py @@ -6,7 +6,7 @@ add_test4_schain_type, cleanup_nodes_schains, create_schain, add_test2_schain_type ) from skale.utils.contracts_provision import DEFAULT_SCHAIN_NAME -from skale.schain_config.rotation_history import get_previous_schain_groups +from skale.schain_config.rotation_history import get_previous_schain_groups, get_new_nodes_list from tests.rotation_history.utils import set_up_nodes, run_dkg, rotate_node, fail_dkg logger = logging.getLogger(__name__) @@ -188,3 +188,43 @@ def test_rotation_history_failed_dkg(skale): # no finish_ts because it's the current group assert not node_groups[3]['finish_ts'] assert node_groups[3]['bls_public_key'] + + +def test_get_new_nodes_list(skale): + cleanup_nodes_schains(skale) + nodes, skale_instances = set_up_nodes(skale, 4) + add_test4_schain_type(skale) + name = create_schain(skale, random_name=True) + group_index = skale.web3.sha3(text=name) + + run_dkg(nodes, skale_instances, group_index) + + exiting_node_index = 1 + rotate_node(skale, group_index, nodes, skale_instances, exiting_node_index, do_dkg=False) + + failed_node_index = 2 + second_failed_node_index = 3 + test_new_node_ids = fail_dkg( + skale=skale, + nodes=nodes, + skale_instances=skale_instances, + group_index=group_index, + failed_node_index=failed_node_index, + second_failed_node_index=second_failed_node_index + ) + + rotation = skale.node_rotation.get_rotation_obj(name) + node_groups = get_previous_schain_groups(skale, name, rotation.leaving_node_id) + new_nodes = get_new_nodes_list(skale, name, node_groups) + + assert len(new_nodes) == 3 + assert all(x in new_nodes for x in test_new_node_ids) + + exiting_node_index = 3 + rotate_node(skale, group_index, nodes, skale_instances, exiting_node_index) + + rotation = skale.node_rotation.get_rotation_obj(name) + node_groups = get_previous_schain_groups(skale, name, rotation.leaving_node_id) + new_nodes = get_new_nodes_list(skale, name, node_groups) + + assert len(new_nodes) == 1 diff --git a/tests/rotation_history/utils.py b/tests/rotation_history/utils.py index d1945e2b..df7a4b65 100644 --- a/tests/rotation_history/utils.py +++ b/tests/rotation_history/utils.py @@ -144,8 +144,18 @@ def rotate_node(skale, group_index, nodes, skale_instances, exiting_node_index, return nodes, skale_instances -def fail_dkg(skale, nodes, skale_instances, group_index, failed_node_index): +def fail_dkg( + skale, + nodes, + skale_instances, + group_index, + failed_node_index, + second_failed_node_index=None +) -> list: + logger.info('Failing first DKG...') + new_node_ids = [] new_nodes, new_skale_instances = set_up_nodes(skale, 1) + new_node_ids.append(new_nodes[0]['node_id']) send_broadcasts(nodes, skale_instances, group_index, failed_node_index) _skip_evm_time(skale_instances[0].web3, skale.constants_holder.get_dkg_timeout()) @@ -153,10 +163,25 @@ def fail_dkg(skale, nodes, skale_instances, group_index, failed_node_index): nodes[failed_node_index] = new_nodes[0] skale_instances[failed_node_index] = new_skale_instances[0] + + if second_failed_node_index: + logger.info('Failing second DKG...') + new_nodes, new_skale_instances = set_up_nodes(skale, 1) + new_node_ids.append(new_nodes[0]['node_id']) + + send_broadcasts(nodes, skale_instances, group_index, second_failed_node_index) + _skip_evm_time(skale_instances[0].web3, skale.constants_holder.get_dkg_timeout()) + send_complaint(nodes, skale_instances, group_index, second_failed_node_index) + + nodes[second_failed_node_index] = new_nodes[0] + skale_instances[second_failed_node_index] = new_skale_instances[0] + run_dkg(nodes, skale_instances, group_index) + return new_node_ids def run_dkg(nodes, skale_instances, group_index): + logger.info('Running DKG procedure...') send_broadcasts(nodes, skale_instances, group_index) send_alrights(nodes, skale_instances, group_index) _skip_evm_time(skale_instances[0].web3, TEST_ROTATION_DELAY)