diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..41ba230 --- /dev/null +++ b/.env.example @@ -0,0 +1,24 @@ +# Model paths +CKPT_DIR=./app/model + +# Default generation parameters +DEFAULT_NUM_INFERENCE_STEPS=40 +DEFAULT_NUM_IMAGES_PER_PROMPT=1 +DEFAULT_GUIDANCE_SCALE=3.0 +DEFAULT_HEIGHT=480 +DEFAULT_WIDTH=704 +DEFAULT_NUM_FRAMES=121 +DEFAULT_FRAME_RATE=25 +DEFAULT_NEGATIVE_PROMPT="worst quality, inconsistent motion, blurry, jittery, distorted" + +# Model precision configuration +USE_BFLOAT16=true + +# Output configuration +OUTPUT_DIR=/app/outputs + +# Version configuration for Docker tags +VERSION_PREFIX=v + +# Logging level +LOG_LEVEL=INFO diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 0000000..ec449cc --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,63 @@ +name: Docker Build + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ltx-video-api: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4.1.7 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3.6.1 + + - name: Cache Docker layers + uses: actions/cache@v3 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-buildx- + + - name: Docker Login + uses: docker/login-action@v3.3.0 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Generate version tags + id: tags + run: | + # Get current date in YYYYMMDD format + DATE_TAG=$(date +'%Y%m%d') + # Get short SHA + SHA_TAG=$(echo ${{ github.sha }} | cut -c1-7) + # Load version prefix from .env + VERSION_PREFIX=$(grep VERSION_PREFIX .env | cut -d '=' -f2) + # Create tag list + TAGS="${{ secrets.DOCKERHUB_USERNAME }}/ltx-video-api:latest,${{ secrets.DOCKERHUB_USERNAME }}/ltx-video-api:${VERSION_PREFIX}${DATE_TAG},${{ secrets.DOCKERHUB_USERNAME }}/ltx-video-api:${VERSION_PREFIX}${SHA_TAG}" + echo "tags=${TAGS}" >> $GITHUB_OUTPUT + + - name: Build and push LTX Video API Docker image + uses: docker/build-push-action@v6.7.0 + with: + context: . + file: ./Dockerfile + push: true + tags: ${{ steps.tags.outputs.tags }} + cache-from: type=local,src=/tmp/.buildx-cache + cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max + platforms: linux/amd64 + + # Move cache to prevent cache growth + - name: Move cache + run: | + rm -rf /tmp/.buildx-cache + mv /tmp/.buildx-cache-new /tmp/.buildx-cache diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index a07ba7b..0000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Ruff - -on: [push] - -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10"] - steps: - - name: Checkout repository and submodules - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install ruff==0.2.2 black==24.2.0 - - name: Analyzing the code with ruff - run: | - ruff $(git ls-files '*.py') - - name: Verify that no Black changes are required - run: | - black --check $(git ls-files '*.py') diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..23a8c66 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +# Use NVIDIA CUDA base image +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + python3-pip \ + python3-dev \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first to leverage Docker cache +COPY requirements.txt . + +# Install Python dependencies +RUN pip3 install --no-cache-dir -r requirements.txt + +# Copy application files +COPY . . + +# Install LTX-Video package and inference dependencies +RUN pip3 install . && \ + pip3 install accelerate matplotlib "imageio[ffmpeg]" + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV CKPT_DIR=/app/models +ENV OUTPUT_DIR=/app/outputs + +# Create directories +RUN mkdir -p /app/models /app/outputs + +# Expose port +EXPOSE 8000 + +# Run the FastAPI server +CMD ["python3", "api.py"] + +# Add labels +LABEL maintainer="Lightricks" +LABEL description="LTX-Video API service" +LABEL version="1.0" diff --git a/api.py b/api.py new file mode 100644 index 0000000..7540363 --- /dev/null +++ b/api.py @@ -0,0 +1,428 @@ +import os +import logging +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, Any +from fastapi import FastAPI, File, UploadFile, HTTPException +from fastapi.responses import FileResponse, JSONResponse +from pydantic import BaseModel, Field +from dotenv import load_dotenv +import torch +from huggingface_hub import snapshot_download +from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline +from inference import ( + load_vae, + load_unet, + load_scheduler, + load_image_to_tensor_with_resize_and_crop, + calculate_padding, + get_unique_filename, + seed_everething, + SymmetricPatchifier, + ConditioningMethod +) +from transformers import T5EncoderModel, T5Tokenizer +import imageio +import numpy as np + +# Load environment variables +load_dotenv() + +# Configure logging +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Global variables for loaded models +pipeline = None +generator = None + +# Constants from inference.py +MAX_HEIGHT = 720 +MAX_WIDTH = 1280 +MAX_NUM_FRAMES = 257 + +# Model precision configuration +USE_BFLOAT16 = os.getenv("USE_BFLOAT16", "true").lower() == "true" + +# Download configuration +MAX_WORKERS = int(os.getenv("MAX_WORKERS", "16")) # Number of parallel workers for downloading + +@asynccontextmanager +async def lifespan(app: FastAPI): + global pipeline, generator + + logger.info("Starting LTX Video API server...") + + try: + # Get current working directory and create relative paths + current_dir = Path.cwd() + ckpt_dir = current_dir / os.getenv("CKPT_DIR", "model").lstrip("./") + + # Download model if not exists + if not ckpt_dir.exists() or not any(ckpt_dir.iterdir()): + logger.info(f"Model not found in {ckpt_dir}. Downloading with {MAX_WORKERS} workers...") + ckpt_dir.mkdir(parents=True, exist_ok=True) + + snapshot_download( + "Lightricks/LTX-Video", + local_dir=str(ckpt_dir), + local_dir_use_symlinks=False, + repo_type='model', + max_workers=MAX_WORKERS, # Enable parallel downloading + resume_download=True, # Resume interrupted downloads + etag_timeout=30 # Increase timeout for better stability + ) + logger.info("Model download completed successfully") + + unet_dir = ckpt_dir / "unet" + vae_dir = ckpt_dir / "vae" + scheduler_dir = ckpt_dir / "scheduler" + + logger.info("Loading VAE model...") + vae = load_vae(vae_dir) # This will load VAE in bfloat16 + + logger.info("Loading UNet model...") + unet = load_unet(unet_dir) + if USE_BFLOAT16 and unet.dtype != torch.bfloat16: + logger.info("Converting UNet to bfloat16...") + unet = unet.to(torch.bfloat16) + + logger.info("Loading scheduler...") + scheduler = load_scheduler(scheduler_dir) + patchifier = SymmetricPatchifier(patch_size=1) + + logger.info("Loading text encoder and tokenizer...") + text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder" + ) + if torch.cuda.is_available(): + text_encoder = text_encoder.to("cuda") + + tokenizer = T5Tokenizer.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="tokenizer" + ) + + # Initialize pipeline + logger.info("Initializing LTX Video pipeline...") + pipeline = LTXVideoPipeline( + transformer=unet, + patchifier=patchifier, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + vae=vae + ) + + if torch.cuda.is_available(): + pipeline = pipeline.to("cuda") + logger.info("Pipeline moved to CUDA") + + # Initialize generator + generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") + + logger.info("Server startup complete!") + + yield + + except Exception as e: + logger.error(f"Error loading models: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}") + finally: + # Cleanup resources if needed + pipeline = None + generator = None + +# Create FastAPI app with lifespan +app = FastAPI( + title="LTX Video Generation API", + description="API for generating videos using LTX-Video model", + version="1.0.0", + lifespan=lifespan +) + +class GenerationParams(BaseModel): + prompt: str + negative_prompt: Optional[str] = Field( + default=os.getenv("DEFAULT_NEGATIVE_PROMPT"), + description="Negative prompt for undesired features" + ) + num_inference_steps: Optional[int] = Field( + default=int(os.getenv("DEFAULT_NUM_INFERENCE_STEPS", 40)), + description="Number of inference steps" + ) + guidance_scale: Optional[float] = Field( + default=float(os.getenv("DEFAULT_GUIDANCE_SCALE", 3.0)), + description="Guidance scale for the pipeline" + ) + height: Optional[int] = Field( + default=int(os.getenv("DEFAULT_HEIGHT", 480)), + description="Height of the output video frames" + ) + width: Optional[int] = Field( + default=int(os.getenv("DEFAULT_WIDTH", 704)), + description="Width of the output video frames" + ) + num_frames: Optional[int] = Field( + default=int(os.getenv("DEFAULT_NUM_FRAMES", 121)), + description="Number of frames to generate" + ) + frame_rate: Optional[int] = Field( + default=int(os.getenv("DEFAULT_FRAME_RATE", 25)), + description="Frame rate for the output video" + ) + seed: Optional[int] = Field( + default=171198, + description="Random seed for generation" + ) + use_mixed_precision: Optional[bool] = Field( + default=not USE_BFLOAT16, + description="Whether to use mixed precision during inference" + ) + + class Config: + json_schema_extra = { + "example": { + "prompt": "A clear, turquoise river flows through a rocky canyon", + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + "num_inference_steps": 40, + "guidance_scale": 3.0, + "height": 480, + "width": 704, + "num_frames": 121, + "frame_rate": 25, + "seed": 171198, + "use_mixed_precision": False + } + } + +def validate_dimensions(height: int, width: int, num_frames: int): + """Validate input dimensions against maximum allowed values""" + if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES: + raise HTTPException( + status_code=400, + detail=f"Input dimensions {height}x{width}x{num_frames} exceed maximum allowed values ({MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES})" + ) + +@app.get("/health") +async def health_check() -> Dict[str, Any]: + """Check if the service is healthy and models are loaded""" + if pipeline is None: + return JSONResponse( + status_code=503, + content={"status": "unhealthy", "message": "Models not loaded"} + ) + return { + "status": "healthy", + "cuda_available": torch.cuda.is_available(), + "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, + "using_bfloat16": USE_BFLOAT16 + } + +@app.post("/generate/text-to-video") +async def generate_text_to_video(params: GenerationParams): + """Generate video from text prompt""" + try: + logger.info(f"Starting text-to-video generation with prompt: {params.prompt}") + + # Validate dimensions + validate_dimensions(params.height, params.width, params.num_frames) + + # Set seed + seed_everething(params.seed) + generator.manual_seed(params.seed) + + # Calculate padded dimensions + height_padded = ((params.height - 1) // 32 + 1) * 32 + width_padded = ((params.width - 1) // 32 + 1) * 32 + num_frames_padded = ((params.num_frames - 2) // 8 + 1) * 8 + 1 + + logger.info(f"Using padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}") + + # Prepare input + sample = { + "prompt": params.prompt, + "prompt_attention_mask": None, + "negative_prompt": params.negative_prompt, + "negative_prompt_attention_mask": None, + "media_items": None, + } + + # Generate video + logger.info("Generating video...") + images = pipeline( + num_inference_steps=params.num_inference_steps, + num_images_per_prompt=1, + guidance_scale=params.guidance_scale, + generator=generator, + output_type="pt", + callback_on_step_end=None, + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + frame_rate=params.frame_rate, + **sample, + is_video=True, + vae_per_channel_normalize=True, + conditioning_method=ConditioningMethod.UNCONDITIONAL, + mixed_precision=params.use_mixed_precision, + ).images + + # Save video + output_dir = Path.cwd() / os.getenv("OUTPUT_DIR", "outputs").lstrip("./") / datetime.today().strftime('%Y-%m-%d') + output_dir.mkdir(parents=True, exist_ok=True) + + # Process video frames + video_np = images[0].permute(1, 2, 3, 0).cpu().float().numpy() + video_np = (video_np * 255).astype(np.uint8) + + # Get output filename + output_filename = get_unique_filename( + "text_to_vid_0", + ".mp4", + prompt=params.prompt, + seed=params.seed, + resolution=(params.height, params.width, params.num_frames), + dir=output_dir + ) + + # Write video file + logger.info(f"Saving video to {output_filename}") + with imageio.get_writer(output_filename, fps=params.frame_rate) as video: + for frame in video_np: + video.append_data(frame) + + return FileResponse( + output_filename, + media_type="video/mp4", + filename=output_filename.name + ) + + except Exception as e: + logger.error(f"Error in text-to-video generation: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/generate/image-to-video") +async def generate_image_to_video( + file: UploadFile = File(...), + params: GenerationParams = None +): + """Generate video from input image and text prompt""" + try: + logger.info(f"Starting image-to-video generation with prompt: {params.prompt}") + + # Validate dimensions + validate_dimensions(params.height, params.width, params.num_frames) + + # Save uploaded image temporarily + temp_image_path = f"temp_{file.filename}" + try: + with open(temp_image_path, "wb") as buffer: + content = await file.read() + buffer.write(content) + + logger.info("Processing input image...") + # Load and process image + media_items_prepad = load_image_to_tensor_with_resize_and_crop( + temp_image_path, + params.height, + params.width + ) + finally: + # Clean up temporary file + if os.path.exists(temp_image_path): + os.remove(temp_image_path) + + # Calculate padded dimensions + height_padded = ((params.height - 1) // 32 + 1) * 32 + width_padded = ((params.width - 1) // 32 + 1) * 32 + num_frames_padded = ((params.num_frames - 2) // 8 + 1) * 8 + 1 + + logger.info(f"Using padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}") + + # Calculate padding + padding = calculate_padding(params.height, params.width, height_padded, width_padded) + + # Apply padding to media items + media_items = torch.nn.functional.pad( + media_items_prepad, + padding, + mode="constant", + value=-1 + ) + + # Set seed + seed_everething(params.seed) + generator.manual_seed(params.seed) + + # Prepare input + sample = { + "prompt": params.prompt, + "prompt_attention_mask": None, + "negative_prompt": params.negative_prompt, + "negative_prompt_attention_mask": None, + "media_items": media_items, + } + + # Generate video + logger.info("Generating video...") + images = pipeline( + num_inference_steps=params.num_inference_steps, + num_images_per_prompt=1, + guidance_scale=params.guidance_scale, + generator=generator, + output_type="pt", + callback_on_step_end=None, + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + frame_rate=params.frame_rate, + **sample, + is_video=True, + vae_per_channel_normalize=True, + conditioning_method=ConditioningMethod.FIRST_FRAME, + mixed_precision=params.use_mixed_precision, + ).images + + # Save video + output_dir = Path.cwd() / os.getenv("OUTPUT_DIR", "outputs").lstrip("./") / datetime.today().strftime('%Y-%m-%d') + output_dir.mkdir(parents=True, exist_ok=True) + + # Process video frames + video_np = images[0].permute(1, 2, 3, 0).cpu().float().numpy() + video_np = (video_np * 255).astype(np.uint8) + + # Get output filename + output_filename = get_unique_filename( + "img_to_vid_0", + ".mp4", + prompt=params.prompt, + seed=params.seed, + resolution=(params.height, params.width, params.num_frames), + dir=output_dir + ) + + # Write video file + logger.info(f"Saving video to {output_filename}") + with imageio.get_writer(output_filename, fps=params.frame_rate) as video: + for frame in video_np: + video.append_data(frame) + + return FileResponse( + output_filename, + media_type="video/mp4", + filename=output_filename.name + ) + + except Exception as e: + logger.error(f"Error in image-to-video generation: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..0614271 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,26 @@ +version: '3.8' + +services: + ltx-video-api: + build: + context: . + dockerfile: Dockerfile + image: ltx-video-api + container_name: ltx-video-api + ports: + - "8000:8000" + volumes: + - ./models:/app/models + - ./outputs:/app/outputs + environment: + - CKPT_DIR=/app/models + - OUTPUT_DIR=/app/outputs + - LOG_LEVEL=INFO + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + restart: unless-stopped diff --git a/pyproject.toml b/pyproject.toml index 3f17249..0101ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,13 +22,20 @@ dependencies = [ "transformers~=4.44.2", "sentencepiece>=0.1.96", "huggingface-hub~=0.25.2", - "einops" + "einops", + "fastapi", + "pydantic", + "python-dotenv", + "uvicorn", + "imageio[ffmpeg]" ] [project.optional-dependencies] # Instead of thinking of them as optional, think of them as specific modes inference-script = [ "accelerate", - "matplotlib", - "imageio[ffmpeg]" + "matplotlib" ] + +[tool.setuptools] +packages = ["ltx_video"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e7cf01e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.104.1 +python-dotenv==1.0.0 +uvicorn==0.24.0 +python-multipart==0.0.6 +huggingface_hub +requests==2.31.0 +beautifulsoup4==4.12.3 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..2fdf24c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,16 @@ +[metadata] +name = ltx-video + +[options] +packages = find: +package_dir = + = . + +[options.packages.find] +where = . +exclude = + docs* + tests* + *.tests + *.tests.* + tests.* diff --git a/testapi.py b/testapi.py new file mode 100644 index 0000000..daedf1c --- /dev/null +++ b/testapi.py @@ -0,0 +1,112 @@ +import os +import logging +from typing import Optional, Dict, Any +import requests +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Configure logging based on .env configuration +logging.basicConfig( + level=getattr(logging, os.getenv('LOG_LEVEL', 'INFO')), + format='%(asctime)s - %(levelname)s: %(message)s' +) +logger = logging.getLogger(__name__) + +class LTXVideoAPIClient: + def __init__( + self, + endpoint: Optional[str] = None, + timeout: int = 300 # Increased timeout for longer video generation + ): + """ + Initialize LTX Video API Client + + Args: + endpoint (str, optional): API endpoint. Defaults to localhost. + timeout (int, optional): Request timeout in seconds. Defaults to 300. + """ + self.base_url = endpoint or 'http://localhost:8000' + self.endpoint = f'{self.base_url}/generate/text-to-video' + self.timeout = timeout + + def generate_video( + self, + prompt: str, + params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Generate video from text prompt + + Args: + prompt (str): Detailed text description for video generation + params (dict, optional): Additional generation parameters + + Returns: + Dict containing video generation result + """ + # Default generation parameters with maximum frames + default_params = { + 'guidance_scale': 3.5, + 'num_inference_steps': 40, + 'height': 720, + 'width': 1280, + 'num_frames': 257 # Maximum available frames + } + + # Merge default and user-provided params + generation_params = {**default_params, **(params or {})} + + payload = { + 'prompt': prompt, + **generation_params + } + + try: + logger.info(f"Sending video generation request for prompt: {prompt}") + logger.info(f"Generation parameters: {generation_params}") + + response = requests.post( + self.endpoint, + json=payload, + timeout=self.timeout + ) + + response.raise_for_status() # Raise exception for bad status codes + + logger.info("Video generation request successful") + return response.json() + + except requests.RequestException as e: + logger.error(f"Video generation failed: {e}") + logger.error(f"Request details - Endpoint: {self.endpoint}, Payload: {payload}") + raise + +def main(): + """Example usage of LTXVideoAPIClient""" + try: + client = LTXVideoAPIClient() + + # Example prompt with added detail + prompt = ( + "A woman with long brown hair walks through a sunlit forest. " + "She moves gracefully between tall pine trees, her hand gently " + "brushing against the bark. The camera follows her from behind, " + "capturing her movement and the dappled light filtering through " + "the canopy. Her blue jacket contrasts with the green forest, " + "highly detailed, cinematic, smooth camera movement." + ) + + # Generate video with maximum frames + result = client.generate_video(prompt) + logger.info("Video generation successful!") + logger.info(f"Video details: {result}") + + except Exception as e: + logger.error(f"Unexpected error in video generation: {e}") + import traceback + traceback.print_exc() + +if __name__ == '__main__': + main()