diff --git a/src/redis_celery/Makefile b/src/redis_celery/Makefile deleted file mode 100644 index d5b9b862..00000000 --- a/src/redis_celery/Makefile +++ /dev/null @@ -1,7 +0,0 @@ -run_server: - uvicorn main:app --reload - -run_client: - streamlit run streamlit_frontend.py --server.fileWatcherType none --browser.gatherUsageStats False - -run_app: run_server run_client diff --git a/src/redis_celery/__init__.py b/src/redis_celery/__init__.py deleted file mode 100644 index f9fa546d..00000000 --- a/src/redis_celery/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .gpt3_api import get_description -from .main import * -from .model import AlbumModel -from .streamlit_frontend import main -from .utils import load_yaml -from .gcp import bigquery, cloud_storage, error -from .worker import * diff --git a/src/redis_celery/config/gcp.json b/src/redis_celery/config/gcp.json deleted file mode 100644 index e69de29b..00000000 diff --git a/src/redis_celery/config/private.yaml b/src/redis_celery/config/private.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/src/redis_celery/config/public.yaml b/src/redis_celery/config/public.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/src/redis_celery/config/translation.yaml b/src/redis_celery/config/translation.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/src/redis_celery/gcp/__init__.py b/src/redis_celery/gcp/__init__.py deleted file mode 100644 index add9f80d..00000000 --- a/src/redis_celery/gcp/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .bigquery import BigQueryLogger -from .cloud_storage import GCSUploader -from .error import ErrorReporter diff --git a/src/redis_celery/gcp/bigquery.py b/src/redis_celery/gcp/bigquery.py deleted file mode 100644 index 51d752fe..00000000 --- a/src/redis_celery/gcp/bigquery.py +++ /dev/null @@ -1,25 +0,0 @@ -# Google -from google.oauth2 import service_account -from google.cloud import bigquery - - -class BigQueryLogger: - def __init__(self, gcp_config): - self.bigquery_config = gcp_config["bigquery"] - - self.project_id = gcp_config["project_id"] - self.dataset_id = self.bigquery_config["dataset_id"] - - credentials = service_account.Credentials.from_service_account_file( - gcp_config["credentials_path"] - ) - - self.client = bigquery.Client(credentials=credentials) - - def log(self, content, id_name): - table_id = f"{self.project_id}.{self.dataset_id}.{self.bigquery_config['table_id'][id_name]}" - - errors = self.client.insert_rows_json(table_id, [content]) - - if errors: - print(f"Encountered errors while inserting rows: {errors}") diff --git a/src/redis_celery/gcp/cloud_storage.py b/src/redis_celery/gcp/cloud_storage.py deleted file mode 100644 index b5e40488..00000000 --- a/src/redis_celery/gcp/cloud_storage.py +++ /dev/null @@ -1,37 +0,0 @@ -# Python built-in modules -import io - -# Google -from google.cloud import storage -from google.oauth2 import service_account - -# datetime -from datetime import datetime, timedelta - - -class GCSUploader: - def __init__(self, gcp_config): - credentials_path = gcp_config["credentials_path"] - self.bucket_name = gcp_config["cloud_storage"]["bucket_name"] - - credentials = service_account.Credentials.from_service_account_file( - credentials_path - ) - self.client = storage.Client(credentials=credentials) - - # Uploads image to GCS and returns the URL - def save_image_to_gcs(self, byte_arr, destination_blob_name): - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.blob(destination_blob_name) - - file_obj = io.BytesIO(byte_arr) - blob.upload_from_file(file_obj) - - # permanent_url = f"https://storage.cloud.google.com/{self.bucket_name}/{destination_blob_name}" - - user_expiration_time = datetime.now() + timedelta(days=30) - user_url = blob.generate_signed_url(expiration=user_expiration_time) - - print(f"File uploaded to {destination_blob_name}.") - - return user_url diff --git a/src/redis_celery/gcp/error.py b/src/redis_celery/gcp/error.py deleted file mode 100644 index 41f95111..00000000 --- a/src/redis_celery/gcp/error.py +++ /dev/null @@ -1,17 +0,0 @@ -# Google -from google.oauth2 import service_account -from google.cloud import error_reporting - - -class ErrorReporter: - def __init__(self, gcp_config): - credentials_path = gcp_config["credentials_path"] - - credentials = service_account.Credentials.from_service_account_file( - credentials_path - ) - - self.client = error_reporting.Client(credentials=credentials) - - def python_error(self): - self.client.report_exception() diff --git a/src/redis_celery/gpt3_api.py b/src/redis_celery/gpt3_api.py deleted file mode 100644 index 91e6d891..00000000 --- a/src/redis_celery/gpt3_api.py +++ /dev/null @@ -1,55 +0,0 @@ -# Python built-in modules -import os - -# OpenAI API -import openai - -# Built-in modules -from .utils import load_yaml - - -def get_description( - lyrics: str, artist_name: str, album_name: str, song_names: str -) -> str: - gpt_config = load_yaml( - os.path.join("src/redis_celery/config", "private.yaml"), "gpt" - ) - - # OpenAI API key - # https://platform.openai.com/ - openai.api_key = gpt_config["api_key"] - - # -- 공백, 줄바꿈 제거 - lyrics = lyrics.strip() - lyrics = lyrics.replace("\n\n", " ") - lyrics = lyrics.replace("\n", " ") - - # message - message = [ - f"Describe the atmosphere or vibe of these lyrics into 5 different words seperated with comma. They should be optimal for visualizing a atmosphere. \n\n{lyrics}", - f"Also describe a atmosphere using the following Artist name, Album name and Song names into 5 different words seperated with comma. They should be optimal for visualizing a atmosphere. Artist name : {artist_name} \n Album name : {album_name} \n Song names : {song_names}", - ] - - # Set up the API call - responses = [] - for idx in range(len(message)): - response = openai.ChatCompletion.create( - model=gpt_config["model"], - messages=[ - { - "role": gpt_config["role"], - "content": message[idx], - } - ], - max_tokens=gpt_config[ - "max_tokens" - ], # Adjust the value to control the length of the generated description - temperature=gpt_config[ - "temperature" - ], # Adjust the temperature to control the randomness of the output - n=gpt_config["n_response"], # Generate a single response - stop=gpt_config["stop"], # Stop generating text at any point - ) - responses.append(response["choices"][0]["message"]["content"]) - - return ",".join(responses) diff --git a/src/redis_celery/main.py b/src/redis_celery/main.py deleted file mode 100644 index 2057050f..00000000 --- a/src/redis_celery/main.py +++ /dev/null @@ -1,124 +0,0 @@ -# Python built-in modules -import io -import os -import base64 -import uuid -from datetime import datetime - -# Backend -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse - -from pydantic import BaseModel - -# Celery -from celery import Celery - -# Other modules -import numpy as np -from pytz import timezone - -# User Defined modules -from .gcp.bigquery import BigQueryLogger -from .gcp.error import ErrorReporter -from .utils import load_yaml - - -# Load config -gcp_config = load_yaml(os.path.join("src/redis_celery/config", "private.yaml"), "gcp") -redis_config = load_yaml( - os.path.join("src/redis_celery/config", "private.yaml"), "redis" -) -celery_config = load_yaml( - os.path.join("src/redis_celery/config", "public.yaml"), "celery" -) -public_config = load_yaml(os.path.join("src/redis_celery/config", "public.yaml")) - -# Start fastapi -app = FastAPI() - -# Initialize Celery -celery_app = Celery( - "tasks", - broker=redis_config["redis_server_ip"], - backend=redis_config["redis_server_ip"], -) - -bigquery_logger = BigQueryLogger(gcp_config) -error_reporter = ErrorReporter(gcp_config) - - -# Album input Schema -class AlbumInput(BaseModel): - song_names: str - artist_name: str - genre: str - album_name: str - lyric: str - - -# Review input Schema -class ReviewInput(BaseModel): - rating: int - comment: str - - -# REST API - Post ~/generate_cover -@app.post("/generate_cover") -async def generate_cover(album: AlbumInput): - # Generate a unique ID for this request - global request_id - request_id = str(uuid.uuid4()) - - # Request time - request_time = datetime.utcnow().astimezone(timezone("Asia/Seoul")).isoformat() - - # Push task to the Celery queue - task = celery_app.send_task("generate_cover", args=[album.dict(), request_id]) - - # Get result (this will block until the task is done) - task_result = task.get() - - album_log = { - "request_id": request_id, - "request_time": request_time, - "song_names": album.song_names, - "artist_name": album.artist_name, - "genre": album.genre, - "album_name": album.album_name, - "lyric": album.lyric, - "summarization": task_result["summarization"], - "image_urls": task_result["image_urls"], - "language": public_config["language"], - } - - # Log to BigQuery - bigquery_logger.log(album_log, "user_album") - - return {"images": task_result["image_urls"]} - - -# REST API - Post ~/review -@app.post("/review") -async def review(review: ReviewInput): - review_log = { - "request_id": request_id, - "request_time": datetime.utcnow() - .astimezone(timezone("Asia/Seoul")) - .isoformat(), - "rating": review.rating, - "comment": review.comment, - "language": public_config["language"], - } - - bigquery_logger.log(review_log, "user_review") - - return review - - -# Exception handling using google cloud -@app.exception_handler(Exception) -async def handle_exceptions(request: Request, exc: Exception): - error_reporter.python_error() - - return JSONResponse(status_code=500, content={"message": "Internal Server Error"}) diff --git a/src/redis_celery/model.py b/src/redis_celery/model.py deleted file mode 100644 index 9422ef1d..00000000 --- a/src/redis_celery/model.py +++ /dev/null @@ -1,41 +0,0 @@ -# huggingface - transformers -from transformers import CLIPTextModel, CLIPTokenizer, AutoModel, AutoTokenizer - -# huggingface - diffusers -from diffusers import StableDiffusionPipeline -from diffusers import UNet2DConditionModel - - -class AlbumModel: - def __init__(self, model_config: dict, lang: str, device: str): - self.device = device - self.model_config = model_config - if lang == "EN": - self.text_encoder = CLIPTextModel.from_pretrained( - "CompVis/stable-diffusion-v1-4", subfolder="text_encoder" - ) - self.tokenizer = CLIPTokenizer.from_pretrained( - "CompVis/stable-diffusion-v1-4", subfolder="tokenizer" - ) - elif lang == "KR": - self.text_encoder = AutoModel.from_pretrained("klue/roberta-base") - self.tokenizer = AutoTokenizer.from_pretrained( - "klue/roberta-base", use_fast=False - ) - - self.pipeline = self.get_model() - - def get_model(self) -> None: - pipeline = StableDiffusionPipeline.from_pretrained( - self.model_config["stable_diffusion"], - unet=UNet2DConditionModel.from_pretrained( - self.model_config["unet"], subfolder="unet" - ), - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - ) - pipeline = pipeline.to(self.device) - if self.model_config["xformers"]: - pipeline.enable_xformers_memory_efficient_attention() - - return pipeline diff --git a/src/redis_celery/streamlit_frontend.py b/src/redis_celery/streamlit_frontend.py deleted file mode 100644 index 7cfcf13f..00000000 --- a/src/redis_celery/streamlit_frontend.py +++ /dev/null @@ -1,146 +0,0 @@ -# Python built-in modules -import base64 -import os -import io -from io import BytesIO - -# Frontend -import streamlit as st - -# Other modules -import pandas as pd -import requests -from PIL import Image - -# Built-in modules -from .utils import load_yaml - - -def main(): - # Load config - request_config = load_yaml( - os.path.join("src/redis_celery/config", "private.yaml"), "request" - ) - public_config = load_yaml(os.path.join("src/redis_celery/config", "public.yaml")) - language = public_config["language"] - translation_config = load_yaml( - os.path.join("src/redis_celery/config", "translation.yaml"), language - ) - - # Frontend - st.title(translation_config["title"]) - - # 1. Input album information - st.header(translation_config["album_info"]) - - song_names = st.text_input( - translation_config["song_names"]["text"], - placeholder=translation_config["song_names"]["placeholder"], - ) - artist_name = st.text_input( - translation_config["artist_name"]["text"], - placeholder=translation_config["artist_name"]["placeholder"], - ) - genre = st.selectbox( - translation_config["genre"]["text"], - translation_config["genre"]["list"], - ) - album_name = st.text_input( - translation_config["album_name"]["text"], - placeholder=translation_config["album_name"]["placeholder"], - ) - lyric = st.text_area( - translation_config["lyric"]["text"], - placeholder=translation_config["lyric"]["placeholder"], - ) - - info = { - translation_config["info"][0]: song_names, - translation_config["info"][1]: artist_name, - translation_config["info"][2]: genre, - translation_config["info"][3]: album_name, - translation_config["info"][4]: lyric, - } - - # 2. Show info dataframe - info_df = pd.DataFrame( - list(info.values()), - index=list(info.keys()), - columns=[translation_config["dataframe"]["col"]], - ) - info_table = st.dataframe(info_df, use_container_width=True) - - # 3. Inference - request_info = { - request_config["info"][0]: song_names, - request_config["info"][1]: artist_name, - request_config["info"][2]: genre, - request_config["info"][3]: album_name, - request_config["info"][4]: lyric, - } - - st.header(translation_config["inference"]["header"]) - gen_button = st.button( - translation_config["inference"]["button_message"], use_container_width=True - ) - - # Set 'img' variable to session state - if "img" not in st.session_state: - st.session_state.img = ["", "", "", ""] - - # Create 2x2 grid - col1, col2 = st.columns(2) - cols = [col1, col2, col1, col2] - - if gen_button: - with st.spinner(translation_config["inference"]["wait_message"]): - # Call the FastAPI server to generate the album cover - response = requests.post(request_config["gen_address"], json=request_info) - if response.status_code == 200: - image_urls = response.json()["images"] - - # Assign images to the cells in the grid - for i, url in enumerate(image_urls): - response = requests.get(url) - img = Image.open(BytesIO(response.content)) - st.session_state.img[i] = img - cols[i].image(img, width=300) - else: - st.error(translation_config["inference"]["fail_meesage"]["generate"]) - else: - if st.session_state.img == ["", "", "", ""]: - for i in range(len(cols)): - cols[i].empty() - else: - for i in range(len(cols)): - cols[i].image(st.session_state.img[i], width=300) - - # 4. Review - with st.expander(translation_config["review"]["expander"]): - review_rating = st.select_slider( - translation_config["review"]["rating"], options=range(1, 6), value=5 - ) - st.markdown( - "
" - + "❤️" * review_rating - + "🖤" * (5 - review_rating) - + "
", - unsafe_allow_html=True, - ) - review_text = st.text_input( - translation_config["review"]["review_text"]["text"], - placeholder=translation_config["review"]["review_text"]["placeholder"], - label_visibility="hidden", - ) - review_btn = st.button(translation_config["review"]["button"]) - if review_btn: - review = {"rating": review_rating, "comment": review_text} - response = requests.post(request_config["review_address"], json=review) - if response.status_code == 200: - st.write(translation_config["review"]["message"]["success"]) - else: - st.error(translation_config["review"]["message"]["fail"]) - - -if __name__ == "__main__": - main() diff --git a/src/redis_celery/utils.py b/src/redis_celery/utils.py deleted file mode 100644 index aa8bd54f..00000000 --- a/src/redis_celery/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# Other modules -import yaml - - -def load_yaml(yaml_path: str, key: str = None) -> dict: - with open(yaml_path, "r") as config_file: - config = yaml.safe_load(config_file) - - return config[key] if key else config diff --git a/src/redis_celery/worker.py b/src/redis_celery/worker.py deleted file mode 100644 index 25622216..00000000 --- a/src/redis_celery/worker.py +++ /dev/null @@ -1,111 +0,0 @@ -# Python Built-in modules -import os -import io -import base64 - -# Pytorch -import torch -from torch import cuda - -# ETC -from PIL import Image -import numpy as np -from numpy import random - -# Celery -from celery import Celery -from celery import signals - -# User Defined modules -from .model import AlbumModel -from .gpt3_api import get_description -from .gcp.cloud_storage import GCSUploader -from .utils import load_yaml - - -# Load config -gcp_config = load_yaml(os.path.join("src/redis_celery/config", "private.yaml"), "gcp") -redis_config = load_yaml( - os.path.join("src/redis_celery/config", "private.yaml"), "redis" -) -celery_config = load_yaml( - os.path.join("src/redis_celery/config", "public.yaml"), "celery" -) -public_config = load_yaml(os.path.join("src/redis_celery/config", "public.yaml")) - - -# f'redis://{redis_config["host"]}:{redis_config["port"]}/{redis_config["db"]}' -# Initialize Celery -celery_app = Celery( - "tasks", - broker=redis_config["redis_server_ip"], - backend=redis_config["redis_server_ip"], -) -celery_app.conf.broker_connection_retry_on_startup = celery_config[ - "broker_connection_retry_on_startup" -] -celery_app.conf.worker_pool = celery_config["worker_pool"] - -gcs_uploader = GCSUploader(gcp_config) - - -@signals.worker_process_init.connect -def setup_worker_init(*args, **kwargs): - device = "cuda" if cuda.is_available() else "cpu" - global model - model = AlbumModel(public_config["model"], public_config["language"], device) - model.get_model() - - -@celery_app.task(name="generate_cover") -def generate_cover(album, request_id): - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - device = "cuda" if cuda.is_available() else "cpu" - - image_urls = [] - - summarization = get_description( - album["lyric"], album["artist_name"], album["album_name"], album["song_names"] - ) - - seeds = np.random.randint( - public_config["generate"]["max_seed"], size=public_config["generate"]["n_gen"] - ) - - for i, seed in enumerate(seeds): - generator = torch.Generator(device=device).manual_seed(int(seed)) - - # Generate Images - with torch.no_grad(): - image = model.pipeline( - prompt=f"A photo of a {album['genre']} album cover with a {summarization} atmosphere visualized.", - num_inference_steps=20, - generator=generator, - ).images[0] - - image = image.resize( - (public_config["generate"]["height"], public_config["generate"]["width"]) - ) - - # Convert to base64-encoded string - byte_arr = io.BytesIO() - image.save(byte_arr, format=public_config["generate"]["save_format"]) - byte_arr = byte_arr.getvalue() - base64_str = base64.b64encode(byte_arr).decode() - - # Upload to GCS - image_url = gcs_uploader.save_image_to_gcs( - byte_arr, - f"{request_id}_image_{i}.{public_config['generate']['save_format']}", - ) - image_urls.append(image_url) - - return { - "image_urls": image_urls, - "summarization": summarization, - } - - -# Start the worker -if __name__ == "__main__": - celery_app.start()