Skip to content

Commit

Permalink
Add Persistable to PlaintextAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
ryardley committed Dec 27, 2024
1 parent fda0eea commit 52009ef
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 79 deletions.
106 changes: 43 additions & 63 deletions packages/ciphernode/aggregator/src/plaintext_aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use actix::prelude::*;
use anyhow::Result;
use async_trait::async_trait;
use data::Persistable;
use data::{Checkpoint, FromSnapshotWithParams, Repository, Snapshot};
use enclave_core::{
DecryptionshareCreated, Die, E3id, EnclaveEvent, EventBus, OrderedSet, PlaintextAggregated,
Expand Down Expand Up @@ -50,65 +51,69 @@ struct ComputeAggregate {
pub struct PlaintextAggregator {
fhe: Arc<Fhe>,
bus: Addr<EventBus>,
store: Repository<PlaintextAggregatorState>,
sortition: Addr<Sortition>,
e3_id: E3id,
state: PlaintextAggregatorState,
state: Persistable<PlaintextAggregatorState>,
src_chain_id: u64,
}

pub struct PlaintextAggregatorParams {
pub fhe: Arc<Fhe>,
pub bus: Addr<EventBus>,
pub store: Repository<PlaintextAggregatorState>,
pub sortition: Addr<Sortition>,
pub e3_id: E3id,
pub src_chain_id: u64,
}

impl PlaintextAggregator {
pub fn new(params: PlaintextAggregatorParams, state: PlaintextAggregatorState) -> Self {
pub fn new(
params: PlaintextAggregatorParams,
state: Persistable<PlaintextAggregatorState>,
) -> Self {
PlaintextAggregator {
fhe: params.fhe,
bus: params.bus,
store: params.store,
sortition: params.sortition,
e3_id: params.e3_id,
src_chain_id: params.src_chain_id,
state,
}
}

pub fn add_share(&mut self, share: Vec<u8>) -> Result<PlaintextAggregatorState> {
let PlaintextAggregatorState::Collecting {
threshold_m,
shares,
ciphertext_output,
..
} = &mut self.state
else {
return Err(anyhow::anyhow!("Can only add share in Collecting state"));
};

shares.insert(share);
if shares.len() == *threshold_m {
return Ok(PlaintextAggregatorState::Computing {
shares: shares.clone(),
ciphertext_output: ciphertext_output.to_vec(),
});
}

Ok(self.state.clone())
pub fn add_share(&mut self, share: Vec<u8>) -> Result<()> {
self.state.try_mutate(|mut state| {
let PlaintextAggregatorState::Collecting {
threshold_m,
shares,
ciphertext_output,
..
} = &mut state
else {
return Err(anyhow::anyhow!("Can only add share in Collecting state"));
};

shares.insert(share);

if shares.len() == *threshold_m {
return Ok(PlaintextAggregatorState::Computing {
shares: shares.clone(),
ciphertext_output: ciphertext_output.to_vec(),
});
}

Ok(state)
})
}

pub fn set_decryption(&mut self, decrypted: Vec<u8>) -> Result<PlaintextAggregatorState> {
let PlaintextAggregatorState::Computing { shares, .. } = &mut self.state else {
return Ok(self.state.clone());
};
pub fn set_decryption(&mut self, decrypted: Vec<u8>) -> Result<()> {
self.state.try_mutate(|mut state| {
let PlaintextAggregatorState::Computing { shares, .. } = &mut state else {
return Ok(state.clone());
};
let shares = shares.to_owned();

let shares = shares.to_owned();

Ok(PlaintextAggregatorState::Complete { decrypted, shares })
Ok(PlaintextAggregatorState::Complete { decrypted, shares })
})
}
}

Expand All @@ -131,9 +136,9 @@ impl Handler<DecryptionshareCreated> for PlaintextAggregator {
type Result = ResponseActFuture<Self, Result<()>>;

fn handle(&mut self, event: DecryptionshareCreated, _: &mut Self::Context) -> Self::Result {
let PlaintextAggregatorState::Collecting {
let Some(PlaintextAggregatorState::Collecting {
threshold_m, seed, ..
} = self.state
}) = self.state.get()
else {
error!(state=?self.state, "Aggregator has been closed for collecting.");
return Box::pin(fut::ready(Ok(())));
Expand Down Expand Up @@ -165,14 +170,13 @@ impl Handler<DecryptionshareCreated> for PlaintextAggregator {
}

// add the keyshare and
act.state = act.add_share(decryption_share)?;
act.checkpoint();
act.add_share(decryption_share)?;

// Check the state and if it has changed to the computing
if let PlaintextAggregatorState::Computing {
if let Some(PlaintextAggregatorState::Computing {
shares,
ciphertext_output,
} = &act.state
}) = &act.state.get()
{
ctx.notify(ComputeAggregate {
shares: shares.clone(),
Expand All @@ -195,8 +199,7 @@ impl Handler<ComputeAggregate> for PlaintextAggregator {
})?;

// Update the local state
self.state = self.set_decryption(decrypted_output.clone())?;
self.checkpoint();
self.set_decryption(decrypted_output.clone())?;

// Dispatch the PlaintextAggregated event
let event = EnclaveEvent::from(PlaintextAggregated {
Expand All @@ -217,26 +220,3 @@ impl Handler<Die> for PlaintextAggregator {
ctx.stop()
}
}

impl Snapshot for PlaintextAggregator {
type Snapshot = PlaintextAggregatorState;

fn snapshot(&self) -> Result<Self::Snapshot> {
Ok(self.state.clone())
}
}

#[async_trait]
impl FromSnapshotWithParams for PlaintextAggregator {
type Params = PlaintextAggregatorParams;

async fn from_snapshot(params: Self::Params, snapshot: Self::Snapshot) -> Result<Self> {
Ok(PlaintextAggregator::new(params, snapshot))
}
}

impl Checkpoint for PlaintextAggregator {
fn repository(&self) -> &Repository<PlaintextAggregatorState> {
&self.store
}
}
32 changes: 16 additions & 16 deletions packages/ciphernode/router/src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use aggregator::{
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cipher::Cipher;
use data::{FromSnapshotWithParams, Snapshot};
use data::{AutoPersist, FromSnapshotWithParams, Snapshot};
use enclave_core::{BusError, E3Requested, EnclaveErrorType, EnclaveEvent, EventBus};
use fhe::{Fhe, SharedRng};
use keyshare::{Keyshare, KeyshareParams};
Expand Down Expand Up @@ -228,22 +228,23 @@ impl E3Feature for PlaintextAggregatorFeature {
};

let e3_id = data.e3_id.clone();

let _ = ctx.set_plaintext(
let repo = ctx.repositories().plaintext(&e3_id);
let sync_state = repo.send(Some(PlaintextAggregatorState::init(
meta.threshold_m,
meta.seed,
data.ciphertext_output.clone(),
)));

ctx.set_plaintext(
PlaintextAggregator::new(
PlaintextAggregatorParams {
fhe: fhe.clone(),
bus: self.bus.clone(),
store: ctx.repositories().plaintext(&e3_id),
sortition: self.sortition.clone(),
e3_id,
e3_id: e3_id.clone(),
src_chain_id: meta.src_chain_id,
},
PlaintextAggregatorState::init(
meta.threshold_m,
meta.seed,
data.ciphertext_output.clone(),
),
sync_state,
)
.start(),
);
Expand All @@ -259,10 +260,11 @@ impl E3Feature for PlaintextAggregatorFeature {
return Ok(());
}

let store = ctx.repositories().plaintext(&snapshot.e3_id);
let repo = ctx.repositories().plaintext(&snapshot.e3_id);
let sync_state = repo.load().await?;

// No Snapshot returned from the store -> bail
let Some(snap) = store.read().await? else {
if !sync_state.has() {
return Ok(());
};

Expand All @@ -283,18 +285,16 @@ impl E3Feature for PlaintextAggregatorFeature {
return Ok(());
};

let value = PlaintextAggregator::from_snapshot(
let value = PlaintextAggregator::new(
PlaintextAggregatorParams {
fhe: fhe.clone(),
bus: self.bus.clone(),
store,
sortition: self.sortition.clone(),
e3_id: ctx.e3_id.clone(),
src_chain_id: meta.src_chain_id,
},
snap,
sync_state,
)
.await?
.start();

// send to context
Expand Down

0 comments on commit 52009ef

Please sign in to comment.