Skip to content

Commit

Permalink
feat: IncrementBy update fn
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Oct 9, 2024
1 parent b807185 commit b5c2386
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions common/aws/dynamodb/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,12 @@ func (c *Client) PutItems(ctx context.Context, tableName string, items []Item) (
}

func (c *Client) UpdateItem(ctx context.Context, tableName string, key Key, item Item) (Item, error) {
err := ensureKeyAttributes(key, item)
if err != nil {
return nil, err
}

update := expression.UpdateBuilder{}
for itemKey, itemValue := range item {
// Ignore primary key updates
if _, ok := key[itemKey]; ok {
continue
}
update = update.Set(expression.Name(itemKey), expression.Value(itemValue))
}

Expand All @@ -158,25 +157,15 @@ func (c *Client) UpdateItem(ctx context.Context, tableName string, key Key, item
return resp.Attributes, err
}

func (c *Client) UpdateItemIncrement(ctx context.Context, tableName string, key Key, item Item) (Item, error) {
err := ensureKeyAttributes(key, item)
func (c *Client) IncrementBy(ctx context.Context, tableName string, key Key, itemKey string, itemValue uint64) (Item, error) {
// ADD numeric values
f, err := strconv.ParseFloat(strconv.FormatUint(itemValue, 10), 64)
if err != nil {
return nil, err
}

update := expression.UpdateBuilder{}
for itemKey, itemValue := range item {
// ADD numeric values
if n, ok := itemValue.(*types.AttributeValueMemberN); ok {
f, _ := strconv.ParseFloat(n.Value, 64)
update = update.Add(expression.Name(itemKey), expression.Value(aws.Float64(f)))

} else {
// For non-numeric values, use SET as before
update = update.Set(expression.Name(itemKey), expression.Value(itemValue))
}
}

update = update.Add(expression.Name(itemKey), expression.Value(aws.Float64(f)))
expr, err := expression.NewBuilder().WithUpdate(update).Build()
if err != nil {
return nil, err
Expand All @@ -191,7 +180,6 @@ func (c *Client) UpdateItemIncrement(ctx context.Context, tableName string, key
ReturnValues: types.ReturnValueUpdatedNew,
})
if err != nil {
fmt.Println("error updating item", err)
return nil, err
}

Expand Down Expand Up @@ -418,14 +406,3 @@ func (c *Client) readItems(ctx context.Context, tableName string, keys []Key) ([

return items, nil
}

func ensureKeyAttributes(key Key, item Item) error {
for itemKey := range item {
if _, ok := key[itemKey]; ok {
// Cannot update the key
return fmt.Errorf("cannot update key %s", itemKey)
}
}

return nil
}

0 comments on commit b5c2386

Please sign in to comment.