From 5d75af2b1f81d6cf0581d45b6fe00cc6df1f9ea1 Mon Sep 17 00:00:00 2001 From: Malte Kliemann Date: Sun, 7 Apr 2024 20:21:50 +0200 Subject: [PATCH] Clean up type conversions --- zrml/neo-swaps/src/lib.rs | 60 +++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/zrml/neo-swaps/src/lib.rs b/zrml/neo-swaps/src/lib.rs index 87a16877a..2c94b0f48 100644 --- a/zrml/neo-swaps/src/lib.rs +++ b/zrml/neo-swaps/src/lib.rs @@ -267,6 +267,8 @@ mod pallet { LiquidityTreeError(LiquidityTreeError), /// The relative value of a new LP position is too low. MinRelativeLiquidityThresholdViolated, + /// Narrowing type conversion occurred. + NarrowingConversion, } #[derive(Decode, Encode, Eq, PartialEq, PalletError, RuntimeDebug, TypeInfo)] @@ -311,7 +313,7 @@ mod pallet { /// Depends on the implementation of `CompleteSetOperationsApi` and `ExternalFees`; when /// using the canonical implementations, the runtime complexity is `O(asset_count)`. #[pallet::call_index(0)] - #[pallet::weight(T::WeightInfo::buy(*asset_count as u32))] + #[pallet::weight(T::WeightInfo::buy((*asset_count).saturated_into()))] #[transactional] pub fn buy( origin: OriginFor, @@ -325,7 +327,7 @@ mod pallet { let asset_count_real = T::MarketCommons::market(&market_id)?.outcomes(); ensure!(asset_count == asset_count_real, Error::::IncorrectAssetCount); Self::do_buy(who, market_id, asset_out, amount_in, min_amount_out)?; - Ok(Some(T::WeightInfo::buy(asset_count as u32)).into()) + Ok(Some(T::WeightInfo::buy(asset_count.into())).into()) } /// Sell outcome tokens to the specified market. @@ -355,7 +357,7 @@ mod pallet { /// Depends on the implementation of `CompleteSetOperationsApi` and `ExternalFees`; when /// using the canonical implementations, the runtime complexity is `O(asset_count)`. #[pallet::call_index(1)] - #[pallet::weight(T::WeightInfo::sell(*asset_count as u32))] + #[pallet::weight(T::WeightInfo::sell((*asset_count).saturated_into()))] #[transactional] pub fn sell( origin: OriginFor, @@ -369,7 +371,7 @@ mod pallet { let asset_count_real = T::MarketCommons::market(&market_id)?.outcomes(); ensure!(asset_count == asset_count_real, Error::::IncorrectAssetCount); Self::do_sell(who, market_id, asset_in, amount_in, min_amount_out)?; - Ok(Some(T::WeightInfo::sell(asset_count as u32)).into()) + Ok(Some(T::WeightInfo::sell(asset_count.into())).into()) } /// Join the liquidity pool for the specified market. @@ -396,9 +398,9 @@ mod pallet { /// providers in the pool. #[pallet::call_index(2)] #[pallet::weight( - T::WeightInfo::join_in_place(max_amounts_in.len() as u32) - .max(T::WeightInfo::join_reassigned(max_amounts_in.len() as u32)) - .max(T::WeightInfo::join_leaf(max_amounts_in.len() as u32)) + T::WeightInfo::join_in_place(max_amounts_in.len().saturated_into()) + .max(T::WeightInfo::join_reassigned(max_amounts_in.len().saturated_into())) + .max(T::WeightInfo::join_leaf(max_amounts_in.len().saturated_into())) )] #[transactional] pub fn join( @@ -409,7 +411,11 @@ mod pallet { ) -> DispatchResultWithPostInfo { let who = ensure_signed(origin)?; let asset_count = T::MarketCommons::market(&market_id)?.outcomes(); - ensure!(max_amounts_in.len() == asset_count as usize, Error::::IncorrectVecLen); + let asset_count_usize: usize = asset_count.into(); + // Ensure that the conversion in the weight calculation doesn't saturate. + let _: u32 = + max_amounts_in.len().try_into().map_err(|_| Error::::NarrowingConversion)?; + ensure!(max_amounts_in.len() == asset_count_usize, Error::::IncorrectVecLen); Self::do_join(who, market_id, pool_shares_amount, max_amounts_in) } @@ -446,7 +452,7 @@ mod pallet { /// pool's liquidity tree, or, equivalently, `log_2(m)` where `m` is the number of liquidity /// providers in the pool. #[pallet::call_index(3)] - #[pallet::weight(T::WeightInfo::exit(min_amounts_out.len() as u32))] + #[pallet::weight(T::WeightInfo::exit(min_amounts_out.len().saturated_into()))] #[transactional] pub fn exit( origin: OriginFor, @@ -456,9 +462,12 @@ mod pallet { ) -> DispatchResultWithPostInfo { let who = ensure_signed(origin)?; let asset_count = T::MarketCommons::market(&market_id)?.outcomes(); - ensure!(min_amounts_out.len() == asset_count as usize, Error::::IncorrectVecLen); + let asset_count_u32: u32 = asset_count.into(); + let min_amounts_out_len: u32 = + min_amounts_out.len().try_into().map_err(|_| Error::::NarrowingConversion)?; + ensure!(min_amounts_out_len == asset_count_u32, Error::::IncorrectVecLen); Self::do_exit(who, market_id, pool_shares_amount_out, min_amounts_out)?; - Ok(Some(T::WeightInfo::exit(asset_count as u32)).into()) + Ok(Some(T::WeightInfo::exit(min_amounts_out_len)).into()) } /// Withdraw swap fees from the specified market. @@ -510,7 +519,7 @@ mod pallet { /// /// `O(n)` where `n` is the number of assets in the pool. #[pallet::call_index(5)] - #[pallet::weight(T::WeightInfo::deploy_pool(spot_prices.len() as u32))] + #[pallet::weight(T::WeightInfo::deploy_pool(spot_prices.len().saturated_into()))] #[transactional] pub fn deploy_pool( origin: OriginFor, @@ -521,9 +530,12 @@ mod pallet { ) -> DispatchResultWithPostInfo { let who = ensure_signed(origin)?; let asset_count = T::MarketCommons::market(&market_id)?.outcomes(); - ensure!(spot_prices.len() == asset_count as usize, Error::::IncorrectVecLen); + let asset_count_u32: u32 = asset_count.into(); + let spot_prices_len: u32 = + spot_prices.len().try_into().map_err(|_| Error::::NarrowingConversion)?; + ensure!(spot_prices_len == asset_count_u32, Error::::IncorrectVecLen); Self::do_deploy_pool(who, market_id, amount, spot_prices, swap_fee)?; - Ok(Some(T::WeightInfo::deploy_pool(asset_count as u32)).into()) + Ok(Some(T::WeightInfo::deploy_pool(spot_prices_len)).into()) } } @@ -672,8 +684,10 @@ mod pallet { ensure!(pool_shares_amount != Zero::zero(), Error::::ZeroAmount); let market = T::MarketCommons::market(&market_id)?; ensure!(market.status == MarketStatus::Active, Error::::MarketNotActive); - let asset_count = max_amounts_in.len() as u32; - ensure!(asset_count == market.outcomes() as u32, Error::::IncorrectAssetCount); + let asset_count_u16: u16 = + max_amounts_in.len().try_into().map_err(|_| Error::::NarrowingConversion)?; + let asset_count_u32: u32 = asset_count_u16.into(); + ensure!(asset_count_u16 == market.outcomes(), Error::::IncorrectAssetCount); let benchmark_info = Self::try_mutate_pool(&market_id, |pool| { let ratio = pool_shares_amount.bdiv_ceil(pool.liquidity_shares_manager.total_shares()?)?; @@ -712,9 +726,9 @@ mod pallet { Ok(benchmark_info) })?; let weight = match benchmark_info { - BenchmarkInfo::InPlace => T::WeightInfo::join_in_place(asset_count), - BenchmarkInfo::Reassigned => T::WeightInfo::join_reassigned(asset_count), - BenchmarkInfo::Leaf => T::WeightInfo::join_leaf(asset_count), + BenchmarkInfo::InPlace => T::WeightInfo::join_in_place(asset_count_u32), + BenchmarkInfo::Reassigned => T::WeightInfo::join_reassigned(asset_count_u32), + BenchmarkInfo::Leaf => T::WeightInfo::join_leaf(asset_count_u32), }; Ok((Some(weight)).into()) } @@ -832,9 +846,11 @@ mod pallet { let market = T::MarketCommons::market(&market_id)?; ensure!(market.status == MarketStatus::Active, Error::::MarketNotActive); ensure!(market.scoring_rule == ScoringRule::Lmsr, Error::::InvalidTradingMechanism); - let asset_count = spot_prices.len(); - ensure!(asset_count as u16 == market.outcomes(), Error::::IncorrectVecLen); - ensure!(market.outcomes() as u32 <= MaxAssets::get(), Error::::AssetCountAboveMax); + let asset_count_u16: u16 = + spot_prices.len().try_into().map_err(|_| Error::::NarrowingConversion)?; + let asset_count_u32: u32 = asset_count_u16.into(); + ensure!(asset_count_u16 == market.outcomes(), Error::::IncorrectVecLen); + ensure!(asset_count_u32 <= MaxAssets::get(), Error::::AssetCountAboveMax); ensure!(swap_fee >= MIN_SWAP_FEE.saturated_into(), Error::::SwapFeeBelowMin); ensure!(swap_fee <= T::MaxSwapFee::get(), Error::::SwapFeeAboveMax); ensure!(