forked from brevdev/simple-dreambooth-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
115 lines (96 loc) · 3.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import base64
from celery import Celery
from fastapi import FastAPI
from uuid import uuid4
from fastapi import FastAPI, File, UploadFile, status
from fastapi.exceptions import HTTPException
import aiofiles
import os
import zipfile
import shutil
from celery.result import AsyncResult
from fastapi.responses import FileResponse
app = FastAPI()
celery = Celery(
__name__,
broker="redis://127.0.0.1:6379/0",
backend="redis://127.0.0.1:6379/0"
)
CHUNK_SIZE = 1024 * 1024 # adjust the chunk size as desired
temporaryZipFileDirectory = './userzipfiles/'
if not os.path.exists(temporaryZipFileDirectory):
os.makedirs(temporaryZipFileDirectory)
extractedFilesDirectory = './extractedfiles/'
if not os.path.exists(extractedFilesDirectory):
os.makedirs(extractedFilesDirectory)
outputModelsDirectory = './outputmodels/'
if not os.path.exists(outputModelsDirectory):
os.makedirs(outputModelsDirectory)
@app.post("/finetune")
async def upload(zipFile: UploadFile = File(...)):
uuid = uuid4()
finalOutputDirectory = await saveZipFile(uuid, zipFile)
outputDirectory = outputModelsDirectory + str(uuid)
task = train.delay(finalOutputDirectory, outputDirectory)
return {"message": f"Job successfully submitted. This should take about 5 minutes to run depending on the queue.", "Task ID": task.id, "Model ID": uuid}
@app.get("/finetunejobstatus")
async def jobStatus(jobid: str):
result = AsyncResult(jobid, app=celery)
return result.status
@app.post("/inference")
async def inference(prompt: str, modelId: str):
modelPath = outputModelsDirectory + modelId + "/" + "800"
inference.delay(prompt, modelId, modelPath)
return {"message": "Your inference job has been run. You can get the result by querying the /inferenceoutput endpoint with the task ID.", "Model Id": modelId}
@app.get("/inferenceoutput")
async def inferenceoutput(modelId: str):
outputImgPath = f'./{modelId}.png'
if not os.path.exists(outputImgPath):
return {"message": "Your training job is still either in the queue or is running. Please try again later."}
return FileResponse(outputImgPath)
@celery.task
def train(dataDirectory, outputDirectory): # todo: maybe improve the params too
os.system(f"accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path='CompVis/stable-diffusion-v1-4' \
--instance_data_dir={dataDirectory} \
--output_dir={outputDirectory} \
--instance_prompt='photo of sks dog' \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler='constant' \
--lr_warmup_steps=0 \
--max_train_steps=800")
@celery.task
def inference(prompt, modelId, modelPath):
imgSavePath = f'./{modelId}.png'
os.system(f'python3 inference.py {modelPath} "{prompt}" {imgSavePath}')
async def saveZipFile(uuid, file):
try:
zipFilepath = temporaryZipFileDirectory + str(uuid) + '.zip' # todo: could add a path.join thing here
print(zipFilepath)
async with aiofiles.open(zipFilepath, 'wb') as f:
while chunk := await file.read(CHUNK_SIZE):
await f.write(chunk)
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='There was an error uploading the file')
finally:
await file.close()
dataDirectory = extractedFilesDirectory + str(uuid)
if not os.path.exists(dataDirectory):
os.makedirs(dataDirectory)
with zipfile.ZipFile(zipFilepath, 'r') as zip_ref:
zip_ref.extractall(dataDirectory)
os.remove(zipFilepath)
finalOutputDirectory = './finaldatadirectory/' + str(uuid)
if not os.path.exists(finalOutputDirectory):
os.makedirs(finalOutputDirectory)
for root, dirs, files in os.walk(dataDirectory):
for file in files:
if file.endswith(".jpg"):
filePath = os.path.join(root, file)
shutil.move(filePath, finalOutputDirectory)
shutil.rmtree(dataDirectory)
return finalOutputDirectory