Skip to content

Commit

Permalink
fix(utils): update with unsafe address option
Browse files Browse the repository at this point in the history
  • Loading branch information
atticusofsparta committed Nov 22, 2024
1 parent 58f069f commit 8833fe6
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 21 deletions.
20 changes: 20 additions & 0 deletions \
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
fix(utils): update id validation utils

# Please enter the commit message for your changes. Lines starting
# with '#' will be ignored, and an empty message aborts the commit.
#
# Date: Thu Nov 21 09:40:06 2024 -0600
#
# On branch PE-7163-update-validate-utils
# Your branch is up to date with 'origin/PE-7163-update-validate-utils'.
#
# Changes to be committed:
# modified: CHANGELOG.md
# modified: spec/utils_spec.lua
# modified: src/common/balances.lua
# modified: src/common/controllers.lua
# modified: src/common/initialize.lua
# modified: src/common/main.lua
# modified: src/common/records.lua
# modified: src/common/utils.lua
#
6 changes: 3 additions & 3 deletions spec/utils_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ end)

describe("utils.isValidAOAddress", function()
it("should throw an error for invalid Arweave IDs", function()
local invalid = utils.isValidAOAddress("invalid-arweave-id-123")
local invalid = utils.isValidAOAddress("invalid-arweave-id-123", false)
assert.is_false(invalid)
end)

it("should not throw an error for a valid Arweave ID", function()
local valid = pcall(utils.isValidAOAddress, "0E7Ai_rEQ326_vLtgB81XHViFsLlcwQNqlT9ap24uQI")
local valid = pcall(utils.isValidAOAddress, "0E7Ai_rEQ326_vLtgB81XHViFsLlcwQNqlT9ap24uQI", false)
assert.is_true(valid)
end)

it("should validate eth address", function()
assert.is_true(utils.isValidAOAddress(testEthAddress))
assert.is_true(utils.isValidAOAddress(testEthAddress, false))
end)
end)

Expand Down
10 changes: 6 additions & 4 deletions src/common/balances.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ local balances = {}

--- Transfers the ANT to a specified wallet.
---@param to string - The wallet address to transfer the balance to.
---@param allowUnsafeAddresses boolean Whether to allow unsafe addresses
---@return table<string, integer>
function balances.transfer(to)
assert(utils.isValidAOAddress(to), "Invalid AO Address")
function balances.transfer(to, allowUnsafeAddresses)
assert(utils.isValidAOAddress(to, allowUnsafeAddresses), "Invalid AO Address")
Balances = { [to] = 1 }
--luacheck: ignore Owner Controllers
Owner = to
Expand All @@ -20,9 +21,10 @@ end

--- Retrieves the balance of a specified wallet.
---@param address string - The wallet address to retrieve the balance from.
---@param allowUnsafeAddresses boolean Whether to allow unsafe addresses
---@return integer - Returns the balance of the specified wallet.
function balances.balance(address)
assert(utils.isValidAOAddress(address), "Invalid AO Address")
function balances.balance(address, allowUnsafeAddresses)
assert(utils.isValidAOAddress(address, allowUnsafeAddresses), "Invalid AO Address")
local balance = Balances[address] or 0
return balance
end
Expand Down
10 changes: 6 additions & 4 deletions src/common/controllers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ local controllers = {}

--- Set a controller.
---@param controller string The controller to set.
---@param allowUnsafeAddresses boolean Whether to allow unsafe addresses
---@return string[]
function controllers.setController(controller)
assert(utils.isValidAOAddress(controller), "Invalid AO Address")
function controllers.setController(controller, allowUnsafeAddresses)
assert(utils.isValidAOAddress(controller, allowUnsafeAddresses), "Invalid AO Address")

for _, c in ipairs(Controllers) do
assert(c ~= controller, "Controller already exists")
Expand All @@ -18,9 +19,10 @@ end

--- Remove a controller.
---@param controller string The controller to remove.
---@param allowUnsafeAddresses boolean Whether to allow unsafe addresses
---@return string[]
function controllers.removeController(controller)
assert(utils.isValidAOAddress(controller), "Invalid AO Address")
function controllers.removeController(controller, allowUnsafeAddresses)
assert(utils.isValidAOAddress(controller, allowUnsafeAddresses), "Invalid AO Address")
local controllerExists = false

for i, v in ipairs(Controllers) do
Expand Down
11 changes: 6 additions & 5 deletions src/common/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ function ant.init()
createActionHandler(TokenSpecActionMap.Transfer, function(msg)
local recipient = msg.Tags.Recipient
utils.validateOwner(msg.From)
balances.transfer(recipient)
balances.transfer(recipient, msg.Tags["Allow-Unsafe-Addresses"])
if not msg.Cast then
ao.send(notices.debit(msg))
ao.send(notices.credit(msg))
end
end)

createActionHandler(TokenSpecActionMap.Balance, function(msg)
local balRes = balances.balance(msg.Tags.Recipient or msg.From)
local balRes = balances.balance(msg.Tags.Recipient or msg.From, msg.Tags["Allow-Unsafe-Addresses"])

ao.send({
Target = msg.From,
Expand Down Expand Up @@ -148,12 +148,12 @@ function ant.init()

createActionHandler(ActionMap.AddController, function(msg)
utils.assertHasPermission(msg.From)
return controllers.setController(msg.Tags.Controller)
return controllers.setController(msg.Tags.Controller, msg.Tags["Allow-Unsafe-Addresses"])
end)

createActionHandler(ActionMap.RemoveController, function(msg)
utils.assertHasPermission(msg.From)
return controllers.removeController(msg.Tags.Controller)
return controllers.removeController(msg.Tags.Controller, msg.Tags["Allow-Unsafe-Addresses"])
end)

createActionHandler(ActionMap.Controllers, function()
Expand Down Expand Up @@ -276,8 +276,9 @@ function ant.init()
createActionHandler(ActionMap.ApproveName, function(msg)
--- NOTE: this could be modified to allow specific users/controllers to create claims
utils.validateOwner(msg.From)

assert(utils.isValidArweaveAddress(msg.Tags["IO-Process-Id"]), "Invalid Arweave ID")
assert(utils.isValidAOAddress(msg.Tags.Recipient), "Invalid AO Address")
assert(utils.isValidAOAddress(msg.Tags.Recipient, msg.Tags["Allow-Unsafe-Addresses"]), "Invalid AO Address")

assert(msg.Tags.Name, "Name is required")

Expand Down
75 changes: 70 additions & 5 deletions src/common/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,63 @@ function utils.isValidEthAddress(address)
return type(address) == "string" and #address == 42 and string.match(address, "^0x[%x]+$") ~= nil
end

function utils.isValidUnsafeAddress(address)
if not address then
return false
end
local match = string.match(address, "^[%w_-]+$")
return match ~= nil and #address >= 1 and #address <= 128
end

--- Checks if an address is a valid AO address
--- @param url string|nil The address to check
--- @return boolean isValidAOAddress - whether the address is a valid AO address
function utils.isValidAOAddress(url)
return url and (utils.isValidArweaveAddress(url) or utils.isValidEthAddress(url)) or false
--- @param address string|nil The address to check
--- @param allowUnsafe boolean Whether to allow unsafe addresses, defaults to false
--- @return boolean isValidAddress - whether the address is valid, depending on the allowUnsafe flag
function utils.isValidAOAddress(address, allowUnsafe)
allowUnsafe = allowUnsafe or false -- default to false, only allow unsafe addresses if explicitly set
if not address then
return false
end
if allowUnsafe then
return utils.isValidUnsafeAddress(address)
end
return utils.isValidArweaveAddress(address) or utils.isValidEthAddress(address)
end

--- Converts an address to EIP-55 checksum format
--- Assumes address has been validated as a valid Ethereum address (see utils.isValidEthAddress)
--- Reference: https://eips.ethereum.org/EIPS/eip-55
--- @param address string The address to convert
--- @return string formattedAddress - the EIP-55 checksum formatted address
function utils.formatEIP55Address(address)
local hex = string.lower(string.sub(address, 3))

local hash = crypto.digest.keccak256(hex)
local hashHex = hash.asHex()

local checksumAddress = "0x"

for i = 1, #hashHex do
local hexChar = string.sub(hashHex, i, i)
local hexCharValue = tonumber(hexChar, 16)
local char = string.sub(hex, i, i)
if hexCharValue > 7 then
char = string.upper(char)
end
checksumAddress = checksumAddress .. char
end

return checksumAddress
end

--- Formats an address to EIP-55 checksum format if it is a valid Ethereum address
--- @param address string The address to format
--- @return string formattedAddress - the EIP-55 checksum formatted address
function utils.formatAddress(address)
if utils.isValidEthAddress(address) then
return utils.formatEIP55Address(address)
end
return address
end

---@param ttl integer
Expand Down Expand Up @@ -215,7 +267,20 @@ function utils.createHandler(tagName, tagValue, handler, position)
utils.camelCase(tagValue),
Handlers.utils.continue(Handlers.utils.hasMatchingTag(tagName, tagValue)),
function(msg)
-- sometimes the message id is not present on dryrun
-- handling for eth EIP-55 format, returns address if is not eth address
msg.From = utils.formatAddress(msg.From)
local knownAddressTags = {
"Recipient",
"Controller",
}
for _, tName in ipairs(knownAddressTags) do
-- Format all incoming addresses
msg.Tags[tName] = msg.Tags[tName] and utils.formatAddress(msg.Tags[tName]) or nil
-- aos assigns tag values to the base message level as well
msg[tName] = msg[tName] and utils.formatAddress(msg[tName]) or nil
end

-- sometimes the message id is not present on dryrun so we add a stub string to prevent issues with concat string
print("Handling Action [" .. msg.Id or "no-msg-id" .. "]: " .. tagValue)
local prevOwner = tostring(Owner)
local prevControllers = utils.deepCopy(Controllers)
Expand Down
2 changes: 2 additions & 0 deletions test/registry.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ describe('Registry Updates', async () => {
From: STUB_ADDRESS,
});

console.dir(result, { depth: null });

const message = result.Messages[1]?.Tags.find(
(tag) => tag.name === 'Action' && tag.value === 'Credit-Notice',
);
Expand Down

0 comments on commit 8833fe6

Please sign in to comment.