Skip to content

Commit

Permalink
Use transaction in increment and swap for better atomicity, remove un…
Browse files Browse the repository at this point in the history
…needed exists check, higher level filtering of empty get_all queries, sqlite handle null value before swap

Signed-off-by: Darwin Boersma <[email protected]>
  • Loading branch information
ogghead committed Nov 13, 2024
1 parent ed7299f commit 1cc913c
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 81 deletions.
9 changes: 9 additions & 0 deletions crates/factor-key-value/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
keys: Vec<String>,
) -> std::result::Result<Vec<(String, Option<Vec<u8>>)>, wasi_keyvalue::store::Error> {
let store = self.get_store_wasi(bucket)?;
if keys.is_empty() {
return Ok(vec![]);
}
store.get_many(keys).await.map_err(to_wasi_err)
}

Expand All @@ -293,6 +296,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
key_values: Vec<(String, Vec<u8>)>,
) -> std::result::Result<(), wasi_keyvalue::store::Error> {
let store = self.get_store_wasi(bucket)?;
if key_values.is_empty() {
return Ok(());
}
store.set_many(key_values).await.map_err(to_wasi_err)
}

Expand All @@ -303,6 +309,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
keys: Vec<String>,
) -> std::result::Result<(), wasi_keyvalue::store::Error> {
let store = self.get_store_wasi(bucket)?;
if keys.is_empty() {
return Ok(());
}
store.delete_many(keys).await.map_err(to_wasi_err)
}
}
Expand Down
10 changes: 6 additions & 4 deletions crates/factor-key-value/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ impl Store for CachingStore {
}
}

let keys_and_values = self.inner.get_many(not_found).await?;
for (key, value) in keys_and_values {
found.push((key.clone(), value.clone()));
state.cache.put(key, value);
if !not_found.is_empty() {
let keys_and_values = self.inner.get_many(not_found).await?;
for (key, value) in keys_and_values {
found.push((key.clone(), value.clone()));
state.cache.put(key, value);
}
}

Ok(found)
Expand Down
165 changes: 102 additions & 63 deletions crates/key-value-aws/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::str;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
Expand All @@ -13,7 +14,10 @@ use aws_sdk_dynamodb::{
get_item::GetItemOutput,
},
primitives::Blob,
types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest},
types::{
AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
WriteRequest,
},
Client,
};
use spin_core::async_trait;
Expand Down Expand Up @@ -148,7 +152,7 @@ enum CasState {
Versioned(String),
// Existing item without version
Unversioned(Blob),
// Item was null when fetched during `current`
// Item was missing when fetched during `current`, expected to be new
Unset,
// Potentially new item -- `current` was never called to fetch version
Unknown,
Expand Down Expand Up @@ -210,15 +214,13 @@ impl Store for AwsDynamoStore {
}

async fn delete(&self, key: &str) -> Result<(), Error> {
if self.exists(key).await? {
self.client
.delete_item()
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key.to_string()))
.send()
.await
.map_err(log_error)?;
}
self.client
.delete_item()
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key.to_string()))
.send()
.await
.map_err(log_error)?;
Ok(())
}

Expand All @@ -241,16 +243,32 @@ impl Store for AwsDynamoStore {
}

async fn get_keys(&self) -> Result<Vec<String>, Error> {
self.get_keys().await
}
let mut primary_keys = Vec::new();

async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
let mut results = Vec::with_capacity(keys.len());
let mut scan_paginator = self
.client
.scan()
.table_name(self.table.as_str())
.projection_expression(PK)
.into_paginator()
.send();

if keys.is_empty() {
return Ok(results);
while let Some(output) = scan_paginator.next().await {
let scan_output = output.map_err(log_error)?;
if let Some(items) = scan_output.items {
for mut item in items {
if let Some(AttributeValue::S(pk)) = item.remove(PK) {
primary_keys.push(pk);
}
}
}
}

Ok(primary_keys)
}

async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
let mut results = Vec::with_capacity(keys.len());
let mut keys_and_attributes_builder = KeysAndAttributes::builder()
.projection_expression(format!("{PK},{VAL}"))
.consistent_read(self.consistent_read);
Expand Down Expand Up @@ -370,26 +388,66 @@ impl Store for AwsDynamoStore {
}

async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
let result = self
let GetItemOutput { item, .. } = self
.client
.update_item()
.get_item()
.consistent_read(true)
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key))
.update_expression("ADD #VAL :delta")
.expression_attribute_names("#VAL", VAL)
.expression_attribute_values(":delta", AttributeValue::N(delta.to_string()))
.return_values(aws_sdk_dynamodb::types::ReturnValue::UpdatedNew)
.key(PK, AttributeValue::S(key.clone()))
.projection_expression(VAL)
.send()
.await
.map_err(log_error)?;

if let Some(updated_attributes) = result.attributes {
if let Some(AttributeValue::N(new_value)) = updated_attributes.get(VAL) {
return Ok(new_value.parse::<i64>().map_err(log_error))?;
}
let old_val = match item {
Some(mut current_item) => match current_item.remove(VAL) {
// We're expecting i64, so technically we could transmute but seems risky...
Some(AttributeValue::B(val)) => Some(
str::from_utf8(&val.into_inner())
.map_err(log_error)?
.parse::<i64>()
.map_err(log_error)?,
),
_ => None,
},
None => None,
};

let new_val = old_val.unwrap_or(0) + delta;

let mut update = Update::builder()
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(key))
.update_expression("SET #VAL = :new_val")
.expression_attribute_names("#VAL", VAL)
.expression_attribute_values(
":new_val",
AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
);

if let Some(old_val) = old_val {
update = update
.condition_expression("#VAL = :old_val")
.expression_attribute_values(
":old_val",
AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
)
} else {
update = update.condition_expression("attribute_not_exists (#VAL)")
}

Err(Error::Other("Failed to increment value".into()))
self.client
.transact_write_items()
.transact_items(
TransactWriteItem::builder()
.update(update.build().map_err(log_error)?)
.build(),
)
.send()
.await
.map_err(log_error)?;

Ok(new_val)
}

async fn new_compare_and_swap(
Expand Down Expand Up @@ -454,9 +512,7 @@ impl Cas for CompareAndSwap {
/// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
/// optimistic concurrency or the previous item value
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
let mut update_item = self
.client
.update_item()
let mut update = Update::builder()
.table_name(self.table.as_str())
.key(PK, AttributeValue::S(self.key.clone()))
.update_expression("SET #VAL = :val ADD #VER :increment")
Expand All @@ -468,22 +524,32 @@ impl Cas for CompareAndSwap {
let state = self.state.lock().unwrap().clone();
match state {
CasState::Versioned(version) => {
update_item = update_item
update = update
.condition_expression("#VER = :ver")
.expression_attribute_values(":ver", AttributeValue::N(version));
}
CasState::Unversioned(old_val) => {
update_item = update_item
update = update
.condition_expression("#VAL = :old_val")
.expression_attribute_values(":old_val", AttributeValue::B(old_val));
}
CasState::Unset => {
update_item = update_item.condition_expression("attribute_not_exists (#VAL)");
update = update.condition_expression("attribute_not_exists (#VAL)");
}
CasState::Unknown => (),
};

update_item
self.client
.transact_write_items()
.transact_items(
TransactWriteItem::builder()
.update(
update
.build()
.map_err(|e| SwapError::Other(format!("{e:?}")))?,
)
.build(),
)
.send()
.await
.map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
Expand All @@ -499,30 +565,3 @@ impl Cas for CompareAndSwap {
self.key.clone()
}
}

impl AwsDynamoStore {
async fn get_keys(&self) -> Result<Vec<String>, Error> {
let mut primary_keys = Vec::new();

let mut scan_paginator = self
.client
.scan()
.table_name(self.table.as_str())
.projection_expression(PK)
.into_paginator()
.send();

while let Some(output) = scan_paginator.next().await {
let scan_output = output.map_err(log_error)?;
if let Some(items) = scan_output.items {
for mut item in items {
if let Some(AttributeValue::S(pk)) = item.remove(PK) {
primary_keys.push(pk);
}
}
}
}

Ok(primary_keys)
}
}
43 changes: 29 additions & 14 deletions crates/key-value-spin/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,20 +307,35 @@ impl Cas for CompareAndSwap {
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
task::block_in_place(|| {
let old_value = self.value.lock().unwrap();
let rows_changed = self.connection
.lock()
.unwrap()
.prepare_cached(
"UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value",
)
.map_err(log_cas_error)?
.execute(named_params! {
":name": &self.name,
":key": self.key,
":old_value": old_value.clone().unwrap(),
":new_value": value,
})
.map_err(log_cas_error)?;
let mut conn = self.connection.lock().unwrap();
let rows_changed = match old_value.clone() {
Some(old_val) => {
conn
.prepare_cached(
"UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value")
.map_err(log_cas_error)?
.execute(named_params! {
":name": &self.name,
":key": self.key,
":old_value": old_val,
":new_value": value,
})
.map_err(log_cas_error)?
}
None => {
let tx = conn.transaction().map_err(log_cas_error)?;
let rows = tx
.prepare_cached(
"INSERT INTO spin_key_value (store, key, value) VALUES ($1, $2, $3)
ON CONFLICT(store, key) DO UPDATE SET value=$3",
)
.map_err(log_cas_error)?
.execute(rusqlite::params![&self.name, self.key, value])
.map_err(log_cas_error)?;
tx.commit().map_err(log_cas_error)?;
rows
}
};

// We expect only 1 row to be updated. If 0, we know that the underlying value has changed.
if rows_changed == 1 {
Expand Down

0 comments on commit 1cc913c

Please sign in to comment.