diff --git a/core/database.py b/core/database.py index 936d089..d9069d8 100644 --- a/core/database.py +++ b/core/database.py @@ -13,6 +13,7 @@ Base = declarative_base() + def get_db(): db = SessionLocal() try: diff --git a/core/schemas/tag.py b/core/schemas/tag.py index 1cefc4d..a3724d8 100644 --- a/core/schemas/tag.py +++ b/core/schemas/tag.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, field_validator -class TagInfo(BaseModel): +class TagInformation(BaseModel): id: int name: str @@ -17,7 +17,7 @@ def not_empty(cls, v): class TagList(BaseModel): - tag_list: list[TagInfo] = [] + tag_list: list[TagInformation] = [] class Tag(BaseModel): @@ -40,3 +40,13 @@ def not_empty(cls, v): if not v or not v.strip(): raise ValueError('빈 값은 허용되지 않습니다.') return v + + +class TagStrList(BaseModel): + tag_list: list[str] = [] + + @field_validator('tag_list') + def tag_limit(cls, v): + if len(v) < 1: + raise ValueError('태그는 1개 이상 추가해야 합니다.') + return v diff --git a/main.py b/main.py index d07e093..233ebf7 100644 --- a/main.py +++ b/main.py @@ -14,6 +14,7 @@ from internal import admin from routers.user import users_router as user +from routers.tag import tag_router as tag app = FastAPI() @@ -22,6 +23,7 @@ # routers app.include_router(user.router) +app.include_router(tag.router) @app.get("/") async def root(): diff --git a/routers/tag/__init__.py b/routers/tag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/tag_crud.py b/routers/tag/tag_crud.py similarity index 67% rename from util/tag_crud.py rename to routers/tag/tag_crud.py index 43251e1..59677ff 100644 --- a/util/tag_crud.py +++ b/routers/tag/tag_crud.py @@ -2,10 +2,10 @@ from sqlalchemy.orm import Session from sqlalchemy import func -from core.models.models import User -from core.schemas.tag import TagInfo, Tag, TagInfoCreate, TagList, TagCreate, AddTag +from core.models.models import Tag, TagInfo +from core.schemas.tag import TagCreate, TagInfoCreate, AddTag, TagList, TagInformation -# 하나의 유저 레시피에 추가 할 수 있는 태긔의 개수 +# 하나의 유저 레시피에 추가 할 수 있는 태그의 개수 TAG_LIMIT = 5 @@ -21,8 +21,19 @@ def create_tag(db: Session, tag_create: TagCreate): db.commit() -def get_existing_tag(db: Session, tag_name: str): - return db.query(TagInfo).filter(TagInfo.name == tag_name).first() +def get_tag_info(db: Session, tag_name: str): + _db_tag = db.query(TagInfo).filter(TagInfo.name == tag_name).first() + return _db_tag + + +def get_tag_info_list(db: Session, tag_list: list[str]): + # 태그 이름 리스트를 받아서 태그 info 리스트를 반환합니다. + # 만약 태그가 존재하지 않는다면 새로 생성합니다. + for tag_name in tag_list: + if not get_tag_info(db, tag_name): + create_tag_info(db, TagInfoCreate(name=tag_name)) + + return [get_tag_info(db, tag_name) for tag_name in tag_list] def get_similar_tag_list(db: Session, limit: int = 5, keyword: str = ''): @@ -43,7 +54,7 @@ def get_similar_tag_list(db: Session, limit: int = 5, keyword: str = ''): def create_tag_info(db: Session, tag_info_create: TagInfoCreate): - if get_existing_tag(db, tag_info_create.name): + if get_tag_info(db, tag_info_create.name): return db_tag_info = TagInfo( @@ -58,12 +69,12 @@ def add_tag(db: Session, tag_list: AddTag): if db.query(func.count(Tag.tag_id)).filter(Tag.recipe_id == tag_list.recipe_id).scalar() >= TAG_LIMIT: return for tag_name in tag_list.tag_list: - if not get_existing_tag(db, tag_name): + if not get_tag_info(db, tag_name): create_tag_info(db, TagInfoCreate(name=tag_name)) - tag_id = get_existing_tag(db, tag_name).id + tag_id = get_tag_info(db, tag_name).id create_tag(db, TagCreate(recipe_id=tag_list.recipe_id, tag_id=tag_id)) def delete_tag(db: Session, tag_id: int, recipe_id: int): db.query(Tag).filter(Tag.tag_id == tag_id, Tag.recipe_id == recipe_id).delete() - db.commit() \ No newline at end of file + db.commit() diff --git a/routers/tag/tag_router.py b/routers/tag/tag_router.py new file mode 100644 index 0000000..909b487 --- /dev/null +++ b/routers/tag/tag_router.py @@ -0,0 +1,33 @@ +from datetime import timedelta, datetime + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer +from jose import jwt +from sqlalchemy.orm import Session +from starlette import status +from starlette.config import Config + +import routers.tag.tag_crud as tag_crud +from core.database import get_db +from core.models import models +from core.schemas import tag +from dependencies import get_current_user +from routers.user.user_crud import pwd_context + +config = Config('.env') + +router = APIRouter( + prefix="/api/tag", + tags=["Tag"] +) + + +# Create +@router.post("/tag_list", response_model=tag.TagList) +def create_tag(tag_list: tag.TagStrList, db: Session = Depends(get_db)): + print(tag_crud.get_tag_info_list(db, tag_list.tag_list)) + pass + + #return {"tag_list": tag_crud.get_tag_list(db, tag_list.tag_list).tag_list} + + diff --git a/routers/user/user_crud.py b/routers/user/user_crud.py index ac80734..c410d02 100644 --- a/routers/user/user_crud.py +++ b/routers/user/user_crud.py @@ -20,8 +20,7 @@ def create_user(db: Session, user_create: UserCreate): def get_existing_user(db: Session, user_create: UserCreate): - _db_user = db.query(User).filter(User.stdId == user_create.stdId, User.hide == 0).first() - return + return get_user(db, user_create.stdId) def get_user(db: Session, student_id: str):