diff --git a/README.md b/README.md index d83ac47d530ad2..3ee8633858c3c3 100644 --- a/README.md +++ b/README.md @@ -397,6 +397,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (from South China University of Technology) released with the paper [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) by Jiapeng Wang, Lianwen Jin, Kai Ding. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. diff --git a/README_es.md b/README_es.md index b67ef86b764f8d..c1a029c7f352e1 100644 --- a/README_es.md +++ b/README_es.md @@ -372,6 +372,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (from South China University of Technology) released with the paper [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) by Jiapeng Wang, Lianwen Jin, Kai Ding. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.. +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. diff --git a/README_hd.md b/README_hd.md index 975273f2e99b25..43d8295a4f4423 100644 --- a/README_hd.md +++ b/README_hd.md @@ -346,6 +346,7 @@ conda install -c huggingface transformers 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (दक्षिण चीन प्रौद्योगिकी विश्वविद्यालय से) साथ में कागज [LiLT: एक सरल लेकिन प्रभावी भाषा-स्वतंत्र लेआउट ट्रांसफार्मर संरचित दस्तावेज़ समझ के लिए](https://arxiv.org/abs/2202.13669) जियापेंग वांग, लियानवेन जिन, काई डिंग द्वारा पोस्ट किया गया। 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (The FAIR team of Meta AI से) Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. द्वाराअनुसंधान पत्र [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) के साथ जारी किया गया 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (The FAIR team of Meta AI से) Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.. द्वाराअनुसंधान पत्र [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) के साथ जारी किया गया +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (Microsoft Research & University of Wisconsin-Madison से) Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. द्वाराअनुसंधान पत्र [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) के साथ जारी किया गया 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (मैंडी गुओ, जोशुआ आइंस्ली, डेविड यूथस, सैंटियागो ओंटानन, जियानमो नि, यूं-हुआन सुंग, यिनफेई यांग द्वारा पोस्ट किया गया। 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (स्टूडियो औसिया से) साथ में पेपर [LUKE: डीप कॉन्टेक्स्टुअलाइज्ड एंटिटी रिप्रेजेंटेशन विद एंटिटी-अवेयर सेल्फ-अटेंशन](https ://arxiv.org/abs/2010.01057) Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto द्वारा। diff --git a/README_ja.md b/README_ja.md index bbe814543f6f12..b811aaf1d20afd 100644 --- a/README_ja.md +++ b/README_ja.md @@ -406,6 +406,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (South China University of Technology から) Jiapeng Wang, Lianwen Jin, Kai Ding から公開された研究論文: [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (The FAIR team of Meta AI から) Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. から公開された研究論文 [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (The FAIR team of Meta AI から) Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.. から公開された研究論文 [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (Microsoft Research & University of Wisconsin-Madison から) Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. から公開された研究論文 [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (AllenAI から) Iz Beltagy, Matthew E. Peters, Arman Cohan から公開された研究論文: [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (Google AI から) Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang から公開された研究論文: [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (Studio Ousia から) Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto から公開された研究論文: [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) diff --git a/README_ko.md b/README_ko.md index 43a7dee808cdbf..c4d72235f34c00 100644 --- a/README_ko.md +++ b/README_ko.md @@ -321,6 +321,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (South China University of Technology 에서) Jiapeng Wang, Lianwen Jin, Kai Ding 의 [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) 논문과 함께 발표했습니다. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (The FAIR team of Meta AI 에서 제공)은 Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.의 [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)논문과 함께 발표했습니다. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (The FAIR team of Meta AI 에서 제공)은 Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom..의 [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX)논문과 함께 발표했습니다. +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (Microsoft Research & University of Wisconsin-Madison 에서 제공)은 Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee.의 [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744)논문과 함께 발표했습니다. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (AllenAI 에서) Iz Beltagy, Matthew E. Peters, Arman Cohan 의 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 논문과 함께 발표했습니다. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (Google AI 에서) Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 의 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 논문과 함께 발표했습니다. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (Studio Ousia 에서) Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 의 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 411e5ec1dddd2e..dac4d1cb68c9a6 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -345,6 +345,7 @@ conda install -c huggingface transformers 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (来自 South China University of Technology) 伴随论文 [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) 由 Jiapeng Wang, Lianwen Jin, Kai Ding 发布。 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (来自 The FAIR team of Meta AI) 伴随论文 [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) 由 Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample 发布。 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (来自 The FAIR team of Meta AI) 伴随论文 [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) 由 Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. 发布。 +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (来自 Microsoft Research & University of Wisconsin-Madison) 伴随论文 [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) 由 Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee 发布。 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (来自 Google AI) released 伴随论文 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 由 Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 发布。 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (来自 Studio Ousia) 伴随论文 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 由 Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 2944be140a40a0..8aaeda231d39f0 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -357,6 +357,7 @@ conda install -c huggingface transformers 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (from South China University of Technology) released with the paper [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) by Jiapeng Wang, Lianwen Jin, Kai Ding. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.. +1. **[Llava](https://huggingface.co/docs/transformers/main/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7fc684ee830db9..d6312509c49e10 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -703,6 +703,8 @@ title: LayoutXLM - local: model_doc/lilt title: LiLT + - local: model_doc/llava + title: Llava - local: model_doc/lxmert title: LXMERT - local: model_doc/matcha diff --git a/docs/source/en/index.md b/docs/source/en/index.md index b6e40672b327ee..2a055e56661ab8 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -169,6 +169,7 @@ Flax), PyTorch, and/or TensorFlow. | [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ | | [LLaMA](model_doc/llama) | ✅ | ❌ | ✅ | | [Llama2](model_doc/llama2) | ✅ | ❌ | ✅ | +| [Llava](model_doc/llava) | ✅ | ❌ | ❌ | | [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ | | [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ | | [LUKE](model_doc/luke) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md new file mode 100644 index 00000000000000..c21438dc115521 --- /dev/null +++ b/docs/source/en/model_doc/llava.md @@ -0,0 +1,67 @@ + + +# Llava + +## Overview + +Llava is an open-source chatbot trained by fine-tuning LlamA/Vicuna on GPT-generated multimodal instruction-following data. It is an auto-regressive language model, based on the transformer architecture. In other words, it is an multi-modal version of LLMs fine-tuned for chat / instructions. + +The Llava model was proposed in [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/pdf/2310.03744) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. + +The abstract from the paper is the following: + +*Large multimodal models (LMM) have recently shown encouraging progress with visual instruction tuning. In this note, we show that the fully-connected vision-language cross-modal connector in LLaVA is surprisingly powerful and data-efficient. With simple modifications to LLaVA, namely, using CLIP-ViT-L-336px with an MLP projection and adding academic-task-oriented VQA data with simple response formatting prompts, we establish stronger baselines that achieve state-of-the-art across 11 benchmarks. Our final 13B checkpoint uses merely 1.2M publicly available data, and finishes full training in ∼1 day on a single 8-A100 node. We hope this can make state-of-the-art LMM research more accessible. Code and model will be publicly available* + +Tips: + +- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. + +- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. + +- For better results, we recommend users to prompt the model with the correct prompt format: + +```bash +"USER: \nASSISTANT:" +``` + +For multiple turns conversation: + +```bash +"USER: \nASSISTANT: USER: ASSISTANT: USER: ASSISTANT:" +``` + +This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ) and [ybelkada](https://huggingface.co/ybelkada). +The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/main/llava). + +Check out also this [Google Colab demo](https://colab.research.google.com/drive/1qsl6cd2c8gGtEW1xV5io7S8NHh-Cp1TV?usp=sharing) on how to run Llava on a free-tier Google colab instance. + +### Using Flash Attention 2 + +Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one). + +## LlavaConfig + +[[autodoc]] LlavaConfig + +## LlavaProcessor + +[[autodoc]] LlavaProcessor + +## LlavaForConditionalGeneration + +[[autodoc]] LlavaForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b762b1ca90c289..f88ed87b68b808 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -542,6 +542,10 @@ "models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"], "models.lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"], "models.llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], + "models.llava": [ + "LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "LlavaConfig", + ], "models.longformer": [ "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", @@ -2433,6 +2437,14 @@ "LlamaPreTrainedModel", ] ) + _import_structure["models.llava"].extend( + [ + "LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "LlavaForConditionalGeneration", + "LlavaPreTrainedModel", + "LlavaProcessor", + ] + ) _import_structure["models.longformer"].extend( [ "LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5170,6 +5182,10 @@ from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig from .models.lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig from .models.llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig + from .models.llava import ( + LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, + LlavaConfig, + ) from .models.longformer import ( LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, @@ -6869,11 +6885,12 @@ LiltModel, LiltPreTrainedModel, ) - from .models.llama import ( - LlamaForCausalLM, - LlamaForSequenceClassification, - LlamaModel, - LlamaPreTrainedModel, + from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + from .models.llava import ( + LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + LlavaForConditionalGeneration, + LlavaPreTrainedModel, + LlavaProcessor, ) from .models.longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 41796f9766c397..4442caa811575e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -119,6 +119,7 @@ levit, lilt, llama, + llava, longformer, longt5, luke, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1d1a64a36a7f83..2094752fdd73d8 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -56,6 +56,7 @@ ("chinese_clip", "ChineseCLIPConfig"), ("clap", "ClapConfig"), ("clip", "CLIPConfig"), + ("clip_vision_model", "CLIPVisionConfig"), ("clipseg", "CLIPSegConfig"), ("clvp", "ClvpConfig"), ("code_llama", "LlamaConfig"), @@ -126,6 +127,7 @@ ("levit", "LevitConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), + ("llava", "LlavaConfig"), ("longformer", "LongformerConfig"), ("longt5", "LongT5Config"), ("luke", "LukeConfig"), @@ -348,6 +350,7 @@ ("levit", "LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("lilt", "LILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("llama", "LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("llava", "LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("longt5", "LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -494,6 +497,7 @@ ("chinese_clip", "Chinese-CLIP"), ("clap", "CLAP"), ("clip", "CLIP"), + ("clip_vision_model", "CLIPVisionModel"), ("clipseg", "CLIPSeg"), ("clvp", "CLVP"), ("code_llama", "CodeLlama"), @@ -573,6 +577,7 @@ ("lilt", "LiLT"), ("llama", "LLaMA"), ("llama2", "Llama2"), + ("llava", "Llava"), ("longformer", "Longformer"), ("longt5", "LongT5"), ("luke", "LUKE"), @@ -740,6 +745,7 @@ ("kosmos-2", "kosmos2"), ("maskformer-swin", "maskformer"), ("xclip", "x_clip"), + ("clip_vision_model", "clip"), ] ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 7d26d668ab2a93..32136b75e79c9b 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -75,6 +75,7 @@ ("layoutlmv2", "LayoutLMv2ImageProcessor"), ("layoutlmv3", "LayoutLMv3ImageProcessor"), ("levit", "LevitImageProcessor"), + ("llava", "CLIPImageProcessor"), ("mask2former", "Mask2FormerImageProcessor"), ("maskformer", "MaskFormerImageProcessor"), ("mgp-str", "ViTImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ef9c9d0354cd3f..9f1b509187ae81 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -59,6 +59,7 @@ ("chinese_clip", "ChineseCLIPModel"), ("clap", "ClapModel"), ("clip", "CLIPModel"), + ("clip_vision_model", "CLIPVisionModel"), ("clipseg", "CLIPSegModel"), ("clvp", "ClvpModelForConditionalGeneration"), ("code_llama", "LlamaModel"), @@ -270,6 +271,7 @@ ("ibert", "IBertForMaskedLM"), ("idefics", "IdeficsForVisionText2Text"), ("layoutlm", "LayoutLMForMaskedLM"), + ("llava", "LlavaForConditionalGeneration"), ("longformer", "LongformerForMaskedLM"), ("luke", "LukeForMaskedLM"), ("lxmert", "LxmertForPreTraining"), @@ -592,6 +594,7 @@ ("git", "GitForCausalLM"), ("instructblip", "InstructBlipForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("llava", "LlavaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), ] diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index e1b3bac2de050c..457fdb107f104e 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -64,6 +64,7 @@ ("kosmos-2", "Kosmos2Processor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), + ("llava", "LlavaProcessor"), ("markuplm", "MarkupLMProcessor"), ("mctct", "MCTCTProcessor"), ("mgp-str", "MgpstrProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 5d0a337f428781..81e17c8d61135f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -203,6 +203,7 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), ( "longt5", diff --git a/src/transformers/models/llava/__init__.py b/src/transformers/models/llava/__init__.py new file mode 100644 index 00000000000000..11aedf9476cfaa --- /dev/null +++ b/src/transformers/models/llava/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_llava": ["LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlavaConfig"]} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_llava"] = [ + "LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "LlavaForConditionalGeneration", + "LlavaPreTrainedModel", + ] + _import_structure["processing_llava"] = ["LlavaProcessor"] + + +if TYPE_CHECKING: + from .configuration_llava import LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlavaConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_llava import ( + LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + LlavaForConditionalGeneration, + LlavaPreTrainedModel, + ) + from .processing_llava import LlavaProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py new file mode 100644 index 00000000000000..1f174bc1b4237e --- /dev/null +++ b/src/transformers/models/llava/configuration_llava.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Llava model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + +LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json", +} + + +class LlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an + Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llava-9B. + + e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`LlavaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`] + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llava" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + vocab_size=32000, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.vocab_size = vocab_size + + self.vision_config = vision_config + + if isinstance(self.vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + self.vocab_size = self.vocab_size + + self.text_config = text_config + + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + self.vocab_size = self.text_config.vocab_size + elif text_config is None: + self.text_config = CONFIG_MAPPING["llama"]() + + super().__init__(**kwargs) diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py new file mode 100644 index 00000000000000..65b58236db1053 --- /dev/null +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -0,0 +1,126 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + CLIPImageProcessor, + LlavaConfig, + LlavaForConditionalGeneration, + LlavaProcessor, +) + + +KEYS_TO_MODIFY_MAPPING = { + "model.vision_tower.": "", + "model.mm_projector": "multi_modal_projector", + "model": "model.model", + "vision_model.model": "vision_model", + "lm_head": "language_model.lm_head", + "model.model": "language_model.model", + "multi_modal_projector.0": "multi_modal_projector.linear_1", + "multi_modal_projector.2": "multi_modal_projector.linear_2", +} + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): + torch.set_default_dtype(torch.float16) + text_config = AutoConfig.from_pretrained(text_model_id) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special=True) + tokenizer.add_special_tokens({"pad_token": ""}) + + image_processor = CLIPImageProcessor.from_pretrained(vision_model_id) + + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + + config = LlavaConfig(text_config=text_config) + config.pad_token_id = 32001 + + with torch.device("meta"): + model = LlavaForConditionalGeneration(config) + + # Pad to 64 for performance reasons + pad_shape = 64 + + state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") + + state_dict = torch.load(state_dict_path, map_location="cpu") + state_dict = convert_state_dict_to_hf(state_dict) + model.load_state_dict(state_dict, strict=True, assign=True) + + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), + dim=0, + ) + model.config.vocab_size = model.config.vocab_size + pad_shape + model.config.text_config.vocab_size = model.config.text_config.vocab_size + pad_shape + + model.push_to_hub(output_hub_path) + processor.push_to_hub(output_hub_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text_model_id", + help="Hub location of the text model", + ) + parser.add_argument( + "--vision_model_id", + help="Hub location of the vision model", + ) + parser.add_argument( + "--output_hub_path", + help="Location on the hub of the converted model", + ) + parser.add_argument( + "--old_state_dict_id", + help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + ) + args = parser.parse_args() + convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py new file mode 100644 index 00000000000000..88e5b62121c746 --- /dev/null +++ b/src/transformers/models/llava/modeling_llava.py @@ -0,0 +1,485 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Llava model.""" +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...modeling_outputs import ModelOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_llava import LlavaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlavaConfig" + +LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-1.5-13b-hf", + "llava-hf/bakLlava-v1-hf", + # See all Llava models at https://huggingface.co/models?filter=llava +] + + +@dataclass +# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava +class LlavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_START_DOCSTRING, +) +class LlavaPreTrainedModel(PreTrainedModel): + config_class = LlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _supports_flash_attn_2 = True + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + LLAVA_START_DOCSTRING, +) +class LlavaForConditionalGeneration(LlavaPreTrainedModel): + def __init__(self, config: LlavaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.vocab_size + use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False) + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, use_flash_attention_2=use_flash_attention_2 + ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, position_ids + ): + nb_images, image_hidden_dim, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where image tokens are + image_token_mask = input_ids == self.config.image_token_index + num_image_tokens = torch.sum(image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_image_tokens.max() * (image_hidden_dim - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((image_token_mask * (image_hidden_dim - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None] + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(image_token_mask)} while" + f" the number of image given to the model is {nb_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + return final_embedding, final_attention_mask, position_ids + + @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> processor = AutoProcessor.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "\nUSER: What's the content of the image?\nASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=text, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "There seems to be a stop sign" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, position_ids + ) + if labels is None: + labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) + else: + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) + # Get the target length + target_seqlen = first_layer_past_key_value.shape[-1] + 1 + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Zero-out the places where we don't need to attend + extended_attention_mask[batch_index, non_attended_tokens] = 0 + + attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, **kwargs + ): + # Call `prepare_inputs_for_generation` from the LM + model_input = self.language_model.prepare_inputs_for_generation( + input_ids, past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + model_input.update({"pixel_values": pixel_values}) + return model_input + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py new file mode 100644 index 00000000000000..373fccbce19286 --- /dev/null +++ b/src/transformers/models/llava/processing_llava.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + + +import warnings +from typing import Callable, List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class LlavaProcessor(ProcessorMixin): + r""" + Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. + + [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "CLIPImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.__init__ + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + transform: Callable = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + transform (`Callable`, *optional*): + A custom transform function that accepts a single image can be passed for training. For example, + `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific + set of transforms will be applied to the images + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None: + pixel_values = self.image_processor(images, transform=transform, return_tensors=return_tensors)[ + "pixel_values" + ] + else: + pixel_values = None + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7fc4ccea5e48d4..0c3cdd67405ef8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4641,6 +4641,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class LlavaForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LlavaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LlavaProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/llava/__init__.py b/tests/models/llava/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py new file mode 100644 index 00000000000000..23daaecc4e62c0 --- /dev/null +++ b/tests/models/llava/test_modeling_llava.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Llava model. """ + +import gc +import unittest + +import requests + +from transformers import ( + AutoProcessor, + LlavaConfig, + LlavaForConditionalGeneration, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class LlavaVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=0, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + text_config={ + "model_type": "llama", + "seq_length": 7, + "is_training": True, + "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 0, + }, + is_training=True, + vision_config={ + "batch_size": 12, + "image_size": 30, + "patch_size": 2, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = 3 + self.image_size = 336 + self.encoder_seq_length = 231 + + def get_config(self): + return LlavaConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + # we are giving 3 images let's make sure we pass in 3 image tokens + input_ids[:, 1] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `LlavaForConditionalGeneration`. + """ + + all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {} + fx_compatible = False + test_pruning = False + test_resize_embeddings = True + test_head_masking = False + + def setUp(self): + self.model_tester = LlavaVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=LlavaConfig, has_text_modality=False) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + +@require_torch +class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("llava-hf/bakLlava-v1-hf") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + # Let' s make sure we test the preprocessing to replace what is used + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True) + + prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" + image_file = "https://llava-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(prompt, raw_image, return_tensors="pt") + + EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + + output = model.generate(**inputs, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "llava-hf/llava-1.5-7b-hf" + + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:" + image_file = "https://llava-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + + output = model.generate(**inputs, max_new_tokens=900, do_sample=False) + EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip + + self.assertEqual( + processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "llava-hf/llava-1.5-7b-hf" + + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id, pad_token="") + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: the water is calm and clear\n\nThe image shows a wooden pier on a lake, with a', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + + self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch(self): + # Let' s make sure we test the preprocessing to replace what is used + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True) + # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = self.processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['\nUSER: What are the things I should be cautious about when I visit this place? What should I bring with me\nASSISTANT: When visiting this place, bring a camera, Park rules may apply, Per person, Sunrise over', '\nUSER: What is this?\nASSISTANT: Two cats lying on a bed!\nUSER: And this? a dock on a lake with two cats on it (Photo credit: Jocelyn R.'] # fmt: skip + self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) diff --git a/tests/pipelines/test_pipelines_image_to_text.py b/tests/pipelines/test_pipelines_image_to_text.py index 7514f17919b1f0..4f5b5780277bfc 100644 --- a/tests/pipelines/test_pipelines_image_to_text.py +++ b/tests/pipelines/test_pipelines_image_to_text.py @@ -252,3 +252,24 @@ def test_large_model_tf(self): [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}], ], ) + + @slow + @require_torch + def test_conditional_generation_llava(self): + pipe = pipeline("image-to-text", model="llava-hf/bakLlava-v1-hf") + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + prompt = ( + "\nUSER: What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud?\nASSISTANT:" + ) + + outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200}) + self.assertEqual( + outputs, + [ + { + "generated_text": " \nUSER: What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud?\nASSISTANT: Lava" + } + ], + ) diff --git a/utils/check_copies.py b/utils/check_copies.py index c536eb800ff084..4fd3a7c23e2b26 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -673,6 +673,7 @@ def check_model_list_copy(overwrite: bool = False): "TimmBackbone", "Vision Encoder decoder", "VisionTextDualEncoder", + "CLIPVisionModel", ] # Template for new entries to add in the main README when we have missing models. diff --git a/utils/check_table.py b/utils/check_table.py index deca8fc8d2e569..6eec4754032e8e 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -171,6 +171,7 @@ def _center_text(text: str, width: int) -> str: "XLS-R": "Wav2Vec2", "XLSR-Wav2Vec2": "Wav2Vec2", } +MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel"] def get_model_table_from_auto_modules() -> str: @@ -243,6 +244,8 @@ def get_model_table_from_auto_modules() -> str: check = {True: "✅", False: "❌"} for name in model_names: + if name in MODEL_NAMES_TO_IGNORE: + continue if name in MODEL_NAMES_WITH_SAME_CONFIG.keys(): prefix = model_name_to_prefix[MODEL_NAMES_WITH_SAME_CONFIG[name]] else: diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 4e0d518a7ae99c..5d0c4ba2b26a8d 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -146,6 +146,7 @@ docs/source/en/model_doc/levit.md docs/source/en/model_doc/lilt.md docs/source/en/model_doc/llama.md docs/source/en/model_doc/llama2.md +docs/source/en/model_doc/llava.md docs/source/en/model_doc/longformer.md docs/source/en/model_doc/longt5.md docs/source/en/model_doc/luke.md @@ -294,7 +295,7 @@ docs/source/en/serialization.md docs/source/en/tasks/asr.md docs/source/en/tasks/audio_classification.md docs/source/en/tasks/document_question_answering.md -docs/source/en/tasks/idefics.md # causes other tests to fail +docs/source/en/tasks/idefics.md docs/source/en/tasks/image_captioning.md docs/source/en/tasks/image_classification.md docs/source/en/tasks/language_modeling.md @@ -432,7 +433,7 @@ src/transformers/models/blip/modeling_blip_text.py src/transformers/models/blip/modeling_tf_blip_text.py src/transformers/models/blip_2/configuration_blip_2.py src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py -src/transformers/models/blip_2/modeling_blip_2.py # causes other tests to fail +src/transformers/models/blip_2/modeling_blip_2.py src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py src/transformers/models/bloom/modeling_bloom.py src/transformers/models/bloom/modeling_flax_bloom.py @@ -634,6 +635,8 @@ src/transformers/models/lilt/configuration_lilt.py src/transformers/models/llama/configuration_llama.py src/transformers/models/llama/convert_llama_weights_to_hf.py src/transformers/models/llama/modeling_llama.py +src/transformers/models/llava/configuration_llava.py +src/transformers/models/llava/modeling_llava.py src/transformers/models/longformer/configuration_longformer.py src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py src/transformers/models/longt5/configuration_longt5.py