From 815dfa6f54178d72c28457dd52bb7cf07b722317 Mon Sep 17 00:00:00 2001 From: Omer Spillinger Date: Mon, 4 Nov 2019 23:45:57 -0800 Subject: [PATCH] Update examples (cherry picked from commit 6aea035cab387e5a6fd47741af99def29c709d93) --- examples/pytorch/image-classifier/predictor.py | 6 +++--- examples/pytorch/image-classifier/sample.json | 2 +- examples/pytorch/iris-classifier/src/predictor.py | 4 +++- examples/pytorch/text-generator/cortex.yaml | 2 +- examples/pytorch/text-generator/predictor.py | 14 +++++++------- examples/tensorflow/image-classifier/sample.json | 2 +- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/pytorch/image-classifier/predictor.py b/examples/pytorch/image-classifier/predictor.py index 4ae89bc4d2..6dea708b1e 100644 --- a/examples/pytorch/image-classifier/predictor.py +++ b/examples/pytorch/image-classifier/predictor.py @@ -1,9 +1,9 @@ import requests +import torch +import torchvision +from torchvision import transforms from PIL import Image from io import BytesIO -from torchvision import transforms -import torchvision -import torch model = torchvision.models.alexnet(pretrained=True) model.eval() diff --git a/examples/pytorch/image-classifier/sample.json b/examples/pytorch/image-classifier/sample.json index 1c7deae607..eb72ddb869 100644 --- a/examples/pytorch/image-classifier/sample.json +++ b/examples/pytorch/image-classifier/sample.json @@ -1,3 +1,3 @@ { - "url": "https://bowwowinsurance.com.au/wp-content/uploads/2018/10/akita-700x700.jpg" + "url": "https://i.imgur.com/PzXprwl.jpg" } diff --git a/examples/pytorch/iris-classifier/src/predictor.py b/examples/pytorch/iris-classifier/src/predictor.py index 5fc67c19e7..3da8100840 100644 --- a/examples/pytorch/iris-classifier/src/predictor.py +++ b/examples/pytorch/iris-classifier/src/predictor.py @@ -2,7 +2,6 @@ import torch from model import IrisNet -labels = ["iris-setosa", "iris-versicolor", "iris-virginica"] model = IrisNet() @@ -12,6 +11,9 @@ def init(model_path, metadata): model.eval() +labels = ["iris-setosa", "iris-versicolor", "iris-virginica"] + + def predict(sample, metadata): input_tensor = torch.FloatTensor( [ diff --git a/examples/pytorch/text-generator/cortex.yaml b/examples/pytorch/text-generator/cortex.yaml index 7245950aed..7867b28029 100644 --- a/examples/pytorch/text-generator/cortex.yaml +++ b/examples/pytorch/text-generator/cortex.yaml @@ -6,7 +6,7 @@ predictor: path: predictor.py metadata: - num_words: 20 + num_words: 50 device: cuda # use "cpu" to run on CPUs compute: gpu: 1 diff --git a/examples/pytorch/text-generator/predictor.py b/examples/pytorch/text-generator/predictor.py index b67b980728..216fbeff35 100644 --- a/examples/pytorch/text-generator/predictor.py +++ b/examples/pytorch/text-generator/predictor.py @@ -1,18 +1,18 @@ +# This file includes code which was modified from https://github.com/huggingface/transformers/blob/master/examples/run_generation.py + from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np -import argparse -import logging -from tqdm import trange -import torch.nn.functional as F import torch +import torch.nn.functional as F from transformers import GPT2Tokenizer, GPT2LMHeadModel +from tqdm import trange + +tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") model = GPT2LMHeadModel.from_pretrained("distilgpt2") model.eval() -tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") -# adapted from: https://github.com/huggingface/transformers/blob/master/examples/run_generation.py + def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: diff --git a/examples/tensorflow/image-classifier/sample.json b/examples/tensorflow/image-classifier/sample.json index 94c1693e7a..667652007a 100644 --- a/examples/tensorflow/image-classifier/sample.json +++ b/examples/tensorflow/image-classifier/sample.json @@ -1,3 +1,3 @@ { - "url": "https://bowwowinsurance.com.au/wp-content/uploads/2018/10/akita-700x700.jpg" + "url": "https://i.imgur.com/PzXprwl.jpg" }