diff --git a/packages/ciphernode/aggregator/src/plaintext_aggregator.rs b/packages/ciphernode/aggregator/src/plaintext_aggregator.rs index 7b833b60..5c1411d5 100644 --- a/packages/ciphernode/aggregator/src/plaintext_aggregator.rs +++ b/packages/ciphernode/aggregator/src/plaintext_aggregator.rs @@ -1,6 +1,7 @@ use actix::prelude::*; use anyhow::Result; use async_trait::async_trait; +use data::Persistable; use data::{Checkpoint, FromSnapshotWithParams, Repository, Snapshot}; use enclave_core::{ DecryptionshareCreated, Die, E3id, EnclaveEvent, EventBus, OrderedSet, PlaintextAggregated, @@ -50,28 +51,28 @@ struct ComputeAggregate { pub struct PlaintextAggregator { fhe: Arc, bus: Addr, - store: Repository, sortition: Addr, e3_id: E3id, - state: PlaintextAggregatorState, + state: Persistable, src_chain_id: u64, } pub struct PlaintextAggregatorParams { pub fhe: Arc, pub bus: Addr, - pub store: Repository, pub sortition: Addr, pub e3_id: E3id, pub src_chain_id: u64, } impl PlaintextAggregator { - pub fn new(params: PlaintextAggregatorParams, state: PlaintextAggregatorState) -> Self { + pub fn new( + params: PlaintextAggregatorParams, + state: Persistable, + ) -> Self { PlaintextAggregator { fhe: params.fhe, bus: params.bus, - store: params.store, sortition: params.sortition, e3_id: params.e3_id, src_chain_id: params.src_chain_id, @@ -79,36 +80,40 @@ impl PlaintextAggregator { } } - pub fn add_share(&mut self, share: Vec) -> Result { - let PlaintextAggregatorState::Collecting { - threshold_m, - shares, - ciphertext_output, - .. - } = &mut self.state - else { - return Err(anyhow::anyhow!("Can only add share in Collecting state")); - }; - - shares.insert(share); - if shares.len() == *threshold_m { - return Ok(PlaintextAggregatorState::Computing { - shares: shares.clone(), - ciphertext_output: ciphertext_output.to_vec(), - }); - } - - Ok(self.state.clone()) + pub fn add_share(&mut self, share: Vec) -> Result<()> { + self.state.try_mutate(|mut state| { + let PlaintextAggregatorState::Collecting { + threshold_m, + shares, + ciphertext_output, + .. + } = &mut state + else { + return Err(anyhow::anyhow!("Can only add share in Collecting state")); + }; + + shares.insert(share); + + if shares.len() == *threshold_m { + return Ok(PlaintextAggregatorState::Computing { + shares: shares.clone(), + ciphertext_output: ciphertext_output.to_vec(), + }); + } + + Ok(state) + }) } - pub fn set_decryption(&mut self, decrypted: Vec) -> Result { - let PlaintextAggregatorState::Computing { shares, .. } = &mut self.state else { - return Ok(self.state.clone()); - }; + pub fn set_decryption(&mut self, decrypted: Vec) -> Result<()> { + self.state.try_mutate(|mut state| { + let PlaintextAggregatorState::Computing { shares, .. } = &mut state else { + return Ok(state.clone()); + }; + let shares = shares.to_owned(); - let shares = shares.to_owned(); - - Ok(PlaintextAggregatorState::Complete { decrypted, shares }) + Ok(PlaintextAggregatorState::Complete { decrypted, shares }) + }) } } @@ -131,9 +136,9 @@ impl Handler for PlaintextAggregator { type Result = ResponseActFuture>; fn handle(&mut self, event: DecryptionshareCreated, _: &mut Self::Context) -> Self::Result { - let PlaintextAggregatorState::Collecting { + let Some(PlaintextAggregatorState::Collecting { threshold_m, seed, .. - } = self.state + }) = self.state.get() else { error!(state=?self.state, "Aggregator has been closed for collecting."); return Box::pin(fut::ready(Ok(()))); @@ -165,14 +170,13 @@ impl Handler for PlaintextAggregator { } // add the keyshare and - act.state = act.add_share(decryption_share)?; - act.checkpoint(); + act.add_share(decryption_share)?; // Check the state and if it has changed to the computing - if let PlaintextAggregatorState::Computing { + if let Some(PlaintextAggregatorState::Computing { shares, ciphertext_output, - } = &act.state + }) = &act.state.get() { ctx.notify(ComputeAggregate { shares: shares.clone(), @@ -195,8 +199,7 @@ impl Handler for PlaintextAggregator { })?; // Update the local state - self.state = self.set_decryption(decrypted_output.clone())?; - self.checkpoint(); + self.set_decryption(decrypted_output.clone())?; // Dispatch the PlaintextAggregated event let event = EnclaveEvent::from(PlaintextAggregated { @@ -217,26 +220,3 @@ impl Handler for PlaintextAggregator { ctx.stop() } } - -impl Snapshot for PlaintextAggregator { - type Snapshot = PlaintextAggregatorState; - - fn snapshot(&self) -> Result { - Ok(self.state.clone()) - } -} - -#[async_trait] -impl FromSnapshotWithParams for PlaintextAggregator { - type Params = PlaintextAggregatorParams; - - async fn from_snapshot(params: Self::Params, snapshot: Self::Snapshot) -> Result { - Ok(PlaintextAggregator::new(params, snapshot)) - } -} - -impl Checkpoint for PlaintextAggregator { - fn repository(&self) -> &Repository { - &self.store - } -} diff --git a/packages/ciphernode/router/src/hooks.rs b/packages/ciphernode/router/src/hooks.rs index f0388937..eea1bc35 100644 --- a/packages/ciphernode/router/src/hooks.rs +++ b/packages/ciphernode/router/src/hooks.rs @@ -7,7 +7,7 @@ use aggregator::{ use anyhow::{anyhow, Result}; use async_trait::async_trait; use cipher::Cipher; -use data::{FromSnapshotWithParams, Snapshot}; +use data::{AutoPersist, FromSnapshotWithParams, Snapshot}; use enclave_core::{BusError, E3Requested, EnclaveErrorType, EnclaveEvent, EventBus}; use fhe::{Fhe, SharedRng}; use keyshare::{Keyshare, KeyshareParams}; @@ -228,22 +228,23 @@ impl E3Feature for PlaintextAggregatorFeature { }; let e3_id = data.e3_id.clone(); - - let _ = ctx.set_plaintext( + let repo = ctx.repositories().plaintext(&e3_id); + let sync_state = repo.send(Some(PlaintextAggregatorState::init( + meta.threshold_m, + meta.seed, + data.ciphertext_output.clone(), + ))); + + ctx.set_plaintext( PlaintextAggregator::new( PlaintextAggregatorParams { fhe: fhe.clone(), bus: self.bus.clone(), - store: ctx.repositories().plaintext(&e3_id), sortition: self.sortition.clone(), - e3_id, + e3_id: e3_id.clone(), src_chain_id: meta.src_chain_id, }, - PlaintextAggregatorState::init( - meta.threshold_m, - meta.seed, - data.ciphertext_output.clone(), - ), + sync_state, ) .start(), ); @@ -259,10 +260,11 @@ impl E3Feature for PlaintextAggregatorFeature { return Ok(()); } - let store = ctx.repositories().plaintext(&snapshot.e3_id); + let repo = ctx.repositories().plaintext(&snapshot.e3_id); + let sync_state = repo.load().await?; // No Snapshot returned from the store -> bail - let Some(snap) = store.read().await? else { + if !sync_state.has() { return Ok(()); }; @@ -283,18 +285,16 @@ impl E3Feature for PlaintextAggregatorFeature { return Ok(()); }; - let value = PlaintextAggregator::from_snapshot( + let value = PlaintextAggregator::new( PlaintextAggregatorParams { fhe: fhe.clone(), bus: self.bus.clone(), - store, sortition: self.sortition.clone(), e3_id: ctx.e3_id.clone(), src_chain_id: meta.src_chain_id, }, - snap, + sync_state, ) - .await? .start(); // send to context