From 9d6c0641c4a3c2c5ecf4d49d7609edd5b745d9bc Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 25 Jul 2024 19:20:47 +0100 Subject: [PATCH] Fix code snippet for Grounding DINO (#32229) Fix code snippet for grounding-dino --- docs/source/en/model_doc/grounding-dino.md | 61 ++++++++++++---------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/docs/source/en/model_doc/grounding-dino.md b/docs/source/en/model_doc/grounding-dino.md index d258f492abf8b5..a6da554f8d5053 100644 --- a/docs/source/en/model_doc/grounding-dino.md +++ b/docs/source/en/model_doc/grounding-dino.md @@ -41,33 +41,40 @@ The original code can be found [here](https://github.com/IDEA-Research/Grounding Here's how to use the model for zero-shot object detection: ```python -import requests - -import torch -from PIL import Image -from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, - -model_id = "IDEA-Research/grounding-dino-tiny" - -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) - -image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = Image.open(requests.get(image_url, stream=True).raw) -# Check for cats and remote controls -text = "a cat. a remote control." - -inputs = processor(images=image, text=text, return_tensors="pt").to(device) -with torch.no_grad(): - outputs = model(**inputs) - -results = processor.post_process_grounded_object_detection( - outputs, - inputs.input_ids, - box_threshold=0.4, - text_threshold=0.3, - target_sizes=[image.size[::-1]] -) +>>> import requests + +>>> import torch +>>> from PIL import Image +>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + +>>> model_id = "IDEA-Research/grounding-dino-tiny" +>>> device = "cuda" + +>>> processor = AutoProcessor.from_pretrained(model_id) +>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) + +>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(image_url, stream=True).raw) +>>> # Check for cats and remote controls +>>> text = "a cat. a remote control." + +>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device) +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> results = processor.post_process_grounded_object_detection( +... outputs, +... inputs.input_ids, +... box_threshold=0.4, +... text_threshold=0.3, +... target_sizes=[image.size[::-1]] +... ) +>>> print(results) +[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751], + [ 12.2666, 51.9145, 316.8582, 472.4392], + [ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'), + 'labels': ['a cat', 'a cat', 'a remote control'], + 'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}] ``` ## Grounded SAM