From 1bc0443088351968582061200683eb14a6a7b136 Mon Sep 17 00:00:00 2001 From: Carlos Wu Fei Date: Sat, 21 Dec 2024 14:24:44 +0100 Subject: [PATCH] Fix all SQLModel vs Pydantic model compatibility issues --- api/bots/models.py | 107 +++++++++++- api/bots/routes.py | 120 ++++++++----- api/database/api_db.py | 78 ++++----- api/database/bot_crud.py | 75 ++++---- api/database/models/__init__.py | 1 - api/database/models/autotrade_table.py | 52 +++++- api/database/models/bot_table.py | 120 +++++++++++-- api/database/models/deal_table.py | 11 +- api/database/models/order_table.py | 33 ++-- api/database/models/paper_trading_table.py | 29 ---- api/database/paper_trading_crud.py | 5 +- api/deals/controllers.py | 6 +- api/deals/margin.py | 11 +- api/deals/models.py | 19 +- api/deals/spot.py | 6 +- api/paper_trading/routes.py | 2 +- api/streaming/streaming_controller.py | 3 +- api/tests/conftest.py | 8 + api/tests/model_mocks.py | 170 ++++++++++++++++++ api/tests/table_mocks.py | 99 +++++++++++ api/tests/test_bots.py | 193 +++++++++------------ api/tools/handle_error.py | 12 +- binquant | 2 +- 23 files changed, 821 insertions(+), 341 deletions(-) delete mode 100644 api/database/models/paper_trading_table.py create mode 100644 api/tests/model_mocks.py create mode 100644 api/tests/table_mocks.py diff --git a/api/bots/models.py b/api/bots/models.py index 87fe6791e..e30d615cb 100644 --- a/api/bots/models.py +++ b/api/bots/models.py @@ -1,13 +1,92 @@ -from typing import List -from database.models.order_table import OrderModel +from typing import List, Optional +from uuid import uuid4, UUID +from tools.enum_definitions import ( + BinanceKlineIntervals, + CloseConditions, + Status, + Strategy, +) from deals.models import DealModel -from database.models.bot_table import BotBase, BotTable -from pydantic import BaseModel, Field, field_validator +from pydantic import ( + BaseModel, + Field, + Json, + field_validator, +) +from database.utils import timestamp +from tools.handle_error import StandardResponse, IResponseBase +from tools.enum_definitions import DealType, OrderType -from tools.handle_error import StandardResponse + +class OrderModel(BaseModel): + order_type: OrderType + time_in_force: str + timestamp: Optional[int] + order_id: int + order_side: str + pair: str + qty: float + status: str + price: float + deal_type: DealType + + +class BotBase(BaseModel): + id: Optional[UUID] = Field(default_factory=uuid4) + pair: str + fiat: str = Field(default="USDC") + base_order_size: float = Field( + default=15, description="Min Binance 0.0001 BNB approx 15USD" + ) + candlestick_interval: BinanceKlineIntervals = Field( + default=BinanceKlineIntervals.fifteen_minutes, + ) + close_condition: CloseConditions = Field( + default=CloseConditions.dynamic_trailling, + ) + cooldown: int = Field( + default=0, + description="cooldown period in minutes before opening next bot with same pair", + ) + created_at: float = Field(default_factory=timestamp) + updated_at: float = Field(default_factory=timestamp) + dynamic_trailling: bool = Field(default=False) + logs: list[Json[str]] = Field(default=[]) + mode: str = Field(default="manual") + name: str = Field(default="Default bot") + status: Status = Field(default=Status.inactive) + stop_loss: float = Field( + default=0, description="If stop_loss > 0, allow for reversal" + ) + margin_short_reversal: bool = Field(default=False) + take_profit: float = Field(default=0) + trailling: bool = Field(default=False) + trailling_deviation: float = Field( + default=0, + ge=-1, + le=101, + description="Trailling activation (first take profit hit)", + ) + trailling_profit: float = Field(default=0) + strategy: Strategy = Field(default=Strategy.long) + total_commission: float = Field( + default=0, description="autoswitch to short_strategy" + ) + + @field_validator("id") + def deserialize_id(cls, v): + if isinstance(v, UUID): + return str(v) + return True class BotModel(BotBase): + """ + The way SQLModel works causes a lot of errors + if we combine (with inheritance) both Pydantic models + and SQLModels. they are not compatible. Thus the duplication + """ + deal: DealModel = Field(default_factory=DealModel) orders: List[OrderModel] = Field(default=[]) @@ -45,11 +124,23 @@ class BotModel(BotBase): class BotResponse(StandardResponse): - data: BotModel + data: Optional[BotModel] + + +class ActivePairsResponse(IResponseBase): + data: list[str] + + +class BotListResponse(IResponseBase): + """ + Model exclusively used to serialize + list of bots. + Has to be converted to BotModel to be able to + serialize nested table objects (deal, orders) + """ -class BotListResponse(StandardResponse): - data: list[BotTable] + data: list[BotModel] class ErrorsRequestBody(BaseModel): diff --git a/api/bots/routes.py b/api/bots/routes.py index 84ca71597..5c47fa410 100644 --- a/api/bots/routes.py +++ b/api/bots/routes.py @@ -1,19 +1,22 @@ from fastapi import APIRouter, Depends -from pydantic import ValidationError +from pydantic import ValidationError, TypeAdapter from sqlmodel import Session from tools.enum_definitions import Status from database.bot_crud import BotTableCrud from deals.controllers import CreateDealController -from database.models.bot_table import BotBase, BotTable from database.utils import get_session -from tools.handle_error import ( - api_response, +from bots.models import ( + BotModel, + BotResponse, + ErrorsRequestBody, + BotBase, + BotListResponse, + IResponseBase, + ActivePairsResponse, ) -from bots.models import BotModel, BotResponse, BotListResponse, ErrorsRequestBody from typing import List from tools.exceptions import BinanceErrors, BinbotErrors - bot_blueprint = APIRouter() @@ -31,32 +34,50 @@ def get( bots = BotTableCrud(session=session).get( status, start_date, end_date, no_cooldown, limit, offset ) - return api_response(detail="Bots found", data=bots) + # Has to be converted to BotModel to + # be able to serialize nested objects + ta = TypeAdapter(List[BotModel]) + data = ta.dump_python(bots) + return BotListResponse[List](message="Successfully found bots!", data=data) except ValidationError as error: - return api_response(detail=error.json(), error=1) + return BotResponse(message="Failed to find bots!", data=error.json(), error=1) -@bot_blueprint.get("/bot/active-pairs", tags=["bots"]) +@bot_blueprint.get( + "/bot/active-pairs", response_model=ActivePairsResponse, tags=["bots"] +) def get_active_pairs( session: Session = Depends(get_session), ): try: bot = BotTableCrud(session=session).get_active_pairs() - return api_response(detail="Active pairs found!", data=bot) - except ValueError as error: - return api_response(detail=f"Error retrieving active pairs: {error}", error=1) + if not bot: + return BotResponse(message="Bot not found.", error=1) + else: + ta = TypeAdapter(BotModel) + data = ta.dump_python(bot) + return ActivePairsResponse( + message="Successfully retrieved active pairs.", data=data + ) + + except ValidationError as error: + return BotResponse( + data=error.json(), error=1, message="Failed to find active pairs." + ) -@bot_blueprint.get("/bot/{id}", tags=["bots"]) +@bot_blueprint.get("/bot/{id}", response_model=BotResponse, tags=["bots"]) def get_one_by_id(id: str, session: Session = Depends(get_session)): try: bot = BotTableCrud(session=session).get_one(bot_id=id) if not bot: - return api_response(detail="Bot not found.", error=1) + return BotResponse(message="Bot not found.", error=1) else: - return api_response(detail="Bot found", data=bot) + ta = TypeAdapter(BotModel) + data = ta.dump_python(bot) + return BotResponse(message="Successfully found one bot.", data=data) except ValidationError as error: - return api_response(error.json()) + return BotResponse(message="Bot not found.", error=1, data=error.json()) @bot_blueprint.get("/bot/symbol/{symbol}", tags=["bots"]) @@ -64,11 +85,13 @@ def get_one_by_symbol(symbol: str, session: Session = Depends(get_session)): try: bot = BotTableCrud(session=session).get_one(bot_id=None, symbol=symbol) if not bot: - return api_response(detail="Bot not found.", error=1) + return BotResponse(message="Bot not found.", error=1) else: - return api_response(detail="Bot found", data=bot) + ta = TypeAdapter(BotModel) + data = ta.dump_python(bot) + return BotResponse(message="Successfully found one bot.", data=data) except ValidationError as error: - return api_response(error.json()) + return BotResponse(message="Bot not found.", error=1, data=error.json()) @bot_blueprint.post("/bot", tags=["bots"], response_model=BotResponse) @@ -77,10 +100,14 @@ def create( session: Session = Depends(get_session), ): try: - data = BotTableCrud(session=session).create(bot_item) - return api_response(detail="Bot created", data=data) + bot = BotTableCrud(session=session).create(bot_item) + ta = TypeAdapter(BotModel) + data = ta.dump_python(bot) + return BotResponse(message="Successfully created one bot.", data=data) except ValidationError as error: - return api_response(detail=error.json(), error=1) + return BotResponse( + message="Failed to create new bot", data=error.json(), error=1 + ) @bot_blueprint.put("/bot/{id}", tags=["bots"]) @@ -90,10 +117,13 @@ def edit( session: Session = Depends(get_session), ): try: + bot_item.id = id bot = BotTableCrud(session=session).save(bot_item) - return api_response(detail="Bot updated", data=bot) + ta = TypeAdapter(BotModel) + data = ta.dump_python(bot) + return BotResponse(message="Sucessfully edited bot", data=data) except ValidationError as error: - return api_response(detail=error.json(), error=1) + return BotResponse(message="Failed to edit bot", data=error.json(), error=1) @bot_blueprint.delete("/bot", tags=["bots"]) @@ -106,13 +136,13 @@ def delete( """ try: BotTableCrud(session=session).delete(id) - return api_response(detail="Bots deleted successfully.") + return IResponseBase(message="Sucessfully deleted bot.") except ValidationError as error: - return api_response(error.json()) + return BotResponse(message="Failed to delete bot", data=error.json(), error=1) @bot_blueprint.get("/bot/activate/{id}", tags=["bots"]) -async def activate_by_id(id: str, session: Session = Depends(get_session)): +def activate_by_id(id: str, session: Session = Depends(get_session)): """ Activate bot @@ -122,41 +152,42 @@ async def activate_by_id(id: str, session: Session = Depends(get_session)): """ bot = BotTableCrud(session=session).get_one(bot_id=id) if not bot: - return api_response(detail="Bot not found.") + return BotResponse(message="Bot not found.") - bot_instance = CreateDealController(bot) + bot_model = BotModel.model_construct(**bot.model_dump()) + bot_instance = CreateDealController(bot_model) try: data = bot_instance.open_deal() - return api_response(detail="Successfully activated bot!", data=data) + return BotResponse(message="Successfully activated bot!", data=data) except BinbotErrors as error: bot_instance.controller.update_logs(bot_id=id, log_message=error.message) - return api_response(detail=error.message, error=1) + return BotResponse(message=error.message, error=1) except BinanceErrors as error: bot_instance.controller.update_logs(bot_id=id, log_message=error.message) - return api_response(detail=error.message, error=1) + return BotResponse(message=error.message, error=1) -@bot_blueprint.delete("/bot/deactivate/{id}", tags=["bots"]) +@bot_blueprint.delete("/bot/deactivate/{id}", response_model=BotResponse, tags=["bots"]) def deactivation(id: str, session: Session = Depends(get_session)): """ Deactivation means closing all deals and selling to fiat. This is often used to prevent losses """ - bot_model = BotTableCrud(session=session).get_one(bot_id=id) - if not bot_model: - return api_response(detail="No active bot found.") + bot_table = BotTableCrud(session=session).get_one(bot_id=id) + if not bot_table: + return BotResponse(message="No active bot found.") + bot_model = BotModel.model_construct(**bot_table.model_dump()) deal_instance = CreateDealController(bot_model) try: - deal_instance.close_all() - return api_response(detail="Active orders closed, sold base asset, deactivated") + data = deal_instance.close_all() + return BotResponse(message="Active orders closed, sold base asset, deactivated", data=data) except BinbotErrors as error: - deal_instance.controller.update_logs(bot_id=id, log_message=error.message) - return api_response(error.message) + return BotResponse(message=error.message, error=1) -@bot_blueprint.post("/bot/errors/{bot_id}", tags=["bots"]) +@bot_blueprint.post("/bot/errors/{bot_id}", response_model=BotResponse, tags=["bots"]) def bot_errors( bot_id: str, bot_errors: ErrorsRequestBody, session: Session = Depends(get_session) ): @@ -172,6 +203,7 @@ def bot_errors( bot = BotTableCrud(session=session).update_logs( log_message=errors, bot_id=bot_id ) - return api_response(detail="Errors posted successfully.", data=bot) - except Exception as error: - return api_response(f"Error posting errors: {error}", error=1) + data = BotModel.model_construct(**bot.model_dump()) + return BotResponse(message="Errors posted successfully.", data=data) + except ValidationError as error: + return BotResponse(message="Failed to post errors", data=error.json(), error=1) diff --git a/api/database/api_db.py b/api/database/api_db.py index d5b022d8d..6d981d6e5 100644 --- a/api/database/api_db.py +++ b/api/database/api_db.py @@ -4,9 +4,8 @@ from database.models.autotrade_table import AutotradeTable, TestAutotradeTable from database.models.deal_table import DealTable from database.models.order_table import ExchangeOrderTable -from database.models.paper_trading_table import PaperTradingTable from database.models.user_table import UserTable -from database.models.bot_table import BotTable +from database.models.bot_table import BotTable, PaperTradingTable from sqlmodel import Session, SQLModel, select from tools.enum_definitions import ( AutotradeSettingsDocument, @@ -119,36 +118,33 @@ def create_dummy_bot(self): results = self.session.exec(statement) if results.first(): return - orders = [ - ExchangeOrderTable( - id=1, - order_id=123, - order_type="market", - time_in_force="GTC", - timestamp=0, - order_side="buy", - pair="BTCUSDT", - qty=0.000123, - status="filled", - price=1.222, - deal_type=DealType.base_order, - total_commission=0, - ), - ExchangeOrderTable( - id=2, - order_id=123, - order_type="limit", - time_in_force="GTC", - timestamp=0, - order_side="sell", - pair="BTCUSDT", - qty=0.000123, - status="filled", - price=1.222, - deal_type=DealType.take_profit, - total_commission=0, - ), - ] + self.session.close() + base_order = ExchangeOrderTable( + order_id=123, + order_type="market", + time_in_force="GTC", + timestamp=0, + order_side="buy", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type=DealType.base_order, + total_commission=0, + ) + take_profit_order = ExchangeOrderTable( + order_id=456, + order_type="limit", + time_in_force="GTC", + timestamp=0, + order_side="sell", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type=DealType.take_profit, + total_commission=0, + ) deal = DealTable( buy_price=0, buy_total_qty=0, @@ -182,18 +178,17 @@ def create_dummy_bot(self): margin_short_sell_timestamp=0, margin_short_loan_timestamp=0, ) - self.session.add(deal) bot = BotTable( pair="BTCUSDT", balance_size_to_use="1", fiat="USDC", base_order_size=15, - deal_id=deal.id, + deal=deal, cooldown=0, - logs='["Bot created"]', + logs=["Bot created"], mode="manual", name="Dummy bot", - orders=orders, + orders=[base_order, take_profit_order], status=Status.inactive, stop_loss=0, take_profit=2.3, @@ -205,19 +200,17 @@ def create_dummy_bot(self): short_sell_price=0, total_commission=0, ) - self.session.add(bot) - self.session.commit() paper_trading_bot = PaperTradingTable( pair="BTCUSDT", balance_size_to_use=1, balance_to_use=1, base_order_size=15, - deal_id=deal.id, + deal=deal, cooldown=0, - logs='["Paper trading bot created"]', + logs=["Paper trading bot created"], mode="manual", name="Dummy bot", - orders=orders, + orders=[base_order, take_profit_order], status=Status.inactive, stop_loss=0, take_profit=2.3, @@ -229,8 +222,11 @@ def create_dummy_bot(self): short_sell_price=0, total_commission=0, ) + self.session.add(bot) self.session.add(paper_trading_bot) self.session.commit() + self.session.refresh(bot) + self.session.refresh(paper_trading_bot) return bot def select_bot(self, pair): diff --git a/api/database/bot_crud.py b/api/database/bot_crud.py index f863c41cf..bab2c4fdc 100644 --- a/api/database/bot_crud.py +++ b/api/database/bot_crud.py @@ -4,10 +4,11 @@ from sqlmodel import Session, asc, desc, or_, select, case from time import time from bots.models import BotModel -from database.models.bot_table import BotBase, BotTable +from database.models.bot_table import BotTable from database.models.deal_table import DealTable from database.utils import independent_session from tools.enum_definitions import BinbotEnums, Status +from bots.models import BotBase class BotTableCrud: @@ -35,7 +36,7 @@ def update_logs( log_message: str, bot: Optional[BotModel] = None, bot_id: str | None = None, - ) -> BotModel: + ) -> BotTable: """ Update logs for a bot @@ -47,7 +48,10 @@ def update_logs( """ if bot_id: bot_obj = self.session.get(BotTable, bot_id) - bot = BotModel.model_validate(bot_obj) + if not bot_obj: + raise ValueError("Bot not found") + # No validation needed, this is a trusted source + bot = BotModel.model_construct(**bot_obj.model_dump()) elif not bot: raise ValueError("Bot id or BotModel object is required") @@ -57,14 +61,15 @@ def update_logs( elif len(current_logs) > 0: current_logs.append(log_message) - bot.logs = current_logs + bot_table_model = BotTable.model_validate(bot.model_dump()) + bot_table_model.logs = current_logs # db operations - self.session.add(bot) + self.session.add(bot_table_model) self.session.commit() - self.session.refresh(bot) + self.session.refresh(bot_table_model) self.session.close() - return bot + return bot_table_model def get( self, @@ -124,8 +129,7 @@ def get( statement.limit(limit).offset(offset) bots = self.session.exec(statement).all() - self.session.close() - + # self.session.close() return bots def get_one( @@ -133,12 +137,16 @@ def get_one( bot_id: str | None = None, symbol: str | None = None, status: Status | None = None, - ) -> BotModel: + ) -> BotTable: """ Get one bot by id or symbol """ if bot_id: - bot = self.session.get(BotTable, UUID(bot_id)) + santize_uuid = UUID(bot_id) + bot = self.session.get(BotTable, santize_uuid) + if not bot: + raise ValueError("Bot not found") + return bot elif symbol: if status: bot = self.session.exec( @@ -150,13 +158,12 @@ def get_one( bot = self.session.exec( select(BotTable).where(BotTable.pair == symbol) ).first() + if not bot: + raise ValueError("Bot not found") + return bot else: raise ValueError("Invalid bot id or symbol") - self.session.close() - bot_model = BotModel.model_validate(bot) - return bot_model - def create(self, data: BotBase) -> BotModel: """ Create a new bot @@ -167,22 +174,28 @@ def create(self, data: BotBase) -> BotModel: Args: - data: BotBase includes only flat properties (excludes deal and orders which are generated internally) """ - bot = BotModel.model_validate(data) - - # Ensure values are reset - bot.orders = [] - bot.logs = [] - bot.status = Status.inactive + bot = BotModel.model_construct(data) + deal = bot.deal # db operations - self.session.add(bot) + serialised_bot = BotTable.model_validate(data) + serialised_deal = DealTable.model_validate(deal) + + self.session.add(serialised_bot) + self.session.add(serialised_deal) + self.session.commit() - resulted_bot = self.session.get(BotTable, bot.id) + self.session.refresh(serialised_bot) + self.session.refresh(serialised_deal) self.session.close() - data = BotModel.model_validate(resulted_bot) - return data + resulted_bot = self.session.get(BotTable, data.id) + if resulted_bot: + bot_model = BotModel.model_validate(resulted_bot.model_dump()) + else: + bot_model = None + return bot_model - def save(self, data: BotModel) -> BotModel: + def save(self, data: BotModel) -> BotTable: """ Save bot @@ -193,15 +206,15 @@ def save(self, data: BotModel) -> BotModel: if not bot: raise ValueError("Bot not found") - # double check orders and deal are not overwritten - dumped_bot = data.model_dump(exclude_unset=True) + # due to incompatibility of SQLModel and Pydantic + dumped_bot = data.model_dump() bot.sqlmodel_update(dumped_bot) self.session.add(bot) self.session.commit() - resulted_bot = self.session.get(BotTable, bot.id) + self.session.refresh(bot) self.session.close() - data = BotModel.model_validate(resulted_bot) - return data + resulted_bot = self.get_one(bot_id=data.id) + return resulted_bot def delete(self, bot_ids: List[str] = Query(...)): """ diff --git a/api/database/models/__init__.py b/api/database/models/__init__.py index 62bfa43ad..f3aa96784 100644 --- a/api/database/models/__init__.py +++ b/api/database/models/__init__.py @@ -3,5 +3,4 @@ from .order_table import * # noqa from .deal_table import * # noqa from .bot_table import * # noqa -from .paper_trading_table import * # noqa from .autotrade_table import * # noqa diff --git a/api/database/models/autotrade_table.py b/api/database/models/autotrade_table.py index b1fe240bb..4eeeebb3e 100644 --- a/api/database/models/autotrade_table.py +++ b/api/database/models/autotrade_table.py @@ -49,8 +49,30 @@ class AutotradeTable(SQLModel, table=True): sa_column=Column(Enum(CloseConditions)), ) - class Config: - arbitrary_types_allowed = True + model_config = { + "from_attributes": True, + "use_enum_values": True, + "json_schema_extra": { + "description": "Autotrade global settings used by Binquant", + "examples": [ + { + "autotrade": True, + "base_order_size": 15, + "candlestick_interval": "15m", + "trailling": False, + "trailling_deviation": 3, + "trailling_profit": 2.4, + "stop_loss": 0, + "take_profit": 2.3, + "fiat": "USDC", + "max_request": 950, + "telegram_signals": True, + "max_active_autotrade_bots": 1, + "close_condition": "dynamic_trailling", + } + ], + }, + } class TestAutotradeTable(SQLModel, table=True): @@ -80,5 +102,27 @@ class TestAutotradeTable(SQLModel, table=True): telegram_signals: bool = Field(default=True) max_active_autotrade_bots: int = Field(default=1) - class Config: - arbitrary_types_allowed = True + model_config = { + "from_attributes": True, + "use_enum_values": True, + "json_schema_extra": { + "description": "Autotrade global settings used by Binquant", + "examples": [ + { + "autotrade": True, + "base_order_size": 15, + "candlestick_interval": "15m", + "trailling": False, + "trailling_deviation": 3, + "trailling_profit": 2.4, + "stop_loss": 0, + "take_profit": 2.3, + "fiat": "USDC", + "max_request": 950, + "telegram_signals": True, + "max_active_autotrade_bots": 1, + "close_condition": "dynamic_trailling", + } + ], + }, + } diff --git a/api/database/models/bot_table.py b/api/database/models/bot_table.py index 83362ee20..7338df6cd 100644 --- a/api/database/models/bot_table.py +++ b/api/database/models/bot_table.py @@ -1,6 +1,6 @@ from uuid import uuid4, UUID -from typing import TYPE_CHECKING, Optional -from pydantic import Json, PositiveInt +from typing import TYPE_CHECKING, Optional, List +from pydantic import Json from sqlalchemy import JSON, Column, Enum from database.utils import timestamp from tools.enum_definitions import ( @@ -18,7 +18,9 @@ from database.models.order_table import ExchangeOrderTable -class BotBase(SQLModel): +class BotTable(SQLModel, table=True): + __tablename__ = "bot" + id: Optional[UUID] = Field( default_factory=uuid4, primary_key=True, index=True, nullable=False, unique=True ) @@ -35,37 +37,117 @@ class BotBase(SQLModel): default=CloseConditions.dynamic_trailling, sa_column=Column(Enum(CloseConditions)), ) - cooldown: PositiveInt = Field( + cooldown: int = Field( default=0, description="cooldown period in minutes before opening next bot with same pair", ) created_at: float = Field(default_factory=timestamp) updated_at: float = Field(default_factory=timestamp) dynamic_trailling: bool = Field(default=False) - logs: list[Json[str]] = Field(default=[], sa_column=Column(JSON)) + logs: List[Json[str]] = Field(default=[], sa_column=Column(JSON)) mode: str = Field(default="manual") name: str = Field(default="Default bot") - # filled up internally by Exchange - status: str = Field(default=Status.inactive, sa_column=Column(Enum(Status))) + status: Status = Field(default=Status.inactive, sa_column=Column(Enum(Status))) stop_loss: float = Field( default=0, description="If stop_loss > 0, allow for reversal" ) margin_short_reversal: bool = Field(default=False) take_profit: float = Field(default=0) trailling: bool = Field(default=False) - trailling_deviation: float = Field(default=0, ge=-1, le=101) - # Trailling activation (first take profit hit) + trailling_deviation: float = Field( + default=0, + ge=-1, + le=101, + description="Trailling activation (first take profit hit)", + ) trailling_profit: float = Field(default=0) - strategy: str = Field(default=Strategy.long, sa_column=Column(Enum(Strategy))) - short_buy_price: float = Field(default=0) - # autoswitch to short_strategy - short_sell_price: float = Field(default=0) - total_commission: float = Field(default=0) + strategy: Strategy = Field(default=Strategy.long, sa_column=Column(Enum(Strategy))) + total_commission: float = Field( + default=0, description="autoswitch to short_strategy" + ) + # Table relationships filled up internally + orders: Optional[list["ExchangeOrderTable"]] = Relationship(back_populates="bot") + deal: Optional["DealTable"] = Relationship(back_populates="bot") -class BotTable(BotBase, table=True): - __tablename__ = "bot" + model_config = { + "from_attributes": True, + "use_enum_values": True, + } + + +class PaperTradingTable(SQLModel, table=True): + """ + Fake bots + + these trade without actual money, so qty + is usually 0 or 1. Orders are simulated + + This cannot inherit from a SQLModel base + because errors with candlestick_interval + already assigned to BotTable error + """ + + __tablename__ = "paper_trading" + + id: Optional[UUID] = Field( + default_factory=uuid4, primary_key=True, index=True, nullable=False, unique=True + ) + pair: str = Field(index=True) + fiat: str = Field(default="USDC", index=True) + base_order_size: float = Field( + default=15, description="Min Binance 0.0001 BNB approx 15USD" + ) + candlestick_interval: BinanceKlineIntervals = Field( + default=BinanceKlineIntervals.fifteen_minutes, + sa_column=Column(Enum(BinanceKlineIntervals)), + ) + close_condition: CloseConditions = Field( + default=CloseConditions.dynamic_trailling, + sa_column=Column(Enum(CloseConditions)), + ) + cooldown: int = Field( + default=0, + description="cooldown period in minutes before opening next bot with same pair", + ) + created_at: float = Field(default_factory=timestamp) + updated_at: float = Field(default_factory=timestamp) + dynamic_trailling: bool = Field(default=False) + logs: list[Json[str]] = Field(default=[], sa_column=Column(JSON)) + mode: str = Field(default="manual") + name: str = Field(default="Default bot") + status: Status = Field(default=Status.inactive, sa_column=Column(Enum(Status))) + stop_loss: float = Field( + default=0, description="If stop_loss > 0, allow for reversal" + ) + margin_short_reversal: bool = Field(default=False) + take_profit: float = Field(default=0) + trailling: bool = Field(default=False) + trailling_deviation: float = Field( + default=0, + ge=-1, + le=101, + description="Trailling activation (first take profit hit)", + ) + trailling_profit: float = Field(default=0) + strategy: Strategy = Field(default=Strategy.long, sa_column=Column(Enum(Strategy))) + short_buy_price: float = Field( + default=0, description="autoswitch to short_strategy" + ) + short_sell_price: float = Field( + default=0, description="autoswitch to short_strategy" + ) + total_commission: float = Field( + default=0, description="autoswitch to short_strategy" + ) + + # Table relationships filled up internally + deal: "DealTable" = Relationship(back_populates="paper_trading") + orders: Optional[list["ExchangeOrderTable"]] = Relationship( + back_populates="paper_trading" + ) - deal: "DealTable" = Relationship(back_populates="bot") - # filled up internally by Exchange - orders: list["ExchangeOrderTable"] = Relationship(back_populates="bot") + model_config = { + "from_attributes": True, + "use_enum_values": True, + } diff --git a/api/database/models/deal_table.py b/api/database/models/deal_table.py index 01f71621d..f44ae203c 100644 --- a/api/database/models/deal_table.py +++ b/api/database/models/deal_table.py @@ -5,8 +5,7 @@ # avoids circular imports if TYPE_CHECKING: - from database.models.bot_table import BotTable - from database.models.paper_trading_table import PaperTradingTable + from database.models.bot_table import BotTable, PaperTradingTable class DealBase(SQLModel): @@ -48,6 +47,9 @@ class DealBase(SQLModel): margin_short_sell_timestamp: int = Field(default=0) margin_short_loan_timestamp: int = Field(default=0) + # Relationships + # bot_id: Optional[UUID] = Field(default=None, foreign_key="bot.id") + # paper_trading_id: Optional[UUID] = Field(default=None, foreign_key="paper_trading.id") class DealTable(DealBase, table=True): """ @@ -59,8 +61,7 @@ class DealTable(DealBase, table=True): # Relationships bot_id: Optional[UUID] = Field(default=None, foreign_key="bot.id") + paper_trading_id: Optional[UUID] = Field(default=None, foreign_key="paper_trading.id") bot: Optional["BotTable"] = Relationship(back_populates="deal") - paper_trading_id: Optional[UUID] = Field( - default=None, foreign_key="paper_trading.id" - ) paper_trading: Optional["PaperTradingTable"] = Relationship(back_populates="deal") + pass diff --git a/api/database/models/order_table.py b/api/database/models/order_table.py index 179f0e8a9..c3d9ac5d0 100644 --- a/api/database/models/order_table.py +++ b/api/database/models/order_table.py @@ -1,21 +1,20 @@ from typing import TYPE_CHECKING, Optional -from uuid import UUID - from pydantic import ValidationInfo, field_validator from sqlalchemy import Column, Enum from tools.enum_definitions import DealType, OrderType from sqlmodel import Field, Relationship, SQLModel +from uuid import UUID, uuid4 + if TYPE_CHECKING: - from database.models.bot_table import BotTable - from database.models.paper_trading_table import PaperTradingTable + from database.models.bot_table import BotTable, PaperTradingTable -class OrderModel(SQLModel): +class OrderBase(SQLModel): order_type: OrderType time_in_force: str timestamp: Optional[int] - order_id: int = Field(nullable=True) + order_id: int = Field(nullable=False) order_side: str pair: str qty: float @@ -23,8 +22,18 @@ class OrderModel(SQLModel): price: float deal_type: DealType + # Relationships + bot_id: Optional[UUID] = Field(default=None, foreign_key="bot.id") + paper_trading_id: Optional[UUID] = Field( + default=None, foreign_key="paper_trading.id" + ) + + model_config = { + "use_enum_values": True, + } + -class ExchangeOrderTable(OrderModel, table=True): +class ExchangeOrderTable(OrderBase, table=True): """ Data provided by Crypto Exchange, therefore they should be all be strings @@ -38,20 +47,18 @@ class ExchangeOrderTable(OrderModel, table=True): __tablename__ = "exchange_order" - id: int = Field(primary_key=True) + id: UUID = Field( + primary_key=True, default_factory=uuid4, nullable=False, unique=True, index=True + ) price: float = Field(nullable=True) deal_type: DealType = Field(sa_column=Column(Enum(DealType))) total_commission: float = Field(nullable=True, default=0) # Relationships - bot_id: Optional[UUID] = Field(default=None, foreign_key="bot.id") bot: Optional["BotTable"] = Relationship(back_populates="orders") - paper_trading_id: Optional[UUID] = Field( - default=None, foreign_key="paper_trading.id" - ) paper_trading: Optional["PaperTradingTable"] = Relationship(back_populates="orders") - @field_validator("price", "qty", mode="before") + @field_validator("price", "qty") @classmethod def validate_str_numbers(cls, v, info: ValidationInfo): if isinstance(v, str): diff --git a/api/database/models/paper_trading_table.py b/api/database/models/paper_trading_table.py deleted file mode 100644 index fa31473e2..000000000 --- a/api/database/models/paper_trading_table.py +++ /dev/null @@ -1,29 +0,0 @@ -import json -from typing import TYPE_CHECKING -from pydantic import field_serializer, field_validator -from database.models.bot_table import BotBase -from tools.enum_definitions import ( - BinbotEnums, -) -from sqlmodel import Relationship - -# avoids circular imports -# https://sqlmodel.tiangolo.com/tutorial/code-structure/#hero-model-file -if TYPE_CHECKING: - from database.models.deal_table import DealTable - from database.models.order_table import ExchangeOrderTable - - -class PaperTradingTable(BotBase, table=True): - """ - Fake bots - - these trade without actual money, so qty - is usually 0 or 1. Orders are simualted - """ - - __tablename__ = "paper_trading" - - deal: "DealTable" = Relationship(back_populates="bot") - # filled up internally by Exchange - orders: list["ExchangeOrderTable"] = Relationship(back_populates="bot") diff --git a/api/database/paper_trading_crud.py b/api/database/paper_trading_crud.py index 095194c3d..da661e8ea 100644 --- a/api/database/paper_trading_crud.py +++ b/api/database/paper_trading_crud.py @@ -2,10 +2,9 @@ from typing import Union from sqlmodel import Session, or_, select, case, desc, asc -from database.models.bot_table import BotBase, BotTable -from bots.models import BotModel +from database.models.bot_table import BotTable, PaperTradingTable +from bots.models import BotModel, BotBase from database.models.deal_table import DealTable -from database.models.paper_trading_table import PaperTradingTable from database.utils import independent_session from tools.enum_definitions import BinbotEnums, Status diff --git a/api/deals/controllers.py b/api/deals/controllers.py index 6829f2b93..1c74da511 100644 --- a/api/deals/controllers.py +++ b/api/deals/controllers.py @@ -1,10 +1,8 @@ from typing import Type, Union -from database.models.order_table import OrderModel from database.bot_crud import BotTableCrud from database.paper_trading_crud import PaperTradingTableCrud -from database.models.paper_trading_table import PaperTradingTable -from database.models.bot_table import BotTable -from bots.models import BotModel +from database.models.bot_table import BotTable, PaperTradingTable +from bots.models import BotModel, OrderModel from deals.base import BaseDeal from deals.margin import MarginDeal from deals.models import DealModel diff --git a/api/deals/margin.py b/api/deals/margin.py index 174655199..038f4b49a 100644 --- a/api/deals/margin.py +++ b/api/deals/margin.py @@ -1,21 +1,18 @@ import logging from typing import Type, Union from urllib.error import HTTPError -from database.models.order_table import OrderModel -from deals.controllers import CreateDealController -from database.models.paper_trading_table import PaperTradingTable from database.bot_crud import BotTableCrud -from database.models.bot_table import BotTable +from database.models.bot_table import BotTable, PaperTradingTable from database.paper_trading_crud import PaperTradingTableCrud from tools.enum_definitions import CloseConditions, DealType, OrderSide, Strategy -from bots.models import BotModel +from bots.models import BotModel, OrderModel from tools.enum_definitions import Status -from deals.base import BaseDeal from tools.exceptions import BinanceErrors, MarginShortError from tools.round_numbers import round_numbers, supress_notation, round_numbers_ceiling +from deals.base import BaseDeal -class MarginDeal(CreateDealController): +class MarginDeal(BaseDeal): def __init__( self, bot: BotModel, diff --git a/api/deals/models.py b/api/deals/models.py index aa71d6cfc..6dac0c2e3 100644 --- a/api/deals/models.py +++ b/api/deals/models.py @@ -1,7 +1,5 @@ -from time import time from pydantic import BaseModel, Field, field_validator - class DealModel(BaseModel): """ Data model that is used for operations, @@ -10,7 +8,7 @@ class DealModel(BaseModel): buy_price: float = Field(default=0) buy_total_qty: float = Field(default=0) - buy_timestamp: float = time() * 1000 + buy_timestamp: float = Field(default=0) current_price: float = Field(default=0) sd: float = Field(default=0) avg_buy_price: float = Field(default=0) @@ -18,19 +16,18 @@ class DealModel(BaseModel): sell_timestamp: float = Field(default=0) sell_price: float = Field(default=0) sell_qty: float = Field(default=0) - trailling_stop_loss_price: float = Field(default=0) - # take_profit but for trailling, to avoid confusion, trailling_profit_price always be > trailling_stop_loss_price + trailling_stop_loss_price: float = Field(default=0, description="take_profit but for trailling, to avoid confusion, trailling_profit_price always be > trailling_stop_loss_price") trailling_profit_price: float = Field(default=0) stop_loss_price: float = Field(default=0) trailling_profit: float = Field(default=0) so_prices: float = Field(default=0) - post_closure_current_price: float = Field(default=0) original_buy_price: float = Field( - default=0 - ) # historical buy_price after so trigger + default=0, + description="historical buy_price after so trigger" + ) short_sell_price: float = Field(default=0) short_sell_qty: float = Field(default=0) - short_sell_timestamp: float = time() * 1000 + short_sell_timestamp: float = Field(default=0) # fields for margin trading margin_short_loan_principal: float = Field(default=0) @@ -42,8 +39,8 @@ class DealModel(BaseModel): margin_short_sell_qty: float = Field(default=0) margin_short_buy_back_timestamp: int = 0 margin_short_base_order: float = Field(default=0) - margin_short_sell_timestamp: int = 0 - margin_short_loan_timestamp: int = 0 + margin_short_sell_timestamp: int = Field(default=0) + margin_short_loan_timestamp: int = Field(default=0) @field_validator( "buy_price", diff --git a/api/deals/spot.py b/api/deals/spot.py index 013a5d8a9..7dc0ba632 100644 --- a/api/deals/spot.py +++ b/api/deals/spot.py @@ -1,9 +1,7 @@ import logging from typing import Type, Union -from database.models.order_table import OrderModel from database.bot_crud import BotTableCrud -from database.models.paper_trading_table import PaperTradingTable -from database.models.bot_table import BotTable +from database.models.bot_table import BotTable, PaperTradingTable from database.paper_trading_crud import PaperTradingTableCrud from deals.base import BaseDeal from deals.margin import MarginDeal @@ -14,7 +12,7 @@ Status, Strategy, ) -from bots.models import BotModel +from bots.models import BotModel, OrderModel class SpotLongDeal(BaseDeal): diff --git a/api/paper_trading/routes.py b/api/paper_trading/routes.py index b0ce8f00d..9ed68de97 100644 --- a/api/paper_trading/routes.py +++ b/api/paper_trading/routes.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session from tools.enum_definitions import Status -from database.models.paper_trading_table import PaperTradingTable +from database.models.bot_table import PaperTradingTable from database.paper_trading_crud import PaperTradingTableCrud from database.utils import get_session from deals.controllers import CreateDealController diff --git a/api/streaming/streaming_controller.py b/api/streaming/streaming_controller.py index 0cd4d5c44..a882ef754 100644 --- a/api/streaming/streaming_controller.py +++ b/api/streaming/streaming_controller.py @@ -4,8 +4,7 @@ from kafka import KafkaConsumer from bots.models import BotModel from database.autotrade_crud import AutotradeCrud -from database.models.bot_table import BotTable -from database.models.paper_trading_table import PaperTradingTable +from database.models.bot_table import BotTable, PaperTradingTable from database.paper_trading_crud import PaperTradingTableCrud from database.bot_crud import BotTableCrud from deals.controllers import CreateDealController diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 9578a2c87..988ec416f 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -18,3 +18,11 @@ def update_required(self, action): def mock_lifespan(): with patch("main.lifespan") as mock_lifespan: yield mock_lifespan + + +@fixture(scope="module") +def vcr_config(): + return { + # Replace the Authorization request header with "DUMMY" in cassettes + "filter_headers": [("authorization", "DUMMY")], + } diff --git a/api/tests/model_mocks.py b/api/tests/model_mocks.py new file mode 100644 index 000000000..5715bca7c --- /dev/null +++ b/api/tests/model_mocks.py @@ -0,0 +1,170 @@ +from bots.models import BotModel, OrderModel +from deals.models import DealModel + +id = "02031768-fbb9-4cc7-b549-642f15ab787b" +ts = 1733973560249.0 + +active_pairs = ["BTCUSDT", "ETHUSDT", "ADAUSDT"] + +deal_model = DealModel( + buy_price=1.3, + buy_total_qty=0, + buy_timestamp=0, + current_price=0, + sd=0, + avg_buy_price=0, + take_profit_price=0, + sell_timestamp=0, + sell_price=0, + sell_qty=0, + trailling_stop_loss_price=0, + trailling_profit_price=0, + stop_loss_price=0, + trailling_profit=0, + so_prices=0, + original_buy_price=0, + short_sell_price=0, + short_sell_qty=0, + short_sell_timestamp=0, + margin_short_loan_principal=0, + margin_loan_id=0, + hourly_interest_rate=0, + margin_short_sell_price=0, + margin_short_loan_interest=0, + margin_short_buy_back_price=0, + margin_short_sell_qty=0, + margin_short_buy_back_timestamp=0, + margin_short_base_order=0, + margin_short_sell_timestamp=0, + margin_short_loan_timestamp=0, +) + +initial_deal_model = DealModel( + buy_price=0, + buy_total_qty=0, + buy_timestamp=0, + current_price=0, + sd=0, + avg_buy_price=0, + take_profit_price=0, + sell_timestamp=0, + sell_price=0, + sell_qty=0, + trailling_stop_loss_price=0, + trailling_profit_price=0, + stop_loss_price=0, + trailling_profit=0, + so_prices=0, + original_buy_price=0, + short_sell_price=0, + short_sell_qty=0, + short_sell_timestamp=0, + margin_short_loan_principal=0, + margin_loan_id=0, + hourly_interest_rate=0, + margin_short_sell_price=0, + margin_short_loan_interest=0, + margin_short_buy_back_price=0, + margin_short_sell_qty=0, + margin_short_buy_back_timestamp=0, + margin_short_base_order=0, + margin_short_sell_timestamp=0, + margin_short_loan_timestamp=0, +) + +orders_model = [ + OrderModel( + id=1, + order_id=123, + order_type="MARKET", + time_in_force="GTC", + timestamp=0, + order_side="buy", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type="base_order", + total_commission=0, + ), + OrderModel( + id=2, + order_id=321, + order_type="LIMIT", + time_in_force="GTC", + timestamp=0, + order_side="sell", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type="take_profit", + total_commission=0, + ), +] + +mock_model_data = BotModel( + id=id, + pair="ADXUSDC", + fiat="USDC", + base_order_size=15, + candlestick_interval="15m", + dynamic_trailling=False, + close_condition="dynamic_trailling", + cooldown=360, + created_at=ts, + status="inactive", + margin_short_reversal=False, + logs=[], + mode="manual", + name="coinrule_fast_and_slow_macd_2024-04-20T22:28", + stop_loss=3.0, + take_profit=2.3, + trailling=True, + trailling_deviation=3.0, + trailling_profit=0, + strategy="long", + updated_at=ts, + orders=orders_model, + deal=deal_model, +) + +# new bots don't have orders because they are not activated +mock_model_data_without_orders = BotModel( + id=id, + pair="ADXUSDC", + fiat="USDC", + base_order_size=15, + candlestick_interval="15m", + status="inactive", + margin_short_reversal=False, + dynamic_trailling=False, + close_condition="dynamic_trailling", + cooldown=360, + created_at=ts, + logs=[], + mode="manual", + name="coinrule_fast_and_slow_macd_2024-04-20T22:28", + stop_loss=3.0, + take_profit=2.3, + trailling=True, + trailling_deviation=3.0, + trailling_profit=0, + strategy="long", + updated_at=ts, + orders=[], + deal=initial_deal_model, +) + + +class CreateDealControllerMock: + def __init__(self, bot: BotModel): + pass + + def open_deal(self): + bot_model = BotModel(**mock_model_data.model_dump()) + return bot_model + + def close_all(self): + bot_model = BotModel(**mock_model_data.model_dump()) + return bot_model diff --git a/api/tests/table_mocks.py b/api/tests/table_mocks.py new file mode 100644 index 000000000..79cc8e4cd --- /dev/null +++ b/api/tests/table_mocks.py @@ -0,0 +1,99 @@ +from database.models import BotTable, DealTable, ExchangeOrderTable +from tools.enum_definitions import DealType, OrderType +from uuid import UUID + +ts = 1733973560249.0 +id = "02031768-fbb9-4cc7-b549-642f15ab787b" + +orders = [ + ExchangeOrderTable( + id=id, + order_id=123, + order_type=OrderType.market, + time_in_force="GTC", + timestamp=0, + order_side="buy", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type=DealType.base_order, + total_commission=0, + ), + ExchangeOrderTable( + id=id, + order_id=321, + order_type=OrderType.limit, + time_in_force="GTC", + timestamp=0, + order_side="sell", + pair="BTCUSDT", + qty=0.000123, + status="filled", + price=1.222, + deal_type=DealType.take_profit, + total_commission=0, + ), +] + + +deal_table = DealTable( + buy_price=1.3, + buy_total_qty=0, + buy_timestamp=0, + current_price=0, + sd=0, + avg_buy_price=0, + take_profit_price=0, + sell_timestamp=0, + sell_price=0, + sell_qty=0, + trailling_stop_loss_price=0, + trailling_profit_price=0, + stop_loss_price=0, + trailling_profit=0, + so_prices=0, + original_buy_price=0, + short_sell_price=0, + short_sell_qty=0, + short_sell_timestamp=0, + margin_short_loan_principal=0, + margin_loan_id=0, + hourly_interest_rate=0, + margin_short_sell_price=0, + margin_short_loan_interest=0, + margin_short_buy_back_price=0, + margin_short_sell_qty=0, + margin_short_buy_back_timestamp=0, + margin_short_base_order=0, + margin_short_sell_timestamp=0, + margin_short_loan_timestamp=0, +) + + +mocked_db_data = BotTable( + id=UUID(id), + pair="ADXUSDC", + fiat="USDC", + base_order_size=15, + buy_price=1.222, + candlestick_interval="15m", + close_condition="dynamic_trailling", + dynamic_trailling=False, + cooldown=360, + created_at=ts, + logs=[], + mode="manual", + name="coinrule_fast_and_slow_macd_2024-04-20T22:28", + stop_loss=3.0, + take_profit=2.3, + trailling=True, + trailling_deviation=3.0, + trailling_profit=0.0, + strategy="long", + updated_at=ts, + status="inactive", + margin_short_reversal=False, + deal=deal_table, + orders=orders, +) diff --git a/api/tests/test_bots.py b/api/tests/test_bots.py index 7e3e8b02a..fe9e2715c 100644 --- a/api/tests/test_bots.py +++ b/api/tests/test_bots.py @@ -1,74 +1,17 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient -from pytest import fixture -from tools.enum_definitions import DealType, OrderType -from database.models import OrderModel -from database.models.bot_table import BotTable from database.utils import get_session from main import app -from unittest.mock import patch - -id = "02031768-fbb9-4cc7-b549-642f15ab787b" -ts = 1733973560249.0 - -active_pairs = ["BTCUSDT", "ETHUSDT", "ADAUSDT"] - -orders = [ - OrderModel( - id=1, - order_id=123, - order_type=OrderType.market, - time_in_force="GTC", - timestamp=0, - order_side="buy", - pair="BTCUSDT", - qty=0.000123, - status="filled", - price=1.222, - deal_type=DealType.base_order, - total_commission=0, - ), - OrderModel( - id=2, - order_id=321, - order_type=OrderType.limit, - time_in_force="GTC", - timestamp=0, - order_side="sell", - pair="BTCUSDT", - qty=0.000123, - status="filled", - price=1.222, - deal_type=DealType.take_profit, - total_commission=0, - ), -] - -mocked_db_data = BotTable( - id=id, - pair="ADXUSDC", - fiat="USDC", - base_order_size=50, - candlestick_interval="15m", - close_condition="dynamic_trailling", - cooldown=360, - created_at=ts, - logs=[], - mode="manual", - name="coinrule_fast_and_slow_macd_2024-04-20T22:28", - orders=[], - stop_loss=3.0, - take_profit=2.3, - trailling=True, - trailling_deviation=3.0, - trailling_profit=0.0, - strategy="long", - updated_at=ts, +from pytest import fixture +from tests.model_mocks import ( + mock_model_data, + id, + active_pairs, + mock_model_data_without_orders, + CreateDealControllerMock, ) - -CreateDealControllerMock = MagicMock() -CreateDealControllerMock.return_value = MagicMock() -CreateDealControllerMock.return_value.open_deal.return_value = mocked_db_data +from tests.table_mocks import mocked_db_data +from fastapi.encoders import jsonable_encoder @fixture() @@ -78,44 +21,19 @@ def client(pairs=False) -> TestClient: session_mock.exec.return_value.all.return_value = [mocked_db_data] session_mock.get.return_value = mocked_db_data session_mock.add.return_value = MagicMock(return_value=None) + session_mock.refresh.return_value = MagicMock(return_value=None) session_mock.commit.return_value = MagicMock(return_value=None) app.dependency_overrides[get_session] = lambda: session_mock client = TestClient(app) return client -@patch("database.models.bot_table.timestamp", lambda: ts) -def test_create_bot(client: TestClient): - payload = { - "pair": "ADXUSDC", - "fiat": "USDC", - "base_order_size": 50, - "candlestick_interval": "15m", - "close_condition": "dynamic_trailling", - "cooldown": 360, - "mode": "manual", - "name": "coinrule_fast_and_slow_macd_2024-04-20T22:28", - "stop_loss": 3.0, - "take_profit": 2.3, - "trailling": True, - "trailling_deviation": 3.0, - "trailling_profit": 0.0, - "strategy": "long", - } - - response = client.post("/bot", json=payload) - - assert response.status_code == 200 - content = response.json() - assert content["data"] == mocked_db_data.model_dump() - - def test_get_one_by_id(client: TestClient): response = client.get(f"/bot/{id}") assert response.status_code == 200 content = response.json() - assert content["data"] == mocked_db_data.model_dump() + assert content["data"] == jsonable_encoder(mock_model_data.model_dump()) def test_get_one_by_symbol(client: TestClient): @@ -124,7 +42,7 @@ def test_get_one_by_symbol(client: TestClient): assert response.status_code == 200 content = response.json() - assert content["data"] == mocked_db_data.model_dump() + assert content["data"] == jsonable_encoder(mock_model_data.model_dump()) def test_get_bots(client: TestClient): @@ -132,32 +50,73 @@ def test_get_bots(client: TestClient): assert response.status_code == 200 content = response.json() - assert content["data"] == [mocked_db_data.model_dump()] + mock_data = jsonable_encoder(mock_model_data.model_dump()) + # Avoid testing internal objects + assert content["data"] == [mock_data] + + +def test_create_bot(client: TestClient): + payload = { + "pair": "ADXUSDC", + "fiat": "USDC", + "base_order_size": 15, + "candlestick_interval": "15m", + "close_condition": "dynamic_trailling", + "cooldown": 360, + "created_at": 1733973560249.0, + "updated_at": 1733973560249.0, + "dynamic_trailling": False, + "logs": [], + "mode": "manual", + "name": "Default bot", + "status": "inactive", + "stop_loss": 3.0, + "margin_short_reversal": False, + "take_profit": 2.3, + "trailling": True, + "trailling_deviation": 3.0, + "trailling_profit": 0.0, + "strategy": "long", + "total_commission": 0.0, + } + + response = client.post("/bot", json=payload) + + assert response.status_code == 200 + content = response.json() + assert content["data"] == mock_model_data_without_orders.model_dump() def test_edit_bot(client: TestClient): payload = { "pair": "ADXUSDC", "fiat": "USDC", - "base_order_size": 50, + "base_order_size": 15, "candlestick_interval": "15m", "close_condition": "dynamic_trailling", "cooldown": 360, + "created_at": 1733973560249.0, + "updated_at": 1733973560249.0, + "dynamic_trailling": False, + "logs": [], "mode": "manual", "name": "coinrule_fast_and_slow_macd_2024-04-20T22:28", + "status": "inactive", "stop_loss": 3.0, + "margin_short_reversal": False, "take_profit": 2.3, "trailling": True, "trailling_deviation": 3.0, "trailling_profit": 0.0, "strategy": "long", + "total_commission": 0.0, } response = client.put(f"/bot/{id}", json=payload) assert response.status_code == 200 content = response.json() - assert content["data"] == mocked_db_data.model_dump() + assert content["data"] == jsonable_encoder(mock_model_data.model_dump()) def test_delete_bot(): @@ -172,7 +131,7 @@ def delete_with_payload(self, **kwargs): assert response.status_code == 200 content = response.json() - assert content["message"] == "Bots deleted successfully." + assert content["message"] == "Sucessfully deleted bot." @patch("bots.routes.CreateDealController", CreateDealControllerMock) @@ -181,22 +140,34 @@ def test_activate_by_id(client: TestClient): assert response.status_code == 200 content = response.json() - assert content["data"] == mocked_db_data.model_dump() - + assert content["data"] == mock_model_data.model_dump() -def test_active_pairs(): - # Only endpoint to not return a bot - session_mock = MagicMock() - session_mock.exec.return_value.all.return_value = ["BTCUSDT", "ETHUSDT", "ADAUSDT"] - session_mock.commit.return_value = MagicMock(return_value=None) - app.dependency_overrides[get_session] = lambda: session_mock - test_client = TestClient(app) - response = test_client.get("/bot/active-pairs") +@patch("bots.routes.CreateDealController", CreateDealControllerMock) +def test_deactivate(client: TestClient): + response = client.delete(f"/bot/deactivate/{id}") assert response.status_code == 200 content = response.json() - assert content["data"] == active_pairs + assert content["data"] == mock_model_data.model_dump() + + +# def test_active_pairs(): +# session_mock = MagicMock() +# session_mock.exec.return_value.first.return_value = mocked_db_data +# session_mock.exec.return_value.all.return_value = [mocked_db_data] +# session_mock.get.return_value = mocked_db_data +# session_mock.add.return_value = MagicMock(return_value=None) +# session_mock.refresh.return_value = MagicMock(return_value=None) +# session_mock.commit.return_value = MagicMock(return_value=None) +# app.dependency_overrides[get_session] = lambda: session_mock +# client = TestClient(app) + +# response = client.get("/bot/active-pairs") + +# assert response.status_code == 200 +# content = response.json() +# assert content["data"] == active_pairs def test_post_bot_errors_str(client: TestClient): @@ -216,7 +187,7 @@ def test_post_bot_errors_list(client: TestClient): """ Test submitting bot errors with a list of strings """ - payload = {"errors": ["failed to create bot", "failed to create bot"]} + payload = {"errors": ["failed to create bot", "failed to create deal"]} response = client.post(f"/bot/errors/{id}", json=payload) diff --git a/api/tools/handle_error.py b/api/tools/handle_error.py index c8210b906..598d429f1 100644 --- a/api/tools/handle_error.py +++ b/api/tools/handle_error.py @@ -2,7 +2,7 @@ import os import logging from time import sleep -from typing import Any, Optional, Union +from typing import Any, Union, TypeVar, Generic from bson import json_util from fastapi.responses import JSONResponse from pydantic import BaseModel @@ -36,7 +36,7 @@ def api_response(detail: str, data: Any = None, error: Union[str, int] = 0, stat """ body = {"message": detail} if data: - body["data"] = jsonable_encoder(data) + body["data"] = data if error: body["error"] = str(error) @@ -152,3 +152,11 @@ def encode_json(raw): class StandardResponse(BaseModel): message: str error: int = 0 + + +DataType = TypeVar("DataType") + + +class IResponseBase(BaseModel, Generic[DataType]): + message: str + error: int = 0 diff --git a/binquant b/binquant index 18a3f88a1..0e7671bc3 160000 --- a/binquant +++ b/binquant @@ -1 +1 @@ -Subproject commit 18a3f88a112ee1e0c7366a1de968a7f5d5f44e10 +Subproject commit 0e7671bc3e46d8306fd162e5cc782c65c590f2c2