From 4f9e70bc31f16cf8474683f959cf4d3c014d4606 Mon Sep 17 00:00:00 2001 From: alex bergvall Date: Thu, 26 Jan 2023 01:15:43 -0600 Subject: [PATCH] upscaling works --- src/components/projects/CustomZoomContent.tsx | 20 ++++-- src/components/projects/ShotCard.tsx | 15 +++-- src/components/shared/ConfirmationModal.tsx | 64 ++++++++++++------- src/core/clients/replicate.ts | 2 +- src/core/utils/bucketHelpers.ts | 12 +++- src/pages/api/projects/[id]/index.ts | 2 + .../[id]/predictions/[predictionId].tsx | 12 ++-- src/pages/api/projects/index.ts | 2 + src/pages/api/shots/upscale.ts | 21 ++---- 9 files changed, 92 insertions(+), 58 deletions(-) diff --git a/src/components/projects/CustomZoomContent.tsx b/src/components/projects/CustomZoomContent.tsx index a082df6..0fdd550 100644 --- a/src/components/projects/CustomZoomContent.tsx +++ b/src/components/projects/CustomZoomContent.tsx @@ -1,6 +1,6 @@ import { ShotsPick } from "@/pages/api/projects"; import { Box, Text } from "@chakra-ui/layout"; -import { Button, useDisclosure } from "@chakra-ui/react"; +import { Button, Portal, useDisclosure } from "@chakra-ui/react"; import axios from "axios"; import { useRouter } from "next/router"; import { ReactElement, FC } from "react"; @@ -62,16 +62,22 @@ const CustomZoomContent: FC = ({ {modalLoaded && ( - )} - mutateUpscale(shot)} - /> + + mutateUpscale(shot)} + /> + ); }; diff --git a/src/components/projects/ShotCard.tsx b/src/components/projects/ShotCard.tsx index 1bcb928..14f76b3 100644 --- a/src/components/projects/ShotCard.tsx +++ b/src/components/projects/ShotCard.tsx @@ -24,15 +24,22 @@ const ShotCard = ({ ) .then((res) => res.data), { - refetchInterval: (data) => (data?.shot.imageUrl ? false : 5000), + refetchInterval: (data) => + !data?.shot.imageUrl || + (data?.shot.upscaleId && !data?.shot.upscaledImageUrl) + ? 5000 + : false, refetchOnWindowFocus: false, - enabled: !initialShot.imageUrl && initialShot.status !== "failed", + enabled: + (initialShot.upscaleId && + initialShot.status !== "failed" && + !initialShot.upscaledImageUrl) || + (!initialShot.imageUrl && initialShot.status !== "failed"), initialData: { shot: initialShot }, } ); const shot = data!.shot; - return ( {shot.status === "failed" && ( @@ -54,7 +61,7 @@ const ShotCard = ({ > diff --git a/src/components/shared/ConfirmationModal.tsx b/src/components/shared/ConfirmationModal.tsx index 0de55fb..4e89f4f 100644 --- a/src/components/shared/ConfirmationModal.tsx +++ b/src/components/shared/ConfirmationModal.tsx @@ -1,4 +1,5 @@ import { + Box, Button, Modal, ModalBody, @@ -6,11 +7,11 @@ import { ModalFooter, ModalHeader, ModalOverlay, + Portal, Text, } from "@chakra-ui/react"; import { useSession } from "next-auth/react"; import { useRouter } from "next/router"; -import { useState } from "react"; interface CreditsModalProps { /** @@ -34,29 +35,44 @@ const ConfirmationModal = ({ const userId = userSession?.user.id; // Options for credit packages radio buttons, we only want the id here to render options return ( - { - onClose(); - }} - > - - - Upscale Image - - - Are you sure you want to spend 1 credit to upscale this image? - - - - - - - +
+ { + onClose(); + }} + isCentered={true} + > + + + Upscale Image + + + Are you sure you want to spend 1 credit to upscale this image? + + + + + + + + + + +
); }; diff --git a/src/core/clients/replicate.ts b/src/core/clients/replicate.ts index 7c03281..a8e3b64 100644 --- a/src/core/clients/replicate.ts +++ b/src/core/clients/replicate.ts @@ -248,7 +248,7 @@ export type UpscaleRequest = { * * default: Real-World Image Super-Resolution-Large **/ - task: + task_type: | "Real-World Image Super-Resolution-Large" | "Real-World Image Super-Resolution-Medium" | "Grayscale Image Denoising" diff --git a/src/core/utils/bucketHelpers.ts b/src/core/utils/bucketHelpers.ts index 885eecb..6e62306 100644 --- a/src/core/utils/bucketHelpers.ts +++ b/src/core/utils/bucketHelpers.ts @@ -35,11 +35,11 @@ export const getShotsUrlPath = ( */ export const fetchImageAndStoreIt = async ( url: string, - shot: Pick + shot: Pick ): Promise => { const response = await fetch(url); const buffer = await response.arrayBuffer(); - const imagePath = getShotsUrlPath(shot); + const imagePath = getShotsUrlPath(shot, shot.upscaleId !== null); // This should always just return the `{id}.png` part of the url const bucketFileName = imagePath.split("/").slice(-1)[0]; @@ -57,3 +57,11 @@ export const fetchImageAndStoreIt = async ( return bucketFileName; }; + +export const fetchImageAndGetDataUrl = async (url: string): Promise => { + const response = await fetch(url); + const buffer = await response.arrayBuffer(); + const myBlob = new Blob([buffer]); + const dataUrl = URL.createObjectURL(myBlob); + return dataUrl; +}; diff --git a/src/pages/api/projects/[id]/index.ts b/src/pages/api/projects/[id]/index.ts index a046a98..5db029d 100644 --- a/src/pages/api/projects/[id]/index.ts +++ b/src/pages/api/projects/[id]/index.ts @@ -17,6 +17,7 @@ type ShotsPick = Pick< | "upscaledImageUrl" | "projectId" | "status" + | "upscaleId" >; export type ProjectIdResponse = { @@ -63,6 +64,7 @@ const handler = async ( upscaledImageUrl: true, projectId: true, status: true, + upscaleId: true, }, orderBy: { createdAt: "desc" }, }, diff --git a/src/pages/api/projects/[id]/predictions/[predictionId].tsx b/src/pages/api/projects/[id]/predictions/[predictionId].tsx index eaefdee..20597c7 100644 --- a/src/pages/api/projects/[id]/predictions/[predictionId].tsx +++ b/src/pages/api/projects/[id]/predictions/[predictionId].tsx @@ -14,14 +14,15 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { const shot = await db.shot.findFirstOrThrow({ where: { projectId: projectId, id: predictionId }, }); - + const fetchId = shot.upscaleId ? shot.upscaleId : shot.replicateId; const { data: prediction } = await replicateClient.get( - `https://api.replicate.com/v1/predictions/${shot.replicateId}` + `https://api.replicate.com/v1/predictions/${fetchId}` ); - // If the initial shot status changes from the prediction, update the shot in database. if (shot.status !== prediction.status) { - const outputUrl = prediction.output?.[0]; + const outputUrl = shot.upscaleId + ? prediction.output + : prediction.output?.[0]; // If the prediction has an output, download it and store it in the bucket. if (outputUrl) { const fileName = await fetchImageAndStoreIt(outputUrl, shot); @@ -30,7 +31,8 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { where: { id: shot.id }, data: { status: prediction.status, - imageUrl: fileName, + imageUrl: shot.upscaleId ? shot.imageUrl : fileName, + upscaledImageUrl: shot.upscaleId ? fileName : null, }, }); } diff --git a/src/pages/api/projects/index.ts b/src/pages/api/projects/index.ts index 94d0126..7b70911 100644 --- a/src/pages/api/projects/index.ts +++ b/src/pages/api/projects/index.ts @@ -17,6 +17,7 @@ export type ShotsPick = Pick< | "status" | "imageUrl" | "upscaledImageUrl" + | "upscaleId" >; export type ProjectWithShots = { @@ -103,6 +104,7 @@ const handler = async ( upscaledImageUrl: true, projectId: true, status: true, + upscaleId: true, }, ...(shotAmount && { take: Number(shotAmount) }), }, diff --git a/src/pages/api/shots/upscale.ts b/src/pages/api/shots/upscale.ts index 21bf6fd..fe4c059 100644 --- a/src/pages/api/shots/upscale.ts +++ b/src/pages/api/shots/upscale.ts @@ -9,6 +9,7 @@ import replicateClient, { } from "@/core/clients/replicate"; import { getRefinedInstanceClass } from "@/core/utils/predictions"; import supabase from "@/core/clients/supabase"; +import { fetchImageAndGetDataUrl } from "@/core/utils/bucketHelpers"; const SWINIR_VERSION = "660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a"; @@ -22,7 +23,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { if (!session?.user) { return res.status(401).json({ message: "Not authenticated" }); } - console.log("shotId", shotId); + let shot = await db.shot.findFirstOrThrow({ where: { id: shotId, @@ -31,30 +32,20 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { }); const { data: data_url } = supabase.storage - .from(process.env.NEXT_PUBLIC_UPLOAD_BUCKET_NAME!) - .getPublicUrl(`${shot.projectId}/standard/${shot.id}`); - - const trainingData: UpscaleRequest = { - input: { - image: data_url.publicUrl, - task: TASK_TYPE, - }, - version: SWINIR_VERSION, - }; + .from(process.env.NEXT_SHOT_BUCKET_NAME!) + .getPublicUrl(`${shot.projectId}/standard/${shot.id}.png`); const { data } = await replicateClient.post( `https://api.replicate.com/v1/predictions`, { - input: { prompt }, + input: { image: data_url.publicUrl, task_type: TASK_TYPE }, version: SWINIR_VERSION, } ); - console.log(data); shot = await db.shot.update({ where: { id: shot.id }, - /// TEMPORARY, CHANGE TOMORROW - data: { upscaleId: data.id }, + data: { upscaleId: data.id, status: data.status }, }); // Decrement the user's credits