From 5ab2b3a4a4c4b021925c40abdd94c8fce2ce9737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20R=C3=A4dle?= Date: Thu, 2 Jun 2022 16:40:44 -0700 Subject: [PATCH] Update react-native-pytorch-core README.md Summary: Update the example code in the README.md for the `react-native-pytorch-core` package to use the new JSI API. Reviewed By: chrisklaiber Differential Revision: D36882398 fbshipit-source-id: 047aeaf18d69b8a014c6bc22b13b6bf999681f63 --- react-native-pytorch-core/README.md | 156 ++++++++++++++++++++++------ 1 file changed, 124 insertions(+), 32 deletions(-) diff --git a/react-native-pytorch-core/README.md b/react-native-pytorch-core/README.md index dd48a19b0..355028693 100644 --- a/react-native-pytorch-core/README.md +++ b/react-native-pytorch-core/README.md @@ -16,47 +16,139 @@ The full documentation for PyTorch Live can be found on our [website](https://py ## Example Usage ```javascript +// Import dependencies import * as React from 'react'; -import {Button, StyleSheet, Text, View} from 'react-native'; - -import {MobileModel, ImageUtil} from 'react-native-pytorch-core'; - -// Have a look at how to prepare a model for PyTorch Live -// https://pytorch.org/live/docs/tutorials/prepare-custom-model/ -const model = require('./mobilenet_v3_small.ptl'); - -// JSON array of classes that map the max idx to a class label -const IMAGE_CLASSES = require('./image_classes.json'); - -type ImageClassificationResult = { - maxIdx: number; - confidence: number; -}; - -export default function ImageClassificaion() { - const [topClass, setTopClass] = React.useState(''); - async function classifyImage() { - const image = await ImageUtil.fromURL('https://pytorch.org/example.jpg'); - const {metrics, result} = - await MobileModel.execute(model, { - image, - }); - - console.log(metrics); - if (result.confidence > 0.7) { - setTopClass(IMAGE_CLASSES[result.maxIdx]); - } else { - setTopClass('low confidence'); +import { StyleSheet, Text, View } from 'react-native'; +import { + Camera, + Image, + media, + MobileModel, + Module, + Tensor, + torch, + torchvision, +} from 'react-native-pytorch-core'; +import { useSafeAreaInsets } from 'react-native-safe-area-context'; + +// Alias for torchvision transforms +const T = torchvision.transforms; + +// URL to the image classification model that is used int his example +const MODEL_URL = + 'https://github.com/pytorch/live/releases/download/v0.1.0/mobilenet_v3_small.ptl'; + +// URL to the ImageNetClasses JSON file, which is used below to map the +// processed model result to a class label +const IMAGENET_CLASSES_URL = + 'https://github.com/pytorch/live/releases/download/v0.1.0/ImageNetClasses.json'; + +// Variable to hold a reference to the loaded ML model +let model: Module | null = null; + +// Variable to hold a reference to the ImageNet classes +let imageNetClasses: string[] | null = null; + +// App function to render a camera and a text +export default function App() { + // Safe area insets to compensate for notches and bottom bars + const insets = useSafeAreaInsets(); + // Create a React state to store the top class returned from the + // classifyImage function + const [topClass, setTopClass] = React.useState( + "Press capture button to classify what's in the camera view!", + ); + + // Function to handle images whenever the user presses the capture button + async function handleImage(image: Image) { + // Get image width and height + const width = image.getWidth(); + const height = image.getHeight(); + + // Convert image to blob, which is a byte representation of the image + // in the format height (H), width (W), and channels (C), or HWC for short + const blob = media.toBlob(image); + + // Get a tensor from image the blob and also define in what format + // the image blob is. + let tensor = torch.fromBlob(blob, [height, width, 3]); + + // Rearrange the tensor shape to be [CHW] + tensor = tensor.permute([2, 0, 1]); + + // Divide the tensor values by 255 to get values between [0, 1] + tensor = tensor.div(255); + + // Crop the image in the center to be a squared image + const centerCrop = T.centerCrop(Math.min(width, height)); + tensor = centerCrop(tensor); + + // Resize the image tensor to 3 x 224 x 224 + const resize = T.resize(224); + tensor = resize(tensor); + + // Normalize the tensor image with mean and standard deviation + const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]); + tensor = normalize(tensor); + + // Unsqueeze adds 1 leading dimension to the tensor + tensor = tensor.unsqueeze(0); + + // If the model has not been loaded already, it will be downloaded from + // the URL and then loaded into memory. + if (model === null) { + const filePath = await MobileModel.download(MODEL_URL); + model = await torch.jit._loadForMobile(filePath); + } + + // Run the ML inference with the pre-processed image tensor + const output = await model.forward(tensor); + + // Get the index of the value with the highest probability + const maxIdx = output.argmax().item(); + + if (imageNetClasses === null) { + const response = await fetch(IMAGENET_CLASSES_URL); + imageNetClasses = (await response.json()) as string[]; } + + // Resolve the most likely class label and return it + const result = imageNetClasses[maxIdx]; + + // Set result as top class label state + setTopClass(result); + + // Release the image from memory + image.release(); } return ( -