Skip to content

Commit

Permalink
.generate_image and .generate_text now returning request status class…
Browse files Browse the repository at this point in the history
… for more info about request (like kudos cost)
  • Loading branch information
lapismyt committed May 25, 2024
1 parent 6722da1 commit 6584446
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def image():
'AAM XL', 'AlbedoBase XL (SDXL)', 'Animagine XL', 'Anime Illust Diffusion', 'DreamShaper XL',
'ICBINP XL', 'Juggernaut XL', 'Quiet Goodnight XL', 'Unstable Diffusers XL']
)
generations: list[models.GenerationStable] = await client.generate_image(generation_input)
generations: list[models.GenerationStable] = (await client.generate_image(generation_input)).generations
for generation in generations:
print(f'{generation.model}: {generation.img}')

Expand All @@ -52,7 +52,7 @@ async def text():
params=params,
models=['koboldcpp/Kunoichi-DPO-v2-7B-Q8_0-imatrix'] # i dont know which models are good for text :D
)
results: list[models.GenerationKobold] = await client.generate_text(generation_input)
results: list[models.GenerationKobold] = (await client.generate_text(generation_input)).generations
for result in results:
print(result.text.strip())

Expand Down
8 changes: 4 additions & 4 deletions aihorde/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ async def image_generation_status(

async def generate_images(
self, request: models.GenerationInputStable
) -> list[models.GenerationStable]:
) -> models.RequestStatusStable:
resp: models.RequestAsync = await self.generate_image_request(request)
request_id = resp.id
while True:
status = await self.image_generation_status(request_id)
if status.done:
return status.generations
return status
else:
await asyncio.sleep(int(status.wait_time / 1.5))

Expand All @@ -77,13 +77,13 @@ async def text_generation_status(

async def generate_text(
self, request: models.GenerationInputKobold
) -> list[models.GenerationKobold]:
) -> models.RequestStatusKobold:
resp: models.RequestAsync = await self.generate_text_request(request)
request_id = resp.id
while True:
status = await self.text_generation_status(request_id)
if status.done:
return status.generations
return status
else:
await asyncio.sleep(int(status.wait_time / 1.5))

Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def text():
params=params,
models=['koboldcpp/Kunoichi-DPO-v2-7B-Q8_0-imatrix'] # i dont know which models are good for text
)
results: list[models.GenerationKobold] = await client.generate_text(generation_input)
results: list[models.GenerationKobold] = (await client.generate_text(generation_input)).generations
for result in results:
print(result.text.strip())

Expand All @@ -29,7 +29,7 @@ async def image():
'AAM XL', 'AlbedoBase XL (SDXL)', 'Animagine XL', 'Anime Illust Diffusion', 'DreamShaper XL',
'ICBINP XL', 'Juggernaut XL', 'Quiet Goodnight XL', 'Unstable Diffusers XL']
)
generations: list[models.GenerationStable] = await client.generate_image(generation_input)
generations: list[models.GenerationStable] = (await client.generate_image(generation_input)).generations
for generation in generations:
print(f'{generation.model}: {generation.img}')

Expand Down

0 comments on commit 6584446

Please sign in to comment.