diff --git a/crates/factor-key-value/src/host.rs b/crates/factor-key-value/src/host.rs index 60d08b509d..efb473fb1b 100644 --- a/crates/factor-key-value/src/host.rs +++ b/crates/factor-key-value/src/host.rs @@ -283,6 +283,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch { keys: Vec, ) -> std::result::Result>)>, wasi_keyvalue::store::Error> { let store = self.get_store_wasi(bucket)?; + if keys.is_empty() { + return Ok(vec![]); + } store.get_many(keys).await.map_err(to_wasi_err) } @@ -293,6 +296,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch { key_values: Vec<(String, Vec)>, ) -> std::result::Result<(), wasi_keyvalue::store::Error> { let store = self.get_store_wasi(bucket)?; + if key_values.is_empty() { + return Ok(()); + } store.set_many(key_values).await.map_err(to_wasi_err) } @@ -303,6 +309,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch { keys: Vec, ) -> std::result::Result<(), wasi_keyvalue::store::Error> { let store = self.get_store_wasi(bucket)?; + if keys.is_empty() { + return Ok(()); + } store.delete_many(keys).await.map_err(to_wasi_err) } } diff --git a/crates/factor-key-value/src/util.rs b/crates/factor-key-value/src/util.rs index cea72c92b6..82dbb59611 100644 --- a/crates/factor-key-value/src/util.rs +++ b/crates/factor-key-value/src/util.rs @@ -260,10 +260,12 @@ impl Store for CachingStore { } } - let keys_and_values = self.inner.get_many(not_found).await?; - for (key, value) in keys_and_values { - found.push((key.clone(), value.clone())); - state.cache.put(key, value); + if !not_found.is_empty() { + let keys_and_values = self.inner.get_many(not_found).await?; + for (key, value) in keys_and_values { + found.push((key.clone(), value.clone())); + state.cache.put(key, value); + } } Ok(found) diff --git a/crates/key-value-aws/src/store.rs b/crates/key-value-aws/src/store.rs index 96a6e33828..ed38e1e0b2 100644 --- a/crates/key-value-aws/src/store.rs +++ b/crates/key-value-aws/src/store.rs @@ -1,3 +1,4 @@ +use core::str; use std::{ collections::HashMap, sync::{Arc, Mutex}, @@ -13,7 +14,10 @@ use aws_sdk_dynamodb::{ get_item::GetItemOutput, }, primitives::Blob, - types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest}, + types::{ + AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update, + WriteRequest, + }, Client, }; use spin_core::async_trait; @@ -148,7 +152,7 @@ enum CasState { Versioned(String), // Existing item without version Unversioned(Blob), - // Item was null when fetched during `current` + // Item was missing when fetched during `current`, expected to be new Unset, // Potentially new item -- `current` was never called to fetch version Unknown, @@ -210,15 +214,13 @@ impl Store for AwsDynamoStore { } async fn delete(&self, key: &str) -> Result<(), Error> { - if self.exists(key).await? { - self.client - .delete_item() - .table_name(self.table.as_str()) - .key(PK, AttributeValue::S(key.to_string())) - .send() - .await - .map_err(log_error)?; - } + self.client + .delete_item() + .table_name(self.table.as_str()) + .key(PK, AttributeValue::S(key.to_string())) + .send() + .await + .map_err(log_error)?; Ok(()) } @@ -241,16 +243,32 @@ impl Store for AwsDynamoStore { } async fn get_keys(&self) -> Result, Error> { - self.get_keys().await - } + let mut primary_keys = Vec::new(); - async fn get_many(&self, keys: Vec) -> Result>)>, Error> { - let mut results = Vec::with_capacity(keys.len()); + let mut scan_paginator = self + .client + .scan() + .table_name(self.table.as_str()) + .projection_expression(PK) + .into_paginator() + .send(); - if keys.is_empty() { - return Ok(results); + while let Some(output) = scan_paginator.next().await { + let scan_output = output.map_err(log_error)?; + if let Some(items) = scan_output.items { + for mut item in items { + if let Some(AttributeValue::S(pk)) = item.remove(PK) { + primary_keys.push(pk); + } + } + } } + Ok(primary_keys) + } + + async fn get_many(&self, keys: Vec) -> Result>)>, Error> { + let mut results = Vec::with_capacity(keys.len()); let mut keys_and_attributes_builder = KeysAndAttributes::builder() .projection_expression(format!("{PK},{VAL}")) .consistent_read(self.consistent_read); @@ -370,26 +388,66 @@ impl Store for AwsDynamoStore { } async fn increment(&self, key: String, delta: i64) -> Result { - let result = self + let GetItemOutput { item, .. } = self .client - .update_item() + .get_item() + .consistent_read(true) .table_name(self.table.as_str()) - .key(PK, AttributeValue::S(key)) - .update_expression("ADD #VAL :delta") - .expression_attribute_names("#VAL", VAL) - .expression_attribute_values(":delta", AttributeValue::N(delta.to_string())) - .return_values(aws_sdk_dynamodb::types::ReturnValue::UpdatedNew) + .key(PK, AttributeValue::S(key.clone())) + .projection_expression(VAL) .send() .await .map_err(log_error)?; - if let Some(updated_attributes) = result.attributes { - if let Some(AttributeValue::N(new_value)) = updated_attributes.get(VAL) { - return Ok(new_value.parse::().map_err(log_error))?; - } + let old_val = match item { + Some(mut current_item) => match current_item.remove(VAL) { + // We're expecting i64, so technically we could transmute but seems risky... + Some(AttributeValue::B(val)) => Some( + str::from_utf8(&val.into_inner()) + .map_err(log_error)? + .parse::() + .map_err(log_error)?, + ), + _ => None, + }, + None => None, + }; + + let new_val = old_val.unwrap_or(0) + delta; + + let mut update = Update::builder() + .table_name(self.table.as_str()) + .key(PK, AttributeValue::S(key)) + .update_expression("SET #VAL = :new_val") + .expression_attribute_names("#VAL", VAL) + .expression_attribute_values( + ":new_val", + AttributeValue::B(Blob::new(new_val.to_string().as_bytes())), + ); + + if let Some(old_val) = old_val { + update = update + .condition_expression("#VAL = :old_val") + .expression_attribute_values( + ":old_val", + AttributeValue::B(Blob::new(old_val.to_string().as_bytes())), + ) + } else { + update = update.condition_expression("attribute_not_exists (#VAL)") } - Err(Error::Other("Failed to increment value".into())) + self.client + .transact_write_items() + .transact_items( + TransactWriteItem::builder() + .update(update.build().map_err(log_error)?) + .build(), + ) + .send() + .await + .map_err(log_error)?; + + Ok(new_val) } async fn new_compare_and_swap( @@ -454,9 +512,7 @@ impl Cas for CompareAndSwap { /// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for /// optimistic concurrency or the previous item value async fn swap(&self, value: Vec) -> Result<(), SwapError> { - let mut update_item = self - .client - .update_item() + let mut update = Update::builder() .table_name(self.table.as_str()) .key(PK, AttributeValue::S(self.key.clone())) .update_expression("SET #VAL = :val ADD #VER :increment") @@ -468,22 +524,32 @@ impl Cas for CompareAndSwap { let state = self.state.lock().unwrap().clone(); match state { CasState::Versioned(version) => { - update_item = update_item + update = update .condition_expression("#VER = :ver") .expression_attribute_values(":ver", AttributeValue::N(version)); } CasState::Unversioned(old_val) => { - update_item = update_item + update = update .condition_expression("#VAL = :old_val") .expression_attribute_values(":old_val", AttributeValue::B(old_val)); } CasState::Unset => { - update_item = update_item.condition_expression("attribute_not_exists (#VAL)"); + update = update.condition_expression("attribute_not_exists (#VAL)"); } CasState::Unknown => (), }; - update_item + self.client + .transact_write_items() + .transact_items( + TransactWriteItem::builder() + .update( + update + .build() + .map_err(|e| SwapError::Other(format!("{e:?}")))?, + ) + .build(), + ) .send() .await .map_err(|e| SwapError::CasFailed(format!("{e:?}")))?; @@ -499,30 +565,3 @@ impl Cas for CompareAndSwap { self.key.clone() } } - -impl AwsDynamoStore { - async fn get_keys(&self) -> Result, Error> { - let mut primary_keys = Vec::new(); - - let mut scan_paginator = self - .client - .scan() - .table_name(self.table.as_str()) - .projection_expression(PK) - .into_paginator() - .send(); - - while let Some(output) = scan_paginator.next().await { - let scan_output = output.map_err(log_error)?; - if let Some(items) = scan_output.items { - for mut item in items { - if let Some(AttributeValue::S(pk)) = item.remove(PK) { - primary_keys.push(pk); - } - } - } - } - - Ok(primary_keys) - } -} diff --git a/crates/key-value-spin/src/store.rs b/crates/key-value-spin/src/store.rs index f18b60f7b7..7c3e50101d 100644 --- a/crates/key-value-spin/src/store.rs +++ b/crates/key-value-spin/src/store.rs @@ -307,20 +307,35 @@ impl Cas for CompareAndSwap { async fn swap(&self, value: Vec) -> Result<(), SwapError> { task::block_in_place(|| { let old_value = self.value.lock().unwrap(); - let rows_changed = self.connection - .lock() - .unwrap() - .prepare_cached( - "UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value", - ) - .map_err(log_cas_error)? - .execute(named_params! { - ":name": &self.name, - ":key": self.key, - ":old_value": old_value.clone().unwrap(), - ":new_value": value, - }) - .map_err(log_cas_error)?; + let mut conn = self.connection.lock().unwrap(); + let rows_changed = match old_value.clone() { + Some(old_val) => { + conn + .prepare_cached( + "UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value") + .map_err(log_cas_error)? + .execute(named_params! { + ":name": &self.name, + ":key": self.key, + ":old_value": old_val, + ":new_value": value, + }) + .map_err(log_cas_error)? + } + None => { + let tx = conn.transaction().map_err(log_cas_error)?; + let rows = tx + .prepare_cached( + "INSERT INTO spin_key_value (store, key, value) VALUES ($1, $2, $3) + ON CONFLICT(store, key) DO UPDATE SET value=$3", + ) + .map_err(log_cas_error)? + .execute(rusqlite::params![&self.name, self.key, value]) + .map_err(log_cas_error)?; + tx.commit().map_err(log_cas_error)?; + rows + } + }; // We expect only 1 row to be updated. If 0, we know that the underlying value has changed. if rows_changed == 1 {