Skip to content

Commit

Permalink
upscaling works
Browse files Browse the repository at this point in the history
  • Loading branch information
bergvall95 committed Jan 26, 2023
1 parent fdc0264 commit 4f9e70b
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 58 deletions.
20 changes: 13 additions & 7 deletions src/components/projects/CustomZoomContent.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ShotsPick } from "@/pages/api/projects";
import { Box, Text } from "@chakra-ui/layout";
import { Button, useDisclosure } from "@chakra-ui/react";
import { Button, Portal, useDisclosure } from "@chakra-ui/react";
import axios from "axios";
import { useRouter } from "next/router";
import { ReactElement, FC } from "react";
Expand Down Expand Up @@ -62,16 +62,22 @@ const CustomZoomContent: FC<CustomZoomContentProps> = ({
</Box>
<Box display={"flex"} alignItems={"center"} justifyContent={"center"}>
{modalLoaded && (
<Button variant={"brand"} onClick={() => mutateUpscale(shot)}>
<Button variant={"brand"} onClick={onOpen}>
Upscale
</Button>
)}
</Box>
<ConfirmationModal
isOpen={isOpen}
onClose={onClose}
onConfirm={() => mutateUpscale(shot)}
/>
<Box
display={isOpen ? "flex" : "none"}
alignItems="center"
justifyContent="center"
>
<ConfirmationModal
isOpen={isOpen}
onClose={onClose}
onConfirm={() => mutateUpscale(shot)}
/>
</Box>
</>
);
};
Expand Down
15 changes: 11 additions & 4 deletions src/components/projects/ShotCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@ const ShotCard = ({
)
.then((res) => res.data),
{
refetchInterval: (data) => (data?.shot.imageUrl ? false : 5000),
refetchInterval: (data) =>
!data?.shot.imageUrl ||
(data?.shot.upscaleId && !data?.shot.upscaledImageUrl)
? 5000
: false,
refetchOnWindowFocus: false,
enabled: !initialShot.imageUrl && initialShot.status !== "failed",
enabled:
(initialShot.upscaleId &&
initialShot.status !== "failed" &&
!initialShot.upscaledImageUrl) ||
(!initialShot.imageUrl && initialShot.status !== "failed"),
initialData: { shot: initialShot },
}
);

const shot = data!.shot;

return (
<Box key={shot.id} backgroundColor="gray.100" overflow="hidden">
{shot.status === "failed" && (
Expand All @@ -54,7 +61,7 @@ const ShotCard = ({
>
<NextImage
alt={shot.filterName || "Stylized image of your pet"}
src={getFullShotUrl(shot)}
src={getFullShotUrl(shot, shot.upscaledImageUrl ? true : false)}
width="512"
height="512"
/>
Expand Down
64 changes: 40 additions & 24 deletions src/components/shared/ConfirmationModal.tsx
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import {
Box,
Button,
Modal,
ModalBody,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Portal,
Text,
} from "@chakra-ui/react";
import { useSession } from "next-auth/react";
import { useRouter } from "next/router";
import { useState } from "react";

interface CreditsModalProps {
/**
Expand All @@ -34,29 +35,44 @@ const ConfirmationModal = ({
const userId = userSession?.user.id;
// Options for credit packages radio buttons, we only want the id here to render options
return (
<Modal
scrollBehavior="inside"
size="md"
isOpen={isOpen}
onClose={() => {
onClose();
}}
>
<ModalOverlay />
<ModalContent>
<ModalHeader>Upscale Image</ModalHeader>
<ModalBody>
<Text fontSize="sm" mb={4}>
Are you sure you want to spend 1 credit to upscale this image?
</Text>
</ModalBody>
<ModalFooter>
<Button size="lg" variant="brand" onClick={() => onConfirm()}>
Yes
</Button>
</ModalFooter>
</ModalContent>
</Modal>
<div style={{ zIndex: 9999 }}>
<Modal
scrollBehavior="inside"
size="md"
isOpen={isOpen}
onClose={() => {
onClose();
}}
isCentered={true}
>
<ModalOverlay />
<ModalContent zIndex={9999}>
<ModalHeader>Upscale Image</ModalHeader>
<ModalBody>
<Text fontSize="sm" mb={4}>
Are you sure you want to spend 1 credit to upscale this image?
</Text>
</ModalBody>
<ModalFooter>
<Box display="flex" justifyContent={"space-around"} width="100%">
<Button size="lg" variant="brand" onClick={() => onClose()}>
Cancel
</Button>
<Button
size="lg"
variant="brand"
onClick={() => {
onConfirm();
onClose();
}}
>
Yes
</Button>
</Box>
</ModalFooter>
</ModalContent>
</Modal>
</div>
);
};

Expand Down
2 changes: 1 addition & 1 deletion src/core/clients/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ export type UpscaleRequest = {
*
* default: Real-World Image Super-Resolution-Large
**/
task:
task_type:
| "Real-World Image Super-Resolution-Large"
| "Real-World Image Super-Resolution-Medium"
| "Grayscale Image Denoising"
Expand Down
12 changes: 10 additions & 2 deletions src/core/utils/bucketHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ export const getShotsUrlPath = (
*/
export const fetchImageAndStoreIt = async (
url: string,
shot: Pick<Shot, "id" | "projectId">
shot: Pick<Shot, "id" | "projectId" | "upscaleId">
): Promise<string | null> => {
const response = await fetch(url);
const buffer = await response.arrayBuffer();
const imagePath = getShotsUrlPath(shot);
const imagePath = getShotsUrlPath(shot, shot.upscaleId !== null);

// This should always just return the `{id}.png` part of the url
const bucketFileName = imagePath.split("/").slice(-1)[0];
Expand All @@ -57,3 +57,11 @@ export const fetchImageAndStoreIt = async (

return bucketFileName;
};

export const fetchImageAndGetDataUrl = async (url: string): Promise<string> => {
const response = await fetch(url);
const buffer = await response.arrayBuffer();
const myBlob = new Blob([buffer]);
const dataUrl = URL.createObjectURL(myBlob);
return dataUrl;
};
2 changes: 2 additions & 0 deletions src/pages/api/projects/[id]/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type ShotsPick = Pick<
| "upscaledImageUrl"
| "projectId"
| "status"
| "upscaleId"
>;

export type ProjectIdResponse = {
Expand Down Expand Up @@ -63,6 +64,7 @@ const handler = async (
upscaledImageUrl: true,
projectId: true,
status: true,
upscaleId: true,
},
orderBy: { createdAt: "desc" },
},
Expand Down
12 changes: 7 additions & 5 deletions src/pages/api/projects/[id]/predictions/[predictionId].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
const shot = await db.shot.findFirstOrThrow({
where: { projectId: projectId, id: predictionId },
});

const fetchId = shot.upscaleId ? shot.upscaleId : shot.replicateId;
const { data: prediction } = await replicateClient.get<PredictionResponse>(
`https://api.replicate.com/v1/predictions/${shot.replicateId}`
`https://api.replicate.com/v1/predictions/${fetchId}`
);

// If the initial shot status changes from the prediction, update the shot in database.
if (shot.status !== prediction.status) {
const outputUrl = prediction.output?.[0];
const outputUrl = shot.upscaleId
? prediction.output
: prediction.output?.[0];
// If the prediction has an output, download it and store it in the bucket.
if (outputUrl) {
const fileName = await fetchImageAndStoreIt(outputUrl, shot);
Expand All @@ -30,7 +31,8 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
where: { id: shot.id },
data: {
status: prediction.status,
imageUrl: fileName,
imageUrl: shot.upscaleId ? shot.imageUrl : fileName,
upscaledImageUrl: shot.upscaleId ? fileName : null,
},
});
}
Expand Down
2 changes: 2 additions & 0 deletions src/pages/api/projects/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export type ShotsPick = Pick<
| "status"
| "imageUrl"
| "upscaledImageUrl"
| "upscaleId"
>;

export type ProjectWithShots = {
Expand Down Expand Up @@ -103,6 +104,7 @@ const handler = async (
upscaledImageUrl: true,
projectId: true,
status: true,
upscaleId: true,
},
...(shotAmount && { take: Number(shotAmount) }),
},
Expand Down
21 changes: 6 additions & 15 deletions src/pages/api/shots/upscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import replicateClient, {
} from "@/core/clients/replicate";
import { getRefinedInstanceClass } from "@/core/utils/predictions";
import supabase from "@/core/clients/supabase";
import { fetchImageAndGetDataUrl } from "@/core/utils/bucketHelpers";

const SWINIR_VERSION =
"660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a";
Expand All @@ -22,7 +23,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
if (!session?.user) {
return res.status(401).json({ message: "Not authenticated" });
}
console.log("shotId", shotId);

let shot = await db.shot.findFirstOrThrow({
where: {
id: shotId,
Expand All @@ -31,30 +32,20 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
});

const { data: data_url } = supabase.storage
.from(process.env.NEXT_PUBLIC_UPLOAD_BUCKET_NAME!)
.getPublicUrl(`${shot.projectId}/standard/${shot.id}`);

const trainingData: UpscaleRequest = {
input: {
image: data_url.publicUrl,
task: TASK_TYPE,
},
version: SWINIR_VERSION,
};
.from(process.env.NEXT_SHOT_BUCKET_NAME!)
.getPublicUrl(`${shot.projectId}/standard/${shot.id}.png`);

const { data } = await replicateClient.post<UpscaleResponse>(
`https://api.replicate.com/v1/predictions`,
{
input: { prompt },
input: { image: data_url.publicUrl, task_type: TASK_TYPE },
version: SWINIR_VERSION,
}
);
console.log(data);

shot = await db.shot.update({
where: { id: shot.id },
/// TEMPORARY, CHANGE TOMORROW
data: { upscaleId: data.id },
data: { upscaleId: data.id, status: data.status },
});

// Decrement the user's credits
Expand Down

0 comments on commit 4f9e70b

Please sign in to comment.