diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml
index e03cbb353bd3..ccb6c25e14f7 100644
--- a/.github/workflows/nv-pre-compile-ops.yml
+++ b/.github/workflows/nv-pre-compile-ops.yml
@@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
- TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install .
+ TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
diff --git a/.github/workflows/nv-torch110-p40.yml b/.github/workflows/nv-torch110-p40.yml
index 95fccb2de9d3..45f3e0438233 100644
--- a/.github/workflows/nv-torch110-p40.yml
+++ b/.github/workflows/nv-torch110-p40.yml
@@ -3,6 +3,7 @@ name: nv-torch110-p40
on:
schedule:
- cron: "0 0 * * *"
+ workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
diff --git a/.github/workflows/nv-torch110-v100.yml b/.github/workflows/nv-torch110-v100.yml
index a3e39a9e5b22..1fd8aaac0ffa 100644
--- a/.github/workflows/nv-torch110-v100.yml
+++ b/.github/workflows/nv-torch110-v100.yml
@@ -3,6 +3,7 @@ name: nv-torch110-v100
on:
schedule:
- cron: "0 0 * * *"
+ workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 36fa34a42744..6b11b3acba51 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -58,7 +58,7 @@ repos:
# Do not check files that are automatically generated
'--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json',
'--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word
- '--ignore-words-list=youn,unsupport', # Word used in error messages that need rewording
+ '--ignore-words-list=youn,unsupport,noe', # Word used in error messages that need rewording
--check-filenames,
--check-hidden
]
diff --git a/README.md b/README.md
index 4999a485f4ce..6aef71b8e66e 100755
--- a/README.md
+++ b/README.md
@@ -15,11 +15,11 @@
## Latest News
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
+* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
-* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
+* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
-* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
---
@@ -35,9 +35,9 @@
---
-# DeepSpeed's three innovation pillars
+# DeepSpeed's four innovation pillars
-
+
## DeepSpeed-Training
@@ -53,6 +53,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor,
To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression)
+## DeepSpeed4Science
+
+In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](https://www.deepspeed.ai/deepspeed4science/)
+
---
# DeepSpeed Software Suite
diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py
index fe0e66768d45..a87ff3c1d223 100644
--- a/accelerator/abstract_accelerator.py
+++ b/accelerator/abstract_accelerator.py
@@ -185,6 +185,10 @@ def lazy_call(self, callback):
def communication_backend_name(self):
...
+ @abc.abstractmethod
+ def is_triton_supported(self):
+ ...
+
# Tensor operations
@property
@abc.abstractmethod
diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py
index 11518d31e069..4de4ad93c2bb 100644
--- a/accelerator/cpu_accelerator.py
+++ b/accelerator/cpu_accelerator.py
@@ -182,6 +182,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name
+ def is_triton_supported(self):
+ return False
+
# Data types
def is_bf16_supported(self):
return True
diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py
index 9c1e0d22785e..045cce510a90 100644
--- a/accelerator/cuda_accelerator.py
+++ b/accelerator/cuda_accelerator.py
@@ -173,6 +173,13 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name
+ def is_triton_supported(self):
+ major, _ = torch.cuda.get_device_capability()
+ if major >= 8:
+ return True
+ else:
+ return False
+
# Tensor operations
@property
diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py
index 68e3b6df05c9..63a92f250898 100644
--- a/accelerator/mps_accelerator.py
+++ b/accelerator/mps_accelerator.py
@@ -138,6 +138,9 @@ def is_bf16_supported(self):
def is_fp16_supported(self):
return False
+ def supported_dtypes(self):
+ return [torch.float]
+
# Misc
def amp(self):
return
@@ -157,6 +160,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name
+ def is_triton_supported(self):
+ return False
+
# Tensor operations
@property
def BFloat16Tensor(self):
diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py
index 5678a0266386..206bc1dfaa1b 100644
--- a/accelerator/npu_accelerator.py
+++ b/accelerator/npu_accelerator.py
@@ -158,6 +158,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name
+ def is_triton_supported(self):
+ return False
+
# Tensor operations
@property
diff --git a/blogs/deepspeed4science/README.md b/blogs/deepspeed4science/README.md
new file mode 100644
index 000000000000..2a80ea2e749e
--- /dev/null
+++ b/blogs/deepspeed4science/README.md
@@ -0,0 +1,7 @@
+
+
+# Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies
+
+
+
+[https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)
diff --git a/blogs/deepspeed4science/chinese/README.md b/blogs/deepspeed4science/chinese/README.md
new file mode 100644
index 000000000000..3ffddfb16fe5
--- /dev/null
+++ b/blogs/deepspeed4science/chinese/README.md
@@ -0,0 +1,145 @@
+
+
+# DeepSpeed4Science:利用先进的AI系统优化技术实现科学发现
+
+
+
+*此博客为英文博客[Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)的官方翻译*
+
+
+
+
+*图1:DeepSpeed4Science方法概述:专为加速科学发现和应对其复杂性而量身定制的AI系统技术开发。*
+
+
+## 简介
+
+在接下来的十年中,深度学习可能会彻底改变自然科学,增强我们对自然现象进行建模和预测的能力。这可能预示着科学探索的新时代,为从药物开发到可再生能源的各个领域带来重大进展。为了响应这一机会以及微软“予力全球每一人、每一组织,成就不凡”的使命,[微软DeepSpeed团队](https://www.deepspeed.ai/)启动了一个名为[DeepSpeed4Science](https://deepspeed4science.ai/)的新计划,旨在通过AI系统技术创新帮助领域专家解锁当今最大的科学之谜。
+
+[DeepSpeed](https://www.deepspeed.ai/)系统是由微软开发的业界领先的开源AI系统框架,它为各种AI硬件上的深度学习训练和推理提供了前所未有的规模和速度。图1展示了我们对DeepSpeed4Science这一新计划的基本方法。通过利用DeepSpeed当前的技术方案(训练、推理和压缩)作为基础技术推动器,DeepSpeed4Science将创建一套专为加速科学发现而量身定制的AI系统技术,以应对其独特的复杂性,超越用于加速通用大型语言模型(LLMs)的常见技术方法。我们与拥有科学AI模型的内部和外部团队紧密合作,以发现和解决领域特定AI系统的挑战。这包括气候科学、药物设计、生物学理解、分子动力学模拟、癌症诊断和监测、催化剂/材料发现、和其他领域。
+
+我们的长期愿景是将DeepSpeed4Science发展成一个用于分享支持科学发现的先进AI技术的软件平台和统一代码仓库。DeepSpeed4Science的设计旨在包容性,呼应微软的[“AI for Good”承诺](https://www.microsoft.com/en-us/ai/ai-for-good)。这体现在该计划对一系列标志性科学模型的支持上,他们代表了一些最关键的AI4Science应用场景。在这篇博客中,我们展示了DeepSpeed4Science如何帮助解决结构生物学研究中的两个关键AI系统挑战:(1) 解决了以Evoformer为中心的蛋白质结构预测模型中的内存爆炸问题,以及(2)为更好地理解引发大流行的病毒的进化提供AI模型长序列支持。
+
+## 我们的初期主要合作者
+
+DeepSpeed4Science的新系统技术可以用于很多推动科学边界的标志性模型,赋能AI驱动的科学发现。目前,DeepSpeed4Science很荣幸地支持来自[微软研究院AI4Science](https://www.microsoft.com/en-us/research/lab/microsoft-research-ai4science/)、[微软WebXT/Bing](https://www.msn.com/en-us/weather/forecast/)、[美国能源部国家实验室](https://www.energy.gov/national-laboratories)和多所大学的几个关键科学模型。
+
+### 微软内部合作伙伴
+
+#### 科学基础模型(Scientific Foundation Model,SFM),微软研究院AI4Science
+
+
+
+
+
+*图2:科学基础模型(Scientific Foundation Model,SFM)及其当前探索:Distributional Graphormer。*
+
+
+科学基础模型(SFM)旨在创建一个统一的大规模基础模型,以支持自然科学发现,支持多种输入、多个科学领域(例如,药物、材料、生物学、健康等)和计算任务。DeepSpeed4Science合作伙伴关系将为SFM团队提供新的训练和推理技术,以支持他们的新生成AI方法(例如[Distributional Graphormer](https://www.microsoft.com/en-us/research/blog/distributional-graphormer-toward-equilibrium-distribution-prediction-for-molecular-systems/))这样的项目进行持续研究。
+
+#### ClimaX,微软研究院AI4Science
+
+
+
+
+*图3:ClimaX是第一个设计用于执行各种天气和气候建模任务的基础模型。*
+
+
+我们的气候正在发生变化,导致极端天气事件的频率增加。为了减轻负面影响,预测这些事件将发生的地方变得越来越重要。[ClimaX](https://www.microsoft.com/en-us/research/group/autonomous-systems-group-robotics/articles/introducing-climax-the-first-foundation-model-for-weather-and-climate/)是第一个设计用于执行各种天气和气候建模任务的基础模型。它可以吸收许多具有不同变量和分辨率的数据集以提高天气预报的准确性。DeepSpeed4Science正在为ClimaX创建新的系统支持和加速策略,以高效地预训练/微调更大的基础模型,同时处理非常大的高分辨率图像数据(例如,数十到数百PB)和长序列。
+
+#### AI驱动的第一性原理分子动力学(AI Powered Ab Initio Molecular Dynamics,AI2MD),微软研究院AI4Science
+
+
+
+
+*图4:一百万步的分子动力学模拟:RBD-蛋白(RBD-protein)与蛋白抑制剂(protein inhibitor)相互作用。*
+
+
+这个项目模拟了使用[AI驱动的力场模型](https://www.microsoft.com/en-us/research/publication/ai2bmd-efficient-characterization-of-protein-dynamics-with-ab-initio-accuracy/)进行近似第一性原理计算精度的大型(百万原子)分子系统的动态模拟,同时保持了经典分子动力学的效率和可扩展性。这些模拟足够高效,可以生成足够长的轨迹来观察化学上有意义的事件。通常,这个过程需要数百万甚至数十亿的推理步骤。这对优化图神经网络(GNN)+ LLM模型的推理速度提出了重大挑战,DeepSpeed4Science将为此提供新的加速策略。
+
+#### 微软天气,微软WebXT/Bing
+
+
+
+
+*图5:微软降水预报(每4分钟一次对接下来4小时进行预测)。*
+
+
+[微软天气](https://www.msn.com/en-us/weather/forecast/)提供精确的天气信息,[帮助用户为他们的生活方式、健康、工作和活动做出更好的决策](https://blogs.windows.com/windowsexperience/2022/08/31/microsoft-joins-noaas-weather-ready-nation-ambassador-initiative-to-help-improve-americas-readiness-and-response-to-weather-events/)——包括每小时多次更新的准确的10天全球天气预报。此前,微软天气受益于DeepSpeed技术,加速了他们的多GPU训练环境。目前,DeepSpeed4Science正在与微软WebXT天气预报团队合作,进一步增强微软天气预报服务的最新功能和改进。
+
+### 外部合作者
+
+DeepSpeed4Science的旅程始于两个开创性的基于LLM的结构生物学研究AI模型:来自哥伦比亚大学的[OpenFold](https://openfold.io/),一个开源的高保真蛋白质结构预测模型;以及来自[阿贡国家实验室](https://www.anl.gov/)的[GenSLMs](https://github.com/ramanathanlab/genslm),一个获得[ACM戈登贝尔奖](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)的用于学习SARS-CoV-2(COVID-19)基因组的进化的语言模型。作为此次发布的特色展示,它们代表了当今AI驱动的结构生物学研究面临的两个常见AI系统挑战。我们将在下一节中讨论DeepSpeed4Science如何赋能这些科学研究。
+
+此外,DeepSpeed4Science最近扩大了其范围,以支持更多样的科学模型。例如,在我们与阿贡国家实验室合作训练[Aurora Exascale系统](https://www.anl.gov/aurora)上的万亿参数科学模型的工作中,DeepSpeed4Science技术将帮助他们达到这一关键任务所需的性能要求和可扩展性。此外,通过与[橡树岭国家实验室](https://ai-roadmap.ornl.gov/)和[国家癌症研究所(NCI)](https://www.cancer.gov/)合作进行癌症监测,DeepSpeed4Science将帮助从非结构化的临床文本中高保真地提取和分类信息,以供[MOSSAIC项目](https://www.olcf.ornl.gov/tag/mossaic/)使用。[Brookhaven国家实验室](https://www.bnl.gov/world/)还将采用DeepSpeed4Science技术,支持使用LLMs开发大型数字双胞胎模型,以便为清洁能源研究产生更真实的模拟数据。您可以在[deepspeed4science.ai](https://deepspeed4science.ai/)上找到有关我们外部合作者及其科学任务的更多详细信息。
+
+## 合作展示
+
+### 展示(I):DeepSpeed4Science通过DS4Sci_EvoformerAttention消除以Evoformer为中心的结构生物学模型的内存爆炸问题
+
+
+
+
+
+*图6:在训练过程中OpenFold对PDB链7B3A_A的预测。*
+
+
+[OpenFold](https://github.com/aqlaboratory/openfold)是DeepMind的[AlphaFold2](https://alphafold.com/)的开源社区再现,使其可以在新数据集上训练或微调AlphaFold2。研究人员已经使用它从头开始重新训练AlphaFold2,生成新的模型参数集,研究AlphaFold2的早期训练阶段(图6),并开发新的蛋白质折叠系统。
+
+
+
+
+*图7:在OpenFold中,对多序列比对(MSA)Attention内核(包含偏差)变体的训练峰值内存需求。 (左) 使用在AlphaFold2中的EvoformerAttention的原始OpenFold实现。对于这些类型的蛋白质结构预测模型,在训练/推理中的内存爆炸问题是常见的。最先进的FlashAttention无法有效支持这样的Attention变体。 (右) DeepSpeed4Science的一种新解决方案DS4Sci_EvoformerAttention在不影响模型品质的条件下显著地减少了OpenFold的训练峰值内存需求(最多13倍)。*
+
+
+尽管OpenFold有使用最先进的系统技术进行性能和内存优化,但从头开始训练AlphaFold2仍然在计算上很昂贵。目前阶段的模型参数很小,只有9300万个参数,但它包含了几个需要非常大的中间内存的特殊Attention变体。在标准AlphaFold2训练的“微调”阶段,只是这些变体中的其中一个在半精度下就生成了超过12GB的张量,使其峰值内存要求远远超过了相同大小的语言模型。即使使用像activation checkpointing和DeepSpeed ZeRO优化这样的技术,这种内存爆炸问题仍然严重限制了可训练模型的序列长度和MSA深度。此外,近似策略可能会显著影响模型的准确性和收敛性,同时仍然导致内存爆炸,如图7左侧(橙色)所示。
+
+为了应对结构生物学研究(例如,蛋白质结构预测和平衡分布预测)中的这一常见系统挑战,DeepSpeed4Science通过为这类科学模型中广泛出现的注意力变体(即EvoformerAttention)设计定制的精确注意力内核来解决这一内存效率问题。具体来说,我们设计了一套由复杂的融合/矩阵分块策略和动态内存减少方法而组成的高内存效率DS4Sci_EvoformerAttention内核,作为高质量机器学习模块供更广泛的生物学研究社区使用。通过整合到OpenFold中,这些定制内核在训练期间提供了显著的加速,并显著减少了模型的训练和推理的峰值内存需求。这使得OpenFold可以用更大、更复杂的模型,使用更长的序列在更广泛的硬件上进行实验。关于这项技术的详细信息可以在[这里](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/)找到。
+
+### 展示(II):DeepSpeed4Science通过系统和算法方法为基因组基础模型(例如,GenSLMs)提供长序列支持
+
+
+
+
+*图8:GenSLMs:获2022年ACM 戈登贝尔奖的COVID基因组模型(基于GPT-NeoX的25B/33B模型)。它用于学习描述SARS-CoV-2基因组生物学意义的潜在空间。这个GIF展示了一个重要的蛋白质家族苹果酸脱氢酶(malate dehydrogenase)的根据重要特征(如序列长度和GC含量(核酸鸟嘌呤和胞嘧啶的含量与腺嘌呤和胸腺嘧啶的比率。它测量DNA链抵抗热的能力))着色的潜在空间的投影。*
+
+
+[GenSLMs](https://github.com/ramanathanlab/genslm),一个来自阿贡国家实验室的[2022年ACM 戈登贝尔奖获奖](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)的基因组模型,可以通过大型语言模型(LLMs)的基因组数据训练来学习SARS-CoV-2(COVID-19)基因组的进化。它旨在改变如何识别和分类引发大流行的病毒(特别是SARS-CoV-2)的新变种。GenSLMs代表了第一批可以泛化到其他预测任务的基因组基础模型。对潜在空间的良好理解可以帮助GenSLMs处理超出仅仅是病毒序列的新领域,并扩展它们模拟细菌病原体甚至真核生物的能力(例如,理解功能、途径成员资格和进化关系等事物)。为了实现这一科学目标,GenSLMs和类似的模型需要非常长的序列支持用于训练和推理,这超出了像[FlashAttention](https://arxiv.org/abs/2307.08691)这样的通用LLM的长序列策略。通过DeepSpeed4Science的新设计,科学家现在可以构建和训练具有显著更长的上下文窗口的模型,允许他们探索以前无法访问的关系。
+
+
+
+
+*图9:由不同框架在不同规模下支持的两个GenSLMs模型的最大序列长度。使用NVIDIA DGX,每个节点有八个40G A100 GPU。*
+
+
+具体在系统层面,我们发布了包括[长序列支持和其他新优化](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support)的最新的[Megatron-DeepSpeed框架](https://github.com/microsoft/Megatron-DeepSpeed)。科学家现在可以通过我们新添加的内存优化技术(如注意力掩码异步处理和位置码分割)、张量并行、流水线并行、序列并行、基于ZeRO的数据并行和模型状态异步处理等技术的协同组合,用更长的序列训练他们的GenSLMs等大型科学模型。图9展示了我们的新版本使GenSLMs的25B和33B模型的最长序列长度分别比之前的Megatron-DeepSpeed版本增加了12倍和14倍。在支持的序列长度方面,这个新Megatron-DeepSpeed框架也显著地超过了NVIDIA的Megatron-LM(对于25B和33B模型分别高达9.8倍和9.1倍)。例如,阿贡实验室团队的GenSLMs 25B模型在64个GPU上的原始序列长度为42K,而现在可以用512K的核苷酸序列进行训练。这在不损失准确性的条件下大大提高了模型质量和科学发现的范围。对于那些更喜欢相对位置编码技术这样的算法策略的领域科学家,这个[新版本](https://deepspeed4science.ai/2023/09/18/model-showcase-genslms/)也进行了集成。
+
+## 总结和路线图
+
+我们非常自豪和兴奋地宣布DeepSpeed4Science计划以及几个研发亮点和成果。从今天开始,我们将在[deepspeed4science.ai](https://deepspeed4science.ai/)上介绍我们的新计划,包括关于我们的外部合作者的信息,以及当前和未来的DeepSpeed4Science技术发布。我们的一个高层次目标是推广广泛解决大规模科学发现的主要系统痛点的AI系统技术。我们希望全球的科学家们能够从DeepSpeed4Science通过开源软件解锁的新功能中受益。我们期待更好地了解阻碍您的科学发现的AI系统设计挑战。我们真诚地欢迎您的参与,帮助构建一个更有前途的AI4Science未来。请给我们发送电子邮件至。我们鼓励您在我们的[DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/)上报告问题、贡献PR、参与讨论。
+
+## 致谢
+
+**Core DeepSpeed4Science Team:**
+
+Shuaiwen Leon Song (DeepSpeed4Science lead), Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Xiaoxia (Shirley) Wu, Masahiro Tanaka, Martin Cai, Adam Graham, Charlie Zhou, Yuxiong He (DeepSpeed team lead)
+
+**Our Founding Collaborators (in alphabetical order):**
+
+**Argonne National Lab team:** Rick Stevens, Cristina Negri, Rao Kotamarthi, Venkatram Vishwanath, Arvind Ramanathan, Sam Foreman, Kyle Hippe, Troy Arcomano, Romit Maulik, Maxim Zvyagin, Alexander Brace, Yuntian Deng, Bin Zhang, Cindy Orozco Bohorquez, Austin Clyde, Bharat Kale, Danilo Perez-Rivera, Heng Ma, Carla M. Mann, Michael Irvin, J. Gregory Pauloski, Logan Ward, Valerie Hayot, Murali Emani, Zhen Xie, Diangen Lin, Maulik Shukla, Weili Nie, Josh Romero, Christian Dallago, Arash Vahdat, Chaowei Xiao, Thomas Gibbs, Ian Foster, James J. Davis, Michael E. Papka, Thomas Brettin, Anima Anandkumar
+
+**AMD:** Ivo Bolsen, Micheal Schulte, Bo Begole, Angela Dalton, Steve Reinhart, Ashwin Aji, Jalal Mahmud, Mahesh Balashibramanian
+
+**Brookhaven National Lab team:** Adolfy Hoisie, Shinjae Yoo, Yihui Ren.
+
+**Columbia University OpenFold team:** Mohammed AlQuraishi, Gustaf Ahdritz
+
+**Microsoft Research AI4Science team:** Christopher Bishop, Bonnie Kruft, Max Welling, Tie-Yan Liu, Christian Bodnar, Johannes Brandsetter, Wessel Bruinsma, Chan Cao, Yuan-Jyue Chen, Peggy Dai, Patrick Garvan, Liang He, Elizabeth Heider, PiPi Hu, Peiran Jin, Fusong Ju, Yatao Li, Chang Liu, Renqian Luo, Qi Meng, Frank Noe, Tao Qin, Janwei Zhu, Bin Shao, Yu Shi, Wenlei Shi, Gregor Simm, Megan Stanley, Lixin Sun, Yue Wang, Tong Wang, Zun Wang, Lijun Wu, Yingce Xia, Leo Xia, Shufang Xie, Shuxin Zheng, Jianwei Zhu
+
+**Oakridge National Lab team:** Prassana Balaprakash, Georgia Tourass
+
+**Princeton University:** William Tang, Kyle Felker, Alexey Svyatkovskiy (Microsoft liaison)
+
+**Rutgers University:** Hang Liu
+
+**WebXT Weather team:** Pete Luferenko, Divya Kumar, Jonathan Weyn, Ruixiong Zhang, Sylwester Klocek, Volodymyr Vragov
diff --git a/blogs/deepspeed4science/japanese/README.md b/blogs/deepspeed4science/japanese/README.md
new file mode 100644
index 000000000000..80fc137e16bb
--- /dev/null
+++ b/blogs/deepspeed4science/japanese/README.md
@@ -0,0 +1,145 @@
+
+
+# DeepSpeed4Scienceイニシアティブ: 洗練されたAIシステムのテクノロジーにより大規模な科学的発見を可能に
+
+
+
+*こちらは英語ブログ[Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)の公式の翻訳です*
+
+
+
+
+*図1:DeepSpeed4Scienceのアプローチ: 汎用の言語モデルのサポートを超え、科学的発見とその複雑さの解決に特化したAI技術を開発*
+
+
+## はじめに
+
+自然の出来事をモデル化し予測する深層学習の能力は急速に高まっており、次の10年間に、自然科学に革命を起こすかも知れません。薬の開発から再生可能エネルギーまでの各セクターで、大きな進展をもたらす新しい科学的探求の時代が到来するでしょう。「地球上のすべての人と組織がもっと多くのことを成し遂げられるようにする」というMicrosoftのミッションに従い、この機会に、[DeepSpeedチーム](https://www.deepspeed.ai/)では[DeepSpeed4Science](https://deepspeed4science.ai/)という新しいイニシアティブを立ち上げました。これは、AIシステム技術のイノベーションを通じて他に類を見ない技術を構築し、様々な分野の専門家が、科学分野における大きな謎を解き明かす手助けをすることを目指しています。
+
+[DeepSpeed](https://www.deepspeed.ai/)システムは、Microsoftが開発した、AI分野をリードするオープンソースのAIシステムのフレームワークであり、多様なAIハードウェア上での深層学習の訓練と推論において、前例のない規模と速度を実現します。図1は、この新しいDeepSpeed4Scienceイニシアティブでの基本的なアプローチを示しています。DeepSpeedの現在の柱となる技術(訓練、推論、圧縮)を基盤として活用しつつ、DeepSpeed4Scienceでは、大規模言語モデル(LLM)を加速するための汎用の技術的アプローチを超え、科学的発見を加速する目的で新たに構築された、一連のAIシステム技術を提供します。私たちは、重要な科学的ミッションを推進している、代表的な科学分野向けAIモデルを所有する内外のチームと連携し、ドメイン固有のAIシステムの課題を特定し、解決していきます。これには、気候科学、薬物設計、生物学的理解、分子動力学シミュレーション、がんの診断と監視、触媒/材料の発見、およびその他の分野が含まれます。
+
+私たちの長期的なビジョンは、DeepSpeed4Scienceを、科学的発見をサポートする先進的なAIシステム技術を共有するための新しいソフトウェアプラットフォームおよび統一的なリポジトリに発展させることです。DeepSpeed4Scienceは、Microsoftの[AI for Good](https://www.microsoft.com/en-us/ai/ai-for-good)のコミットメントを反映して、包括的に設計されています。このことは、AI4Scienceへのもっとも重要な投資の成果として構築された、様々な代表的モデルへの、DeepSpeed4Scienceイニシアティブによるサポートに現れています。このブログでは、DeepSpeed4Scienceが、構造生物学の研究における2つの重要なシステムの課題にどのように対処するかを紹介します:(1) Evoformer中心のタンパク質構造予測モデルをスケールアップする際に極めて大きなメモリが必要となる問題を解決し、(2) パンデミックを引き起こすウイルスの進化の様子をよりよく理解するための非常に長いシーケンスのサポートを可能にします。
+
+## 主要な初期コラボレータ
+
+DeepSpeed4Scienceによる新しいシステム技術はAI駆動の幅広い科学研究を強化するものです。現在、DeepSpeed4Scienceは、[Microsoft Research AI4Science](https://www.microsoft.com/en-us/research/lab/microsoft-research-ai4science/)、[Microsoft WebXT/Bing](https://www.msn.com/en-us/weather/forecast/)、[U.S. DoE National Labs](https://www.energy.gov/national-laboratories)、および複数の大学のいくつかの重要な科学モデルをサポートしています。
+
+### Microsoft内のパートナーシップ
+
+#### 科学基盤モデル (Scientific Foundation Model, SFM), Microsoft Research AI4Science
+
+
+
+
+
+*図2: 科学基盤モデル (Scientific foundation model, SFM) とその探索: Distributional Graphormer*
+
+
+科学的基盤モデル(SFM)は、多様なインプット、複数の科学領域(薬物、材料、生物学、健康など)、および計算タスクをサポートする、自然科学的発見を強化するための統一された大規模基盤モデルを作成することを目的としています。DeepSpeed4Scienceパートナーシップは、[Distributional Graphormer](https://www.microsoft.com/en-us/research/blog/distributional-graphormer-toward-equilibrium-distribution-prediction-for-molecular-systems/)などのMicrosoftの新しい生成AI手法などのプロジェクトに関する、SFMチームの継続的な研究を強化するための新しい訓練および推論テクノロジーを提供します。
+
+#### ClimaX, Microsoft Research AI4Science
+
+
+
+
+*図3: 天気・気候の多様なモデリングタスクのための最初の基盤モデルClimaX*
+
+
+気候の変化は、より頻繁な異常気象を引き起こしています。悪影響を軽減するため、これらのイベントが発生する場所を予測することがますます重要になっています。[ClimaX](https://www.microsoft.com/en-us/research/group/autonomous-systems-group-robotics/articles/introducing-climax-the-first-foundation-model-for-weather-and-climate/)は、さまざまな気象および気候モデリングタスクを実行するために設計された最初の基盤モデルです。さまざまな変数と解像度を持つ多くの異なるデータセットを扱えるため、天気予報の精度が向上する可能性があります。DeepSpeed4Scienceは、非常に大きな高解像度画像データ(数十から数百ペタバイトなど)を長いシーケンスで処理しながら、より大きな基盤モデルを効率的に事前訓練/ファインチューニングするためのClimaXの新しいシステムサポートを提供しています。
+
+#### AIを用いたAb Initio分子動力学法(AI Powered Ab Initio Molecular Dynamics,AI2MD),Microsoft Research AI4Science
+
+
+
+
+*図4: 100万ステップの分子動力学シミュレーション: RBD-proteinとprotein inhibitorの相互作用*
+
+
+このプロジェクトは、古典的な分子動力学の効率とスケーラビリティを維持しながら、[AIを利用した力場モデル](https://www.microsoft.com/en-us/research/publication/ai2bmd-efficient-characterization-of-protein-dynamics-with-ab-initio-accuracy/)を使用して、原理に基づく精度(ab initio accuracy)に近い精度で大規模(原子数で100万規模)な分子システムの力学をシミュレートします。このシミュレーションは、化学的に重要なイベントを観察するのに十分な長さの軌道を生成できる効率を実現しています。通常、このプロセスには数百万から数十億の推論ステップが必要です。これは、グラフニューラルネットワーク(GNN)+ LLMモデルの推論速度を最適化する上で大きな課題となります。DeepSpeed4Scienceは、この課題に対して、新しいシステムサポートを提供します。
+
+#### 天気 from Microsoft Start, Microsoft WebXT/Bing
+
+
+
+
+*図5: Microsoft Startにおける降水予想 (次の4時間について4分ごと)*
+
+
+[天気 from Microsoft Start](https://www.msn.com/en-us/weather/forecast/)は、[ユーザーがライフスタイル、健康、仕事、活動についてより適切な決定を下せるよう](https://blogs.windows.com/windowsexperience/2022/08/31/microsoft-joins-noaas-weather-ready-nation-ambassador-initiative-to-help-improve-americas-readiness-and-response-to-weather-events/)、正確な気象情報を提供します。 (1 時間ごとに複数回更新される、10 日間に渡る正確かつグローバルな天気予報など)。 以前にも、この天気予報は、DeepSpeedの技術を使用して、マルチ GPU を用いた訓練を高速化していました。現在、DeepSpeed4ScienceはMicrosoft WebXT気象チームと協力して、最先端の機能と更なる改善により、マイクロソフトの気象サービスをさらに強化しています。
+
+### 外部のコラボレータ
+
+DeepSpeed4Scienceは、構造生物学研究のための2つの先駆的なLLMベースのAIモデルを扱うことから始まりました: オープンソースのハイフィデリティタンパク質構造予測モデルであるコロンビア大学の[OpenFold](https://openfold.io/)と、SARS-CoV-2(COVID-19)ゲノムの進化を学習する、[Gordon Bell Special Prize](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)を受賞したゲノム用言語モデルである[アルゴンヌ国立研究所](https://www.anl.gov/)の[GenSLMs](https://github.com/ramanathanlab/genslm)です。次のセクションでは、今日のAI主導の構造生物学研究が直面している2つの一般的なAIシステムの課題を紹介し、DeepSpeed4Scienceが科学研究をどのように強化したかについて説明します。
+
+またDeepSpeed4Scienceは最近、より多様な科学モデルをサポートするために、その対象を拡大しました。たとえば、[Aurora Exascaleシステム](https://www.anl.gov/aurora)で、1兆パラメータの科学モデルを訓練するアルゴンヌ国立研究所との協力にあたって、DeepSpeed4Scienceテクノロジーは、求められるパフォーマンス要件とスケーラビリティを実現するのに重要な役割を果たします。さらに、DeepSpeed4Scienceは、がんの調査に関して、[オークリッジ国立研究所](https://ai-roadmap.ornl.gov/)および[国立がん研究所(NCI)](https://www.cancer.gov/)と協力することにより、[MOSSAICプロジェクト](https://www.olcf.ornl.gov/tag/mossaic/)の非構造化臨床テキストからの情報の高信頼度抽出と分類にも用いられます。さらに、DeepSpeed4Scienceのテクノロジーは、[ブルックヘブン国立研究所](https://www.bnl.gov/world/)にも採用され、LLMを使用してより現実的なシミュレーションデータを生成することにより、クリーンエネルギー研究用の大規模なデジタルツインモデルの開発をサポートします。外部のコラボレータとその科学ミッションに関するより詳細な情報は、[deepspeed4science.ai](https://deepspeed4science.ai/)に掲載しています。
+
+## パートナーシップの事例
+
+### 事例(I): DeepSpeed4ScienceのDS4Sci_EvoformerAttentionにより、Evoformerで構成された生物学モデルをスケールアップする際のメモリ問題を解決
+
+
+
+
+
+*図6: モデル学習の進行に伴うPDB chain 7B3A_AについてのOpenFoldの予測*
+
+
+[OpenFold](https://github.com/aqlaboratory/openfold)は、DeepMindによる[AlphaFold2](https://alphafold.com/)をオープンソースで再現したものであり、新しいデータセットでAlphaFold2を訓練またはファインチューニングすることを可能にします。研究者は、これを使用して、AlphaFold2をゼロから再訓練して新しいモデルパラメータを作成し、AlphaFold2の初期訓練フェーズを研究し(図6)、新しいタンパク質フォールディングシステムを開発しました。
+
+
+
+
+*図7: OpenFoldで可能な最大の訓練サンプル次元を持つ多重配列アライメント(MSA)アテンションカーネル(バイアス付き)のバリエーションを訓練するために必要なピークメモリ。(左)AlphaFold2で使用されているEvoformerAttentionを用いたオリジナルのOpenFold実装。この種のタンパク質構造予測モデルの訓練/推論では、極めて多くのメモリが必要とされることは一般的な課題となっている。特に、最新技術として広く知られるFlashAttentionでも、このような科学研究のためのアテンションのバリエーションを効果的にサポートできない。(右)DS4Sci_EvoformerAttentionと呼ばれるDeepSpeed4Scienceの新しい技術は、精度を落とすことなく、OpenFoldモデルの訓練に必要なピークメモリを1/13に大幅に削減する。*
+
+
+OpenFoldには、最先端のシステムテクノロジーを使用したパフォーマンスとメモリの最適化が含まれていますが、AlphaFold2をゼロから訓練することは依然として大きな計算コストがかかります。現段階でのモデルは、パラメータ数の絶対値は小さい(9,300万個)のですが、極めて大きなアクティベーションを持つアテンションのバリエーションが含まれています。標準的なAlphaFold2訓練のファインチューニングフェーズでは、これらのバリエーションのうちのの1つが生成したロジットテンソル(入力としてモデルに供給されるディープタンパク質MSAに対応するように設計されたもの)は、半精度浮動小数で12GBを超え、同等のサイズの言語モデルが使用するメモリを大幅に上回ります。Activation checkpointingや、DeepSpeed ZeRO 最適化などの手法を使用しても、非常に多くのメモリが必要とされるため、モデルを訓練できるシーケンスの長さと MSA の深さが大幅に制限されます。さらに、近似解を与えるような戦略を用いると、モデルの精度と収束に大きな影響を与える可能性があり、それでもメモリが爆発的に増加します(図7の左側のバー(オレンジ色))。
+
+DeepSpeed4Scienceは、構造生物学研究(タンパク質構造予測や平衡分布予測など)におけるこの一般的なシステムの課題に対処するために、このカテゴリの科学モデルに広く見られるアテンションのバリエーション(つまりEvoformerAttention)用にカスタマイズされた正確なアテンションのカーネルを設計することにより、このメモリの非効率性の問題に対処しています。具体的には、高度なフュージョン/タイリング戦略とオンザフライのメモリ削減方法によって可能になるメモリ効率の高いDS4Sci_EvoformerAttentionカーネルのセットを、高品質の機械学習プリミティブとして、より広いコミュニティ向けに作成しました。これらをOpenFoldに組み込むことで、訓練中の速度が大幅に向上し、訓練と推論のためのモデルのピークメモリが大幅に削減されます。これにより、OpenFoldはより大きく、より複雑なモデル、より長いシーケンスで実験し、より幅広いハードウェアで訓練することができます。この技術の詳細については、[こちら](https://deepspeed4science.ai/2023/09/18/model-showcase-openfold/)をご覧ください。
+
+### 事例(II): DeepSpeed4Scienceのシステムとアルゴリズムの両方からのアプローチにより、ゲノム基盤モデルでの非常に長い系列の使用をサポート
+
+
+
+
+*図8: GenSLMs:2022年ACM Gordon Bell Special Prize受賞COVIDゲノム用モデル(GPT-NeoXに基づく25B/33Bモデル)。SARS-CoV-2ゲノムの生物学的に意味のある特性を記述する潜在空間を学習するために使用される。このGIFは、重要なタンパク質ファミリーであるリンゴ酸デヒドロゲナーゼ(malate dehydrogenase)を可視化し、配列の長さやGC含量(アデニンとチミンと比較した核酸グアニンとシトシンの含量の比率。これはDNA鎖が熱に耐える能力を測るものである。)などの重要な特徴で色付けされた潜在空間の投影を表示している。*
+
+
+アルゴンヌ国立研究所が開発し、[2022年ACM Gordon Bell Special Prize](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022)を受賞したゲノム用言語モデルである[GenSLMs](https://github.com/ramanathanlab/genslm)は、ゲノムデータに大規模言語モデル(LLM)を適用することにより、SARS-CoV-2(COVID-19)ゲノムの進化を学習します。これは、パンデミックを引き起こすウイルス、特にSARS-CoV-2の新たに出現する亜種を特定し、分類する方法を変えるように設計されています。GenSLMsは、他の予測タスクに一般化できる最初のゲノム基盤モデルの1つです。潜在空間をうまく表現することにより、GenSLMsはウイルス配列だけでなく新しいドメインに適用し、細菌性病原体や真核生物をモデル化する能力を拡大し、機能、経路のメンバーシップ、進化的関係などを理解することができます。この科学的目標を達成するために、GenSLMsおよび同様のモデルは、[FlashAttention](https://arxiv.org/abs/2307.08691)のように、長いシーケンスのための一般的な戦略では扱うことが困難なレベルの、非常に長いシーケンスサポートを、訓練と推論の両方に対して必要とします。DeepSpeed4Scienceの新しい設計により、科学者はより長いシーケンスでモデルを構築および訓練できるようになり、以前は扱えなかった科学探索が可能になりました。
+
+
+
+
+*図9: 異なるスケールで異なるフレームワークがサポートする2つのGenSLMsモデルの最大シーケンス長。1ノードあたり8個の40G A100 GPUを搭載したNVIDIA DGXノードを使用。*
+
+
+システムレベルでは、非常に長いシーケンスをサポートするための最新の[Megatron-DeepSpeedフレームワーク](https://github.com/microsoft/Megatron-DeepSpeed)を、[他の新しい最適化とともにリリースします](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support)。科学者は、(アテンションマスクと位置の埋め込みに関する)新しく追加されたメモリ最適化手法、テンソル並列処理、パイプライン並列処理、シーケンス並列処理、ZeROスタイルのデータ並列処理、モデル状態のオフロードなどの技術を相乗的な組み合わせにより、GenSLMsのような大規模な科学モデルをはるかに長いシーケンスで訓練できるようになりました。図9は、新しいリリースにより、GenSLMsの25Bおよび33Bモデルで、以前のMegatron-DeepSpeedよりもそれぞれ最大12倍および14倍の最長シーケンス長を処理できることを示しています。サポートされているシーケンス長に関しては、この新しいMegatron-DeepSpeedは、25Bモデルと33Bモデルでそれぞれ最大9.8倍と9.1倍でNVIDIAのMegatron-LMを大幅に上回っています。たとえば、GenSLMsの25Bモデルは、64個のGPUでのアルゴンヌチームの元の42Kシーケンス長と比較して、512Kのヌクレオチド配列で訓練できるようになりました。これにより、精度を損なうことなく、モデルの品質と科学的発見の範囲が大幅に向上します。Relative position embeddingなどのアルゴリズム戦略を必要とする科学者向けの追加サポートも、[このリリース](https://deepspeed4science.ai/2023/09/18/model-showcase-genslms/)に統合されています。
+
+## まとめとロードマップ
+
+DeepSpeed4Scienceイニシアティブを、いくつかのR&Dのハイライトや成果と共に発表できることを嬉しく思います。本日から、外部の協力者に関する情報や、現在および将来のDeepSpeed4Scienceテクノロジーリリースなど、新しいイニシアティブでの活動を[deepspeed4science.ai](https://deepspeed4science.ai/)上で進めていきます。私たちの高レベルな目標の1つは、大規模な科学的発見のための主要なシステムの問題点に広く対処するAIシステムテクノロジーを一般化することです。世界中の科学者によって、オープンソースのソフトウェアを通じてDeepSpeed4Scienceによって利用可能になる新機能が活用されることを願っています。科学的発見の障害となるAIシステム設計の課題を解決していくことを楽しみにしています。AI4Scienceの有望な未来を築くために、皆様の参加を歓迎します。お問い合わせはまでお願いします。問題の報告や、PRを通じての貢献、ディスカッションへの参加は、[DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/)でお願いします。
+
+## 謝辞
+
+**Core DeepSpeed4Science Team:**
+
+Shuaiwen Leon Song (DeepSpeed4Science lead), Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Xiaoxia (Shirley) Wu, Masahiro Tanaka, Martin Cai, Adam Graham, Charlie Zhou, Yuxiong He (DeepSpeed team lead)
+
+**Our Founding Collaborators (in alphabetical order):**
+
+**Argonne National Lab team:** Rick Stevens, Cristina Negri, Rao Kotamarthi, Venkatram Vishwanath, Arvind Ramanathan, Sam Foreman, Kyle Hippe, Troy Arcomano, Romit Maulik, Maxim Zvyagin, Alexander Brace, Yuntian Deng, Bin Zhang, Cindy Orozco Bohorquez, Austin Clyde, Bharat Kale, Danilo Perez-Rivera, Heng Ma, Carla M. Mann, Michael Irvin, J. Gregory Pauloski, Logan Ward, Valerie Hayot, Murali Emani, Zhen Xie, Diangen Lin, Maulik Shukla, Weili Nie, Josh Romero, Christian Dallago, Arash Vahdat, Chaowei Xiao, Thomas Gibbs, Ian Foster, James J. Davis, Michael E. Papka, Thomas Brettin, Anima Anandkumar
+
+**AMD:** Ivo Bolsen, Micheal Schulte, Bo Begole, Angela Dalton, Steve Reinhart, Ashwin Aji, Jalal Mahmud, Mahesh Balashibramanian
+
+**Brookhaven National Lab team:** Adolfy Hoisie, Shinjae Yoo, Yihui Ren.
+
+**Columbia University OpenFold team:** Mohammed AlQuraishi, Gustaf Ahdritz
+
+**Microsoft Research AI4Science team:** Christopher Bishop, Bonnie Kruft, Max Welling, Tie-Yan Liu, Christian Bodnar, Johannes Brandsetter, Wessel Bruinsma, Chan Cao, Yuan-Jyue Chen, Peggy Dai, Patrick Garvan, Liang He, Elizabeth Heider, PiPi Hu, Peiran Jin, Fusong Ju, Yatao Li, Chang Liu, Renqian Luo, Qi Meng, Frank Noe, Tao Qin, Janwei Zhu, Bin Shao, Yu Shi, Wenlei Shi, Gregor Simm, Megan Stanley, Lixin Sun, Yue Wang, Tong Wang, Zun Wang, Lijun Wu, Yingce Xia, Leo Xia, Shufang Xie, Shuxin Zheng, Jianwei Zhu
+
+**Oakridge National Lab team:** Prassana Balaprakash, Georgia Tourass
+
+**Princeton University:** William Tang, Kyle Felker, Alexey Svyatkovskiy (Microsoft liaison)
+
+**Rutgers University:** Hang Liu
+
+**WebXT Weather team:** Pete Luferenko, Divya Kumar, Jonathan Weyn, Ruixiong Zhang, Sylwester Klocek, Volodymyr Vragov
diff --git a/blogs/deepspeed4science/media/Figure1.png b/blogs/deepspeed4science/media/Figure1.png
new file mode 100644
index 000000000000..614c4b40d6a1
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure1.png differ
diff --git a/blogs/deepspeed4science/media/Figure2-1.jpg b/blogs/deepspeed4science/media/Figure2-1.jpg
new file mode 100644
index 000000000000..6008ccd91d09
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure2-1.jpg differ
diff --git a/blogs/deepspeed4science/media/Figure2-2.gif b/blogs/deepspeed4science/media/Figure2-2.gif
new file mode 100644
index 000000000000..0890be7d7e31
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure2-2.gif differ
diff --git a/blogs/deepspeed4science/media/Figure3.png b/blogs/deepspeed4science/media/Figure3.png
new file mode 100644
index 000000000000..465e80e15a25
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure3.png differ
diff --git a/blogs/deepspeed4science/media/Figure4.gif b/blogs/deepspeed4science/media/Figure4.gif
new file mode 100644
index 000000000000..b45a5f28fd36
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure4.gif differ
diff --git a/blogs/deepspeed4science/media/Figure5.gif b/blogs/deepspeed4science/media/Figure5.gif
new file mode 100644
index 000000000000..a26c20103269
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure5.gif differ
diff --git a/blogs/deepspeed4science/media/Figure6-1.png b/blogs/deepspeed4science/media/Figure6-1.png
new file mode 100644
index 000000000000..65f7f9309f71
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure6-1.png differ
diff --git a/blogs/deepspeed4science/media/Figure6-2.gif b/blogs/deepspeed4science/media/Figure6-2.gif
new file mode 100644
index 000000000000..b50588c227d7
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure6-2.gif differ
diff --git a/blogs/deepspeed4science/media/Figure7.jpg b/blogs/deepspeed4science/media/Figure7.jpg
new file mode 100644
index 000000000000..eaa92007268b
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure7.jpg differ
diff --git a/blogs/deepspeed4science/media/Figure8.gif b/blogs/deepspeed4science/media/Figure8.gif
new file mode 100644
index 000000000000..624384910f2a
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure8.gif differ
diff --git a/blogs/deepspeed4science/media/Figure9.png b/blogs/deepspeed4science/media/Figure9.png
new file mode 100644
index 000000000000..f00fd9b6917f
Binary files /dev/null and b/blogs/deepspeed4science/media/Figure9.png differ
diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cpp b/csrc/deepspeed4science/evoformer_attn/attention.cpp
new file mode 100644
index 000000000000..ac3364539ff1
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention.cpp
@@ -0,0 +1,62 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+
+void attention_impl(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse);
+void attention(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse)
+{
+ attention_impl(q, k, v, bias1, bias2, o, lse);
+}
+
+void attention_back_impl(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2);
+void attention_bwd(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("attention", &attention, "");
+ m.def("attention_bwd", &attention_bwd, "");
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cu b/csrc/deepspeed4science/evoformer_attn/attention.cu
new file mode 100644
index 000000000000..37636c4bf988
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention.cu
@@ -0,0 +1,160 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include "gemm_kernel_utils.h"
+#include "kernel_forward.h"
+#include "transform/bias_broadcast.h"
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_impl_template(
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ float* lse_ptr)
+{
+ EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
+}
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_impl_template(
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ float* lse_ptr)
+{
+ // Attention definition goes here, replaced with BroadcastType1 and
+ // BroadcastType2
+ using Attention = AttentionKernel;
+
+ static_assert(!Attention::kNeedsOutputAccumulatorBuffer,
+ "This test does not support output accumulator buffer");
+ int head_size = q.size(-1);
+ int head_number = q.size(-2);
+ int seq_length = q.size(-3);
+ auto q_view = q.view({-1, seq_length, head_number, head_size});
+ auto k_view = k.view({-1, seq_length, head_number, head_size});
+ auto v_view = v.view({-1, seq_length, head_number, head_size});
+ auto o_view = o.view({-1, seq_length, head_number, head_size});
+ int batch_size = q_view.size(0);
+ auto q_ptr = reinterpret_cast(q.data_ptr());
+ auto k_ptr = reinterpret_cast(k.data_ptr());
+ auto v_ptr = reinterpret_cast(v.data_ptr());
+ auto o_ptr = reinterpret_cast(o.data_ptr());
+
+ auto bias1_ptr = reinterpret_cast(bias1.data_ptr());
+ auto bias2_ptr = reinterpret_cast(bias2.data_ptr());
+
+ typename Attention::Params p;
+ { // set parameters
+ p.query_ptr = q_ptr;
+ p.key_ptr = k_ptr;
+ p.value_ptr = v_ptr;
+ p.logsumexp_ptr = lse_ptr; // Only needed for bw
+ p.output_accum_ptr = nullptr;
+ p.output_ptr = o_ptr;
+ p.scale = 1.0f / sqrt(float(head_size));
+
+ p.bias1_ptr = bias1_ptr;
+ p.bias2_ptr = bias2_ptr;
+ p.B = q.size(0);
+ p.N = q.size(1);
+
+ p.num_heads = head_number;
+ p.num_batches = batch_size;
+ p.head_dim = head_size;
+ p.head_dim_value = head_size;
+ p.num_queries = seq_length;
+ p.num_keys = seq_length;
+
+ // All tensors are in BMHK shapes
+ p.q_strideH = q_view.stride(-2);
+ p.k_strideH = k_view.stride(-2);
+ p.v_strideH = v_view.stride(-2);
+ p.q_strideM = q_view.stride(-3);
+ p.k_strideM = k_view.stride(-3);
+ p.v_strideM = v_view.stride(-3);
+ p.o_strideM = o_view.stride(-3);
+ p.q_strideB = q_view.stride(-4);
+ p.k_strideB = k_view.stride(-4);
+ p.v_strideB = v_view.stride(-4);
+ }
+
+ constexpr auto kernel_fn = attention_kernel_batched_impl;
+ int smem_bytes = sizeof(typename Attention::SharedStorage);
+ if (smem_bytes > 0xc000) {
+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
+ }
+ if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); }
+ kernel_fn<<>>(p);
+}
+
+#define CODE(scalar_t, torch_scalar_t) \
+ do { \
+ if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else if (bias1.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else if (bias2.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else { \
+ attention_impl_template( \
+ q, k, v, bias1, bias2, o, lse_ptr); \
+ } \
+ } while (0)
+
+// Function to select and call the correct template based on biases sizes
+void attention_impl(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse)
+{
+ auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast(lse.data_ptr());
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
+ DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/attention_back.cu b/csrc/deepspeed4science/evoformer_attn/attention_back.cu
new file mode 100644
index 000000000000..a82c4ec68a13
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention_back.cu
@@ -0,0 +1,218 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include
+#include "gemm_kernel_utils.h"
+#include "kernel_backward.h"
+#include "transform/bias_broadcast.h"
+
+constexpr auto kBlockSizeI = 64;
+constexpr auto kBlockSizeJ = 64;
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_back_impl_template(
+ torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
+}
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_back_impl_template(
+ torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
+ using Kernel = AttentionBackwardKernel;
+ int head_size = q.size(-1);
+ int head_number = q.size(-2);
+ int seq_length = q.size(-3);
+ auto q_view = q.view({-1, seq_length, head_number, head_size});
+ auto k_view = k.view({-1, seq_length, head_number, head_size});
+ auto v_view = v.view({-1, seq_length, head_number, head_size});
+ auto o_view = o.view({-1, seq_length, head_number, head_size});
+ auto do_view = go.view({-1, seq_length, head_number, head_size});
+ auto dk_view = gk.view({-1, seq_length, head_number, head_size});
+ auto dv_view = gv.view({-1, seq_length, head_number, head_size});
+ auto dq_view = gq.view({-1, seq_length, head_number, head_size});
+ auto q_ptr = reinterpret_cast(q.data_ptr());
+ auto k_ptr = reinterpret_cast(k.data_ptr());
+ auto v_ptr = reinterpret_cast(v.data_ptr());
+ auto o_ptr = reinterpret_cast(o.data_ptr());
+ auto do_ptr = reinterpret_cast(go.data_ptr());
+ auto dk_ptr = reinterpret_cast(gk.data_ptr());
+ auto dv_ptr = reinterpret_cast(gv.data_ptr());
+ auto dq_ptr = reinterpret_cast(gq.data_ptr());
+ auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast(gb1.data_ptr()) : nullptr;
+ auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast(gb2.data_ptr()) : nullptr;
+ auto lse_ptr = reinterpret_cast(lse.data_ptr());
+ auto delta_ptr = reinterpret_cast(delta.data_ptr());
+ auto bias1_ptr = reinterpret_cast(bias1.data_ptr());
+ auto bias2_ptr = reinterpret_cast(bias2.data_ptr());
+ static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
+
+ typename Kernel::Params p;
+ p.query_ptr = q_ptr;
+ p.key_ptr = k_ptr;
+ p.value_ptr = v_ptr;
+ p.logsumexp_ptr = lse_ptr;
+ p.output_ptr = o_ptr;
+ p.grad_output_ptr = do_ptr;
+ p.delta_ptr = delta_ptr;
+ p.grad_query_ptr = dq_ptr;
+ p.grad_key_ptr = dk_ptr;
+ p.grad_value_ptr = dv_ptr;
+
+ p.grad_bias1_ptr = db1_ptr;
+ p.grad_bias2_ptr = db2_ptr;
+ p.B = q.size(0);
+ p.N = q.size(1);
+ p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
+ p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
+
+ p.scale = 1.0f / sqrtf(head_size);
+
+ p.head_dim = head_size;
+ p.head_dim_value = head_size;
+ p.num_queries = seq_length;
+ p.num_keys = seq_length;
+ p.num_heads = head_number;
+
+ p.q_strideM = q_view.stride(-3);
+ p.k_strideM = k_view.stride(-3);
+ p.v_strideM = v_view.stride(-3);
+ p.gO_strideM = do_view.stride(-3);
+ p.o_strideH = o_view.stride(-2);
+ p.q_strideH = q_view.stride(-2);
+ p.k_strideH = k_view.stride(-2);
+ p.v_strideH = v_view.stride(-2);
+ p.o_strideB = o_view.stride(-4);
+ p.q_strideB = q_view.stride(-4);
+ p.k_strideB = k_view.stride(-4);
+ p.v_strideB = v_view.stride(-4);
+ p.lse_strideB = lse.stride(-3);
+ p.lse_strideH = lse.stride(-2);
+ p.delta_strideB = delta.stride(-3);
+ p.delta_strideH = delta.stride(-2);
+ p.num_batches = q_view.size(-4);
+
+ p.gO_strideB = do_view.stride(-4);
+ p.gQ_strideB = dq_view.stride(-4);
+ p.gK_strideB = dk_view.stride(-4);
+ p.gV_strideB = dv_view.stride(-4);
+ p.gO_strideH = do_view.stride(-2);
+ p.gQ_strideH = dq_view.stride(-2);
+ p.gK_strideH = dk_view.stride(-2);
+ p.gV_strideH = dv_view.stride(-2);
+
+ torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
+ p.workspace = workspace.data_ptr();
+
+ auto kernel_fn = attention_kernel_backward_batched_impl;
+ size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
+ if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
+ kernel_fn<<>>(p);
+}
+
+#define CODE(scalar_t, torch_scalar_t) \
+ do { \
+ if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else if (bias1.size(0) > 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } \
+ } while (0)
+
+void attention_back_impl(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
+ DISPATCH_TYPES(q, { CODE(scalar_t, torch_scalar_t); }));
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h
new file mode 100644
index 000000000000..17b6479ed8c5
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h
@@ -0,0 +1,250 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+#include
+#include
+#include "../iterators/predicated_tile_iterator_atomic.h"
+#include "cutlass/epilogue/threadblock/epilogue.h"
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+template
+struct EpilogueTensorOpAffineRankN : public DefaultEpilogueTensorOpAffineRankN {
+ using Base = DefaultEpilogueTensorOpAffineRankN;
+ using OutputTileIterator =
+ cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ Rank>;
+
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueVoltaTensorOpAffineRankN
+ : public DefaultEpilogueVoltaTensorOpAffineRankN {
+ using Base = DefaultEpilogueVoltaTensorOpAffineRankN;
+ using OutputTileIterator =
+ cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ Rank>;
+
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueTensorOp : public DefaultEpilogueTensorOp {
+ using Base = DefaultEpilogueTensorOp;
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ ScatterD,
+ PermuteDLayout>;
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueVoltaTensorOp : public DefaultEpilogueVoltaTensorOp {
+ using Base = DefaultEpilogueVoltaTensorOp;
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ ScatterD,
+ PermuteDLayout>;
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+template
+struct BiasGradEpilogue {
+ using Epilogue =
+ typename cutlass::epilogue::threadblock::EpilogueTensorOp::Epilogue;
+};
+
+template
+struct BiasGradEpilogue {
+ using Epilogue =
+ typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOp::Epilogue;
+};
+
+template
+struct BiasGradEpilogueAffineRankN {
+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueTensorOpAffineRankN<
+ Rank,
+ Shape_,
+ WarpMmaTensorOp_,
+ PartitionsK,
+ OutputOp_,
+ ElementsPerAccess>::Epilogue;
+};
+
+template
+struct BiasGradEpilogueAffineRankN {
+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOpAffineRankN<
+ Rank,
+ Shape_,
+ WarpMmaTensorOp_,
+ PartitionsK,
+ OutputOp_,
+ ElementsPerAccess>::Epilogue;
+};
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h
new file mode 100644
index 000000000000..3b7b32d61452
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h
@@ -0,0 +1,592 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
+
+ File copied from "cutlass/epilogue/threadblock/epilogue.h"
+ then modified to:
+ (1) load 2 source fragments at the same time (pipelining)
+ (2) support reading from a different dtype
+ (3) pass the row id to the OutputOp if it takes it
+ (see MemoryEfficientAttentionNormalize)
+ Note that in general the fragment passed to the OutputOp could
+ span multiple rows but it does not happen with the configurations we have
+*/
+
+#pragma once
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/functional.h"
+#include "cutlass/layout/tensor.h"
+#include "cutlass/layout/vector.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_coord.h"
+
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/transform/pitch_linear_thread_map.h"
+#include "cutlass/transform/threadblock/regular_tile_iterator.h"
+
+#include "cutlass/epilogue/threadblock/epilogue_base.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
+#include "cutlass/numeric_types.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+
+template
+struct ApplyEpilogueOp {
+ static CUTLASS_DEVICE typename Op::FragmentOutput apply(
+ Op const& output_op,
+ int row_id,
+ typename Op::FragmentAccumulator const& accum,
+ typename Op::FragmentOutput const& source)
+ {
+ return output_op(accum, source);
+ }
+ static CUTLASS_DEVICE typename Op::FragmentOutput
+ apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
+ {
+ return output_op(accum);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Epilogue operator
+template ::value),
+ typename OutputTileSourceIterator_ =
+ OutputTileIterator_ ///< Tile iterator reading tensors
+ >
+class EpiloguePipelined : public EpilogueBase {
+public:
+ using Base = EpilogueBase;
+
+ using Shape = Shape_;
+ using WarpMmaOperator = WarpMmaOperator_;
+ static int const kPartitionsK = PartitionsK;
+ using OutputTileIterator = OutputTileIterator_;
+ using OutputTileSourceIterator = OutputTileSourceIterator_;
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
+ using WarpTileIterator = WarpTileIterator_;
+ using SharedLoadIterator = SharedLoadIterator_;
+ using OutputOp = OutputOp_;
+ using Padding = Padding_;
+
+ using Layout = layout::RowMajor;
+ using LongIndex = typename Layout::LongIndex;
+
+ /// The complete warp-level accumulator tile
+ using AccumulatorTile = typename Base::AccumulatorTile;
+
+ /// Accumulator element
+ using ElementAccumulator = typename WarpTileIterator::Element;
+
+ /// Output element
+ using ElementOutput = typename OutputTileIterator::Element;
+ using ElementSource = typename OutputTileSourceIterator::Element;
+
+ /// Output access size
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
+
+ /// Tensor reference to destination tensor
+ using TensorRef = typename OutputTileIterator::TensorRef;
+
+ /// Tensor reference to sync tensor
+ using SyncTensorRef = typename cutlass::TensorRef;
+
+ /// Const tensor reference to source tensor
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
+
+ /// Array type used to output
+ using OutputAccessType =
+ Array;
+ using SourceAccessType = Array;
+
+ /// Array type used by output functor
+ using AccumulatorAccessType =
+ Array;
+
+ /// Number of warps
+ using WarpCount = typename Base::WarpCount;
+
+ static int constexpr kSmemTiles =
+ Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
+ static int constexpr kSmemPointerOffset =
+ Base::SharedStorage::StorageShape::kCount / kSmemTiles;
+
+public:
+ static_assert(OutputTileSourceIterator::Fragment::kElements ==
+ OutputTileIterator::Fragment::kElements,
+ "Mismatch between input tile and output tile iterator (kElements)");
+ static_assert(OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
+ "Mismatch between input tile and output tile iterator (kIterations)");
+ static_assert(SharedLoadIterator::Fragment::kElements ==
+ OutputTileIterator::Fragment::kElements,
+ "Mismatch between shared load iterator and output tile iterator.");
+
+ static_assert(OutputTileIterator::kElementsPerAccess,
+ "OutputTileIterator::kElementsPerAccess must not be zero.");
+
+ static_assert(!(OutputTileIterator::Fragment::kElements %
+ OutputTileIterator::kElementsPerAccess),
+ "Divisibility");
+
+private:
+ /// Loads fragment from shared memory aligned with output tensor
+ SharedLoadIterator shared_load_iterator_;
+
+public:
+ /// Constructor
+ CUTLASS_DEVICE
+ EpiloguePipelined(typename Base::SharedStorage& shared_storage, ///< Shared storage object
+ int thread_idx, ///< ID of a thread within the threadblock
+ int warp_idx, ///< ID of warp within threadblock
+ int lane_idx ///< Id of thread within warp
+ )
+ : Base(shared_storage, thread_idx, warp_idx, lane_idx),
+ shared_load_iterator_(shared_storage.reference(), thread_idx)
+ {
+ }
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void operator()(OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
+ OutputTileSourceIterator source_iterator)
+ { ///< Threadblock tile coordinate in GEMM (in units
+ ///< of threadblock tiles)
+
+ if (!output_op.is_source_needed()) {
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
+ } else {
+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
+ }
+ }
+ CUTLASS_DEVICE
+ void operator()(OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators)
+ { ///< Complete warp-level accumulator tile
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
+ }
+
+private:
+ template
+ struct acc2smem_source_not_needed;
+
+ template
+ struct acc2smem_source_not_needed> {
+ template
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
+
+ accum_fragment_iterator.load(accum_fragment);
+ ++accum_fragment_iterator;
+
+ warp_tile_iterator.store(accum_fragment);
+ if (p < Base::kFragmentsPerIteration - 1) {
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
+ }
+ }
+
+ if (Base::kFragmentsPerIteration > 1) {
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
+ (1 - Base::kFragmentsPerIteration));
+ }
+ }
+
+ CUTLASS_DEVICE
+ static void push(size_t pos,
+ AccumulatorFragmentIterator const& iterator_begin,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ int dummy[] = {
+ (pos == (Seq * Base::kFragmentsPerIteration)) &&
+ (helper(iterator_begin, warp_tile_iterator),
+ 0)...};
+
+ CUTLASS_UNUSED(dummy[0]);
+ }
+ };
+
+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
+ "One of these must be exactly 1.");
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void compute_source_not_needed_(
+ OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators ///< Complete warp-level accumulator tile
+ )
+ {
+ //
+ // Iterator over warp-level accumulator fragment
+ //
+
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
+
+ //
+ // Iterate over accumulator tile
+ //
+
+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
+ : 1)
+ for (int iter = 0; iter < OutputTileIterator::kIterations;
+ iter += Base::kFragmentsPerIteration) {
+ //
+ // Convert and store fragment
+ //
+
+ __syncthreads();
+
+ acc2smem_source_not_needed>::
+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
+
+ __syncthreads();
+
+ //
+ // Load fragments from shared memory
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
+
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
+
+ if (p < Base::kFragmentsPerIteration - 1) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ } else if (kPartitionsK > 1) {
+ plus add_fragments;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 1; i < kPartitionsK; ++i) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
+ aligned_accum_fragment[0] =
+ add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
+ }
+
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) *
+ kSmemPointerOffset);
+ }
+
+ //
+ // Compute the output result
+ //
+
+ typename OutputTileIterator::Fragment output_fragment;
+
+ apply_output_operator_source_not_needed_(destination_iterator.thread_start_row(),
+ output_fragment,
+ output_op,
+ aligned_accum_fragment[0]);
+
+ //
+ // Store the final result
+ //
+
+ destination_iterator.store(output_fragment);
+ ++destination_iterator;
+ }
+
+ if (Base::kFragmentsPerIteration > 1) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset *
+ (1 - Base::kFragmentsPerIteration));
+ }
+ }
+ }
+
+ template
+ struct acc2smem_source_needed;
+
+ template
+ struct acc2smem_source_needed> {
+ template
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
+
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
+ accum_fragment_iterator.load(accum_fragment);
+ warp_tile_iterator.store(accum_fragment);
+ }
+
+ CUTLASS_DEVICE
+ static void push(size_t pos,
+ AccumulatorFragmentIterator const& iterator_begin,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...};
+ }
+ };
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void compute_source_needed_(
+ OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
+ OutputTileSourceIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units
+ ///< of threadblock tiles)
+ )
+ {
+ typename OutputTileSourceIterator::Fragment source_fragment[2];
+
+ source_fragment[0].clear();
+ source_iterator.load(source_fragment[0]);
+ ++source_iterator;
+ source_fragment[1].clear();
+
+ //
+ // Iterator over warp-level accumulator fragment
+ //
+
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
+
+ //
+ // Iterate over accumulator tile
+ //
+
+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
+ if (iter > 0) { __syncthreads(); }
+ //
+ // Load the source for next iteration (pipelining)
+ //
+
+ if (iter + 1 < OutputTileIterator::kIterations) {
+ source_iterator.load(source_fragment[(iter + 1) % 2]);
+ }
+ ++source_iterator;
+ acc2smem_source_needed>::
+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
+
+ __syncthreads();
+
+ //
+ // Load fragments from shared memory
+ //
+
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
+
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
+
+ // If the number of k-slices is > 1 - perform a reduction amongst the
+ // k-slices
+ if (kPartitionsK > 1) {
+ plus add_fragments;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 1; i < kPartitionsK; ++i) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
+ aligned_accum_fragment[0] =
+ add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
+ }
+
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
+ }
+
+ //
+ // Compute the output result
+ //
+
+ typename OutputTileIterator::Fragment output_fragment;
+
+ apply_output_operator_(destination_iterator.thread_start_row(),
+ output_fragment,
+ output_op,
+ aligned_accum_fragment[0],
+ source_fragment[iter % 2]);
+
+ //
+ // Store the final result
+ //
+
+ destination_iterator.store(output_fragment);
+ ++destination_iterator;
+ }
+ }
+
+ /// Helper to invoke the output functor over each vector of output
+ CUTLASS_DEVICE
+ void apply_output_operator_(int begin_row,
+ typename OutputTileIterator::Fragment& output_fragment,
+ OutputOp const& output_op, ///< Output operator
+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
+ typename OutputTileSourceIterator::Fragment const& source_fragment)
+ {
+ OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment);
+
+ AccumulatorAccessType const* compute_frag_ptr =
+ reinterpret_cast(&aligned_accum_fragment);
+
+ SourceAccessType const* source_frag_ptr =
+ reinterpret_cast(&source_fragment);
+
+ int const kOutputOpIterations =
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kOutputOpIterations; ++i) {
+ // Call the output operator
+ output_frag_ptr[i] = ApplyEpilogueOp::apply(
+ output_op,
+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
+ compute_frag_ptr[i],
+ source_frag_ptr[i]);
+ }
+ }
+
+ /// Helper to invoke the output functor over each vector of output
+ CUTLASS_DEVICE
+ void apply_output_operator_source_not_needed_(
+ int begin_row,
+ typename OutputTileIterator::Fragment& output_fragment,
+ OutputOp const& output_op, ///< Output operator
+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment)
+ {
+ OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment);
+
+ AccumulatorAccessType const* compute_frag_ptr =
+ reinterpret_cast(&aligned_accum_fragment);
+
+ int const kOutputOpIterations =
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kOutputOpIterations; ++i) {
+ // Call the output operator
+ output_frag_ptr[i] = ApplyEpilogueOp::apply(
+ output_op,
+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
+ compute_frag_ptr[i]);
+ }
+ }
+
+ // This should be constexpr, but it's only supported on c++14
+ static int CUTLASS_HOST_DEVICE getRowOffset(int i)
+ {
+ using ThreadMap = typename OutputTileIterator::ThreadMap;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
+ int row_offset = row * ThreadMap::Delta::kRow +
+ group * ThreadMap::Delta::kGroup +
+ cluster * ThreadMap::Delta::kCluster;
+ int frag_row_idx =
+ (row + ThreadMap::Iterations::kRow *
+ (group + ThreadMap::Iterations::kGroup * cluster));
+ CUTLASS_PRAGMA_UNROLL
+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
+ int frag_idx = ThreadMap::kElementsPerAccess *
+ (frag_row_idx * ThreadMap::Iterations::kColumn + column);
+ if (i < frag_idx + ThreadMap::kElementsPerAccess) { return row_offset; }
+ }
+ }
+ }
+ }
+ return -1;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h
new file mode 100644
index 000000000000..f81a09f74f1e
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h
@@ -0,0 +1,251 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
+
+ The epilogue rearranges the result of a matrix product through shared memory
+ to match canonical tensor layouts in global memory. Epilogues support
+ conversion and reduction operations.
+
+ This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
+ handle "row_id" as a first argument, as uses it to get the corresponding
+ `m_prime` / `s_prime` to rescale the output.
+*/
+
+#pragma once
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/functional.h"
+#include "cutlass/layout/tensor.h"
+#include "cutlass/layout/vector.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_coord.h"
+
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/transform/pitch_linear_thread_map.h"
+#include "cutlass/transform/threadblock/regular_tile_iterator.h"
+
+#include "cutlass/epilogue/threadblock/epilogue_base.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/epilogue/thread/scale_type.h"
+#include "cutlass/functional.h"
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
+#include "epilogue_pipelined.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace thread {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Applies a linear combination operator to an array of elements.
+// output <- alpha * accumulator + beta * source
+// with:
+// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
+// beta = alpha / m_prime (renormalize the output when the max changes)
+// source is the current output
+template ,
+ ///< but we use 64 or 32 sometimes when there are not enough data
+ ///< to store
+ typename ElementAccumulator_, ///< Accumulator data type
+ typename ElementCompute_, ///< Data type used to compute linear combination
+ bool isFirst,
+ bool isLast,
+ typename FragmentAlphaBeta_,
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
+class MemoryEfficientAttentionNormalize {
+public:
+ using ElementOutput = ElementOutput_;
+ using ElementSource = ElementSource_;
+ using ElementAccumulator = ElementAccumulator_;
+ using ElementCompute = ElementCompute_;
+
+ static int const kCount = Count;
+
+ using FragmentOutput = Array;
+ using FragmentSource = Array;
+ using FragmentAccumulator = Array;
+ using ComputeFragment = Array;
+ using FragmentAlphaBeta = FragmentAlphaBeta_;
+
+ static FloatRoundStyle const kRound = Round;
+
+private:
+ //
+ // Data members
+ //
+
+ FragmentAlphaBeta const& s_prime_;
+ FragmentAlphaBeta const& m_prime_;
+
+public:
+ /// Constructs the function object, possibly loading from pointers in host
+ /// memory
+ CUTLASS_HOST_DEVICE
+ MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime,
+ FragmentAlphaBeta const& m_prime)
+ : s_prime_(s_prime), m_prime_(m_prime)
+ {
+ }
+
+ /// Returns true if source is needed
+ CUTLASS_HOST_DEVICE
+ bool is_source_needed() const { return !isFirst; }
+
+ /// Functionally required for serial reduction in the epilogue
+ CUTLASS_HOST_DEVICE
+ void set_k_partition(int k_partition, int k_partition_count) {}
+
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(int row,
+ FragmentAccumulator const& accumulator,
+ FragmentSource const& source) const
+ {
+ assert(!isFirst);
+
+ // Convert source to internal compute numeric type
+ NumericArrayConverter source_converter;
+ NumericArrayConverter
+ accumulator_converter;
+
+ // Convert to destination numeric type
+ NumericArrayConverter destination_converter;
+
+ ComputeFragment converted_source = source_converter(source);
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
+
+ // Perform binary operations
+ ComputeFragment intermediate;
+
+ multiplies mul_add_source;
+ multiply_add mul_add_accumulator;
+
+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
+ ElementCompute beta = alpha * m_prime_[row];
+
+ intermediate = mul_add_source(beta, converted_source); // X = beta * C
+
+ intermediate = mul_add_accumulator(
+ alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
+
+ return destination_converter(intermediate);
+ }
+
+ /// Computes linear scaling: D = alpha * accumulator
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) const
+ {
+ assert(isFirst);
+
+ // Convert source to internal compute numeric type
+ NumericArrayConverter
+ accumulator_converter;
+
+ // Convert to destination numeric type
+ NumericArrayConverter destination_converter;
+
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
+
+ ComputeFragment intermediate;
+ multiplies mul_accumulator;
+
+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
+
+ intermediate = mul_accumulator(alpha, converted_accumulator); // X = alpha * C + uniform
+
+ return destination_converter(intermediate);
+ }
+};
+
+} // namespace thread
+
+namespace threadblock {
+template
+struct ApplyEpilogueOp<
+ thread::MemoryEfficientAttentionNormalize> {
+ using Op = thread::MemoryEfficientAttentionNormalize;
+ static CUTLASS_DEVICE typename Op::FragmentOutput apply(
+ Op const& output_op,
+ int row_id,
+ typename Op::FragmentAccumulator const& accum,
+ typename Op::FragmentSource const& source)
+ {
+ return output_op(row_id, accum, source);
+ }
+ static CUTLASS_DEVICE typename Op::FragmentOutput
+ apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
+ {
+ return output_op(row_id, accum);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h
new file mode 100644
index 000000000000..46fb2bf17c1c
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h
@@ -0,0 +1,168 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Functor performing linear combination operations used by epilogues.
+*/
+
+#pragma once
+
+#include
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/epilogue/thread/activation.h"
+#include "cutlass/functional.h"
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace thread {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+
+template
+struct ArrayExponential {
+ CUTLASS_HOST_DEVICE
+ Array operator()(
+ Array const& input) const
+ {
+ Array result;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ElementsPerAccess; ++i) { result[i] = expf(input[i]); }
+
+ return result;
+ }
+};
+
+template
+struct ArrayExponential {
+ CUTLASS_DEVICE
+ Array operator()(Array const& input) const
+ {
+ Array result;
+
+ int const kVectorCount = ElementsPerAccess / 2;
+
+ __half2 const* input_ptr = reinterpret_cast<__half2 const*>(input.raw_data());
+ __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kVectorCount; ++i) { res_ptr[i] = h2exp(input_ptr[i]); }
+
+ return result;
+ }
+};
+} // namespace detail
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Applies:
+/// output <- (input - lse).exp()
+template
+class ApplyLogSumExp {
+public:
+ using ElementOutput = ElementOutput_;
+ using ElementAccumulator = ElementAccumulator_;
+ using ElementCompute = ElementCompute_;
+ using ElementLSE = ElementLSE_;
+
+ static int const kElementsPerAccess = ElementsPerAccess;
+ static int const kCount = kElementsPerAccess;
+ static const ScaleType::Kind kScale = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
+
+ using FragmentOutput = Array;
+ using FragmentAccumulator = Array;
+ using FragmentCompute = Array;
+ using FragmentLSE = Array;
+ using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
+
+public:
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ ApplyLogSumExp() {}
+
+ /// Returns true if source is needed
+ CUTLASS_HOST_DEVICE
+ bool is_source_needed() const { return true; }
+
+ /// Functionally required for serial reduction in the epilogue
+ CUTLASS_HOST_DEVICE
+ void set_k_partition(int k_partition, int k_partition_count) {}
+
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(FragmentAccumulator const& AB,
+ FragmentLSE const& scale_unused,
+ // bias used as LSE
+ FragmentLSE const& bias) const
+ {
+ FragmentCompute frag_AB =
+ NumericArrayConverter()(AB);
+ FragmentCompute frag_lse_compute =
+ NumericArrayConverter()(bias);
+ FragmentCompute frag_compute;
+
+ minus minus_lse;
+ detail::ArrayExponential apply_exp;
+ frag_compute = minus_lse(frag_AB, frag_lse_compute);
+ frag_compute = apply_exp(frag_compute);
+
+ return NumericArrayConverter()(
+ frag_compute);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace thread
+} // namespace epilogue
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h
new file mode 100644
index 000000000000..75833bbfe7d2
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h
@@ -0,0 +1,119 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include "custom_mma_multistage.h"
+#include "custom_mma_pipelined.h"
+#include "cutlass/gemm/threadblock/mma_multistage.h"
+#include "cutlass/gemm/threadblock/mma_pipelined.h"
+
+template
+struct MakeCustomMma;
+
+template
+struct MakeCustomMma,
+ kMaxK> {
+ // Reduce the number of stages if we don't need that many
+ static int constexpr kStages =
+ kMaxK == cutlass::platform::numeric_limits::max()
+ ? Stages
+ : cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
+ using Mma = cutlass::gemm::threadblock::CustomMmaMultistage;
+};
+
+template
+struct MakeCustomMma,
+ kMaxK> {
+ using Mma = cutlass::gemm::threadblock::CustomMmaPipelined;
+};
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h
new file mode 100644
index 000000000000..bbf91240b900
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h
@@ -0,0 +1,181 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/threadblock/mma_base.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Number of stages,
+ int Stages,
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaBase {
+public:
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using Shape = Shape_;
+
+ ///< Policy describing tuning details
+ using Policy = Policy_;
+
+ //
+ // Dependent types
+ //
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Shape describing the overall GEMM computed from shared memory
+ /// by each warp.
+ using WarpGemm = typename Policy::Operator::Shape;
+
+ /// Shape describing the number of warps filling the CTA
+ using WarpCount =
+ GemmShape;
+
+ /// Number of warp-level GEMM oeprations
+ static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
+
+ /// Number of stages
+ static int const kStages = Stages;
+
+ //
+ // Nested structs
+ //
+
+ /// Shared storage object needed by threadblock-scoped GEMM
+ template
+ struct OperandSharedStorage {
+ AlignedBuffer buffer;
+ using TensorRef = TensorRef;
+
+ CUTLASS_DEVICE
+ static OperandLayout Layout()
+ {
+ return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
+ }
+
+ /// Returns a TensorRef to the operand
+ CUTLASS_HOST_DEVICE
+ TensorRef ref() { return TensorRef{buffer.data(), Layout()}; }
+ };
+
+ /// Shape of the A matrix operand in shared memory
+ using ShapeA = MatrixShape;
+
+ /// Shape of the B matrix operand in shared memory
+ using ShapeB = MatrixShape;
+
+ using SharedStorageA =
+ OperandSharedStorage;
+ using SharedStorageB =
+ OperandSharedStorage;
+ using TensorRefA = typename SharedStorageA::TensorRef;
+ using TensorRefB = typename SharedStorageB::TensorRef;
+
+ struct SharedStorage {
+ /// Buffer for A operand
+ SharedStorageA operand_A;
+
+ /// Buffer for B operand
+ SharedStorageB operand_B;
+ };
+
+protected:
+ //
+ // Data members
+ //
+
+ /// Iterator to load a warp-scoped tile of A operand from shared memory
+ typename Operator::IteratorA warp_tile_iterator_A_;
+
+ /// Iterator to load a warp-scoped tile of B operand from shared memory
+ typename Operator::IteratorB warp_tile_iterator_B_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaBase(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ SharedStorageA& shared_storageA,
+ SharedStorageB& shared_storageB,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
+ warp_tile_iterator_B_(shared_storageB.ref(), lane_idx)
+ {
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
new file mode 100644
index 000000000000..50ba58b1d1dd
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
@@ -0,0 +1,706 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/cache_operation.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+#include "custom_mma_base.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Iterates over tiles of A operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorA_,
+ /// Iterates over tiles of A operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorA_,
+ /// Cache operation for operand A
+ cutlass::arch::CacheOperation::Kind CacheOpA,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorB_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB_,
+ /// Cache operation for operand B
+ cutlass::arch::CacheOperation::Kind CacheOpB,
+ /// Data type of accumulator matrix
+ typename ElementC_,
+ /// Data type of accumulator matrix
+ typename LayoutC_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Number of stages,
+ int Stages,
+ /// Use zfill or predicate for out-of-bound cp.async
+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
+ /// Upper boundon the K dimension
+ int kMaxK = cutlass::platform::numeric_limits::max(),
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaMultistage : public CustomMmaBase {
+public:
+ ///< Base class
+ using Base = CustomMmaBase;
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using Shape = Shape_;
+ ///< Iterates over tiles of A operand in global memory
+ using IteratorA = IteratorA_;
+ ///< Iterates over tiles of B operand in global memory
+ using IteratorB = IteratorB_;
+ ///< Data type of accumulator matrix
+ using ElementC = ElementC_;
+ ///< Layout of accumulator matrix
+ using LayoutC = LayoutC_;
+ ///< Policy describing tuning details
+ using Policy = Policy_;
+
+ using SmemIteratorA = SmemIteratorA_;
+ using SmemIteratorB = SmemIteratorB_;
+
+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
+
+ //
+ // Dependent types
+ //
+
+ /// Fragment of accumulator tile
+ using FragmentC = typename Policy::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Minimum architecture is Sm80 to support cp.async
+ using ArchTag = arch::Sm80;
+
+ /// Complex transform on A operand
+ static ComplexTransform const kTransformA = Operator::kTransformA;
+
+ /// Complex transform on B operand
+ static ComplexTransform const kTransformB = Operator::kTransformB;
+
+ /// Internal structure exposed for introspection.
+ struct Detail {
+ static_assert(Base::kWarpGemmIterations > 1,
+ "The pipelined structure requires at least two warp-level "
+ "GEMM operations.");
+
+ /// Number of cp.async instructions to load one stage of operand A
+ static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
+
+ /// Number of cp.async instructions to load one stage of operand B
+ static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
+
+ /// Number of stages
+ static int const kStages = Stages;
+
+ /// Number of cp.async instructions to load on group of operand A
+ static int const kAccessesPerGroupA =
+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
+ Base::kWarpGemmIterations;
+
+ /// Number of cp.async instructions to load on group of operand B
+ static int const kAccessesPerGroupB =
+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
+ Base::kWarpGemmIterations;
+ };
+
+ static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
+ static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireMat ? Stages : Stages - 1;
+
+private:
+ using WarpLoadedFragmentA = typename Operator::FragmentA;
+ using WarpLoadedFragmentB = typename Operator::FragmentB;
+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
+
+private:
+ //
+ // Data members
+ //
+
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
+ SmemIteratorA smem_iterator_A_;
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
+ SmemIteratorB smem_iterator_B_;
+
+ bool prologue_done_;
+
+ // Set to `True` to ensure the accumulator will be zero outside the GEMM
+ // footprint
+ bool zero_outside_bounds_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaMultistage(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
+ smem_iterator_A_(shared_storageA.ref(), thread_idx),
+ smem_iterator_B_(shared_storageB.ref(), thread_idx),
+ prologue_done_(false),
+ zero_outside_bounds_(false)
+ {
+ // Compute warp location within threadblock tile by mapping the warp_id to
+ // three coordinates:
+ // _m: the warp's position within the threadblock along the M dimension
+ // _n: the warp's position within the threadblock along the N dimension
+ // _k: the warp's position within the threadblock along the K dimension
+
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
+
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
+
+ // Add per-warp offsets in units of warp-level tiles
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+ }
+ CUTLASS_DEVICE
+ CustomMmaMultistage(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorage& st,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : CustomMmaMultistage(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
+ {
+ }
+
+ CUTLASS_DEVICE
+ bool set_prologue_done(bool value) { prologue_done_ = value; }
+
+ CUTLASS_DEVICE
+ bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ prologue(shared_storage.operand_A,
+ shared_storage.operand_B,
+ iterator_A,
+ iterator_B,
+ thread_idx,
+ problem_size_k);
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
+ SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
+ int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
+ _prologue(iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
+ }
+
+ CUTLASS_DEVICE
+ void copy_tiles_and_advance(IteratorA& iterator_A,
+ IteratorB& iterator_B,
+ int group_start_A = 0,
+ int group_start_B = 0)
+ {
+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
+ this->smem_iterator_A_.set_iteration_index(group_start_A);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(this->smem_iterator_A_.get());
+
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorA::ThreadMap::kElementsPerAccess /
+ IteratorA::kAccessesPerVector / 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+ auto gmem_ptr = iterator_A.get();
+
+ if (zero_outside_bounds_ ||
+ SharedMemoryClear == SharedMemoryClearOption::kZfill) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, gmem_ptr, iterator_A.valid());
+ } else {
+ cutlass::arch::cp_async(
+ dst_ptr + v, gmem_ptr, iterator_A.valid());
+ }
+
+ ++iterator_A;
+ }
+
+ ++this->smem_iterator_A_;
+ }
+ }
+
+ iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
+ this->smem_iterator_B_.set_iteration_index(group_start_B);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(this->smem_iterator_B_.get());
+
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorB::ThreadMap::kElementsPerAccess /
+ IteratorB::kAccessesPerVector / 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+ auto gmem_ptr = iterator_B.get();
+
+ if (zero_outside_bounds_ ||
+ SharedMemoryClear == SharedMemoryClearOption::kZfill) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, gmem_ptr, iterator_B.valid());
+ } else {
+ cutlass::arch::cp_async(
+ dst_ptr + v, gmem_ptr, iterator_B.valid());
+ }
+
+ ++iterator_B;
+ }
+ ++this->smem_iterator_B_;
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A,
+ IteratorB& iterator_B,
+ int32_t& gemm_k_iterations,
+ SmemIteratorA& smem_iterator_A_,
+ SmemIteratorB& smem_iterator_B_)
+ {
+ // Issue several complete stages
+ CUTLASS_PRAGMA_UNROLL
+ for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) {
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+
+ iterator_A.set_iteration_index(0);
+ smem_iterator_A_.set_iteration_index(0);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(smem_iterator_A_.get());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorA::ThreadMap::kElementsPerAccess /
+ IteratorA::kAccessesPerVector / 8;
+
+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
+
+ if (kLoadA) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, iterator_A.get(), iterator_A.valid());
+ }
+
+ ++iterator_A;
+ }
+
+ ++smem_iterator_A_;
+ }
+
+ iterator_B.set_iteration_index(0);
+ smem_iterator_B_.set_iteration_index(0);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(smem_iterator_B_.get());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorB::ThreadMap::kElementsPerAccess /
+ IteratorB::kAccessesPerVector / 8;
+
+ if (kLoadB) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, iterator_B.get(), iterator_B.valid());
+ }
+
+ ++iterator_B;
+ }
+
+ ++smem_iterator_B_;
+ }
+
+ // Move to the next stage
+ iterator_A.add_tile_offset({0, 1});
+ iterator_B.add_tile_offset({1, 0});
+
+ smem_iterator_A_.add_tile_offset({0, 1});
+ smem_iterator_B_.add_tile_offset({1, 0});
+
+ // Defines the boundary of a stage of cp.async.
+ cutlass::arch::cp_async_fence();
+ }
+ }
+
+ /// Perform a threadblock-scoped matrix multiply-accumulate
+ CUTLASS_DEVICE
+ void operator()(
+ ///< problem size of GEMM
+ int gemm_k_iterations,
+ ///< destination accumulator tile
+ FragmentC& accum,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ ///< initial value of accumulator
+ FragmentC const& src_accum)
+ {
+ //
+ // Prologue
+ //
+
+ if (!prologue_done_) {
+ _prologue(
+ iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
+ } else if (!kSmemContainsEntireMat) {
+ _prologue(
+ iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
+ } else {
+ gemm_k_iterations -= kNumStagesConcurrentLoad;
+ }
+
+ // Perform accumulation in the 'd' output operand
+ accum = src_accum;
+
+ //
+ // Clear the remaining tiles of SMEM. This is a functional requirement for
+ // some kernels so that all accumulator elements outside the GEMM footprint
+ // are zero.
+ //
+
+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
+ /// Iterator to write threadblock-scoped tile of A operand to shared
+ /// memory
+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
+
+ typename IteratorA::AccessType zero_A;
+ zero_A.clear();
+
+ last_smem_iterator_A.set_iteration_index(0);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(last_smem_iterator_A.get());
+
+ *dst_ptr = zero_A;
+
+ ++last_smem_iterator_A;
+ }
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared
+ /// memory
+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
+ typename IteratorB::AccessType zero_B;
+
+ zero_B.clear();
+ last_smem_iterator_B.set_iteration_index(0);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(last_smem_iterator_B.get());
+
+ *dst_ptr = zero_B;
+
+ ++last_smem_iterator_B;
+ }
+ }
+
+ // Waits until kStages-2 stages have committed.
+ cutlass::arch::cp_async_wait();
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math
+ // instructions
+ WarpLoadedFragmentA warp_loaded_frag_A[2];
+ WarpLoadedFragmentB warp_loaded_frag_B[2];
+ WarpTransformedFragmentA warp_transformed_frag_A[2];
+ WarpTransformedFragmentB warp_transformed_frag_B[2];
+
+ Operator warp_mma;
+
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
+
+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+
+ int smem_write_stage_idx = Base::kStages - 1;
+ int smem_read_stage_idx = 0;
+
+ warp_mma.transform(warp_transformed_frag_A[0],
+ warp_transformed_frag_B[0],
+ warp_loaded_frag_A[0],
+ warp_loaded_frag_B[0]);
+
+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary
+ // accumulator and this temporary accumulator is added to the final
+ // accumulator once in every mainloop iteration.
+ plus plus_accum;
+
+ FragmentC tmp_accum;
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ tmp_accum.clear();
+ }
+
+ //
+ // Mainloop
+ //
+
+ CUTLASS_GEMM_LOOP
+ for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
+ //
+ // Loop over GEMM K dimension
+ //
+
+ // Computes a warp-level GEMM on data held in shared memory
+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+ // Load warp-level tiles from shared memory, wrapping to k offset if
+ // this is the last group as the case may be.
+
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+
+ // In case of a non-circular buffer ("kSmemContainsEntireMat")
+ // make sure we don't load out of bounds data.
+ if (!kSmemContainsEntireMat || gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
+ warp_mma_k < Base::kWarpGemmIterations - 1) {
+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
+ }
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ if (warp_mma_k > 0)
+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ warp_loaded_frag_A[warp_mma_k % 2],
+ warp_loaded_frag_B[warp_mma_k % 2]);
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ warp_mma(tmp_accum,
+ warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ tmp_accum);
+
+ if (warp_mma_k == 0) {
+ accum = plus_accum(accum, tmp_accum);
+ tmp_accum.clear();
+ }
+ } else {
+ warp_mma(accum,
+ warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ accum);
+ }
+
+ // Issue global->shared copies for the this stage
+ if (!kSmemContainsEntireMat && warp_mma_k < Base::kWarpGemmIterations - 1) {
+ int group_start_iteration_A, group_start_iteration_B;
+
+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
+
+ copy_tiles_and_advance(
+ iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
+ }
+
+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
+ if (!kSmemContainsEntireMat) {
+ int group_start_iteration_A, group_start_iteration_B;
+ group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
+ group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
+
+ copy_tiles_and_advance(iterator_A,
+ iterator_B,
+ group_start_iteration_A,
+ group_start_iteration_B);
+ }
+
+ // Inserts a memory fence between stages of cp.async instructions.
+ cutlass::arch::cp_async_fence();
+
+ // Waits until kStages-2 stages have committed.
+ cutlass::arch::cp_async_wait();
+ __syncthreads();
+
+ // Move to the next stage
+ iterator_A.add_tile_offset({0, 1});
+ iterator_B.add_tile_offset({1, 0});
+
+ this->smem_iterator_A_.add_tile_offset({0, 1});
+ this->smem_iterator_B_.add_tile_offset({1, 0});
+
+ // Add negative offsets to return iterators to the 'start' of the
+ // circular buffer in shared memory
+ if (smem_write_stage_idx == (Base::kStages - 1)) {
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
+ smem_write_stage_idx = 0;
+ } else {
+ ++smem_write_stage_idx;
+ }
+
+ if (!kSmemContainsEntireMat && smem_read_stage_idx == (Base::kStages - 1)) {
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+ smem_read_stage_idx = 0;
+ } else {
+ ++smem_read_stage_idx;
+ }
+
+ --gemm_k_iterations;
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+ }
+
+ // Do any conversions feeding the first stage at the end of the loop so
+ // we can start right away on mma instructions
+ if (warp_mma_k + 1 == Base::kWarpGemmIterations)
+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
+ warp_transformed_frag_B[(warp_mma_k + 1) % 2],
+ warp_loaded_frag_A[(warp_mma_k + 1) % 2],
+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
+ }
+ }
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ accum = plus_accum(accum, tmp_accum);
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h
new file mode 100644
index 000000000000..07b26ca31299
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h
@@ -0,0 +1,388 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_conversion.h"
+
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+#include "custom_mma_base.h"
+#include "cutlass/gemm/gemm.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Iterates over tiles of A operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorA_,
+ /// Iterates over tiles of A operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorA_,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorB_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB_,
+ /// Data type of accumulator matrix
+ typename ElementC_,
+ /// Data type of accumulator matrix
+ typename LayoutC_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Transformation applied to A operand
+ typename TransformA_ = NumericArrayConverter,
+ ///
+ /// Transformation applied to B operand
+ typename TransformB_ = NumericArrayConverter,
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaPipelined : public CustomMmaBase {
+public:
+ ///< Base class
+ using Base = CustomMmaBase;
+
+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
+ using ElementC = ElementC_; ///< Data type of accumulator matrix
+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix
+ using Policy = Policy_; ///< Policy describing tuning details
+
+ using SmemIteratorA = SmemIteratorA_;
+ using SmemIteratorB = SmemIteratorB_;
+
+ using TransformA = TransformA_;
+ using TransformB = TransformB_;
+
+ //
+ // Dependent types
+ //
+
+ /// Fragment of operand A loaded from global memory
+ using FragmentA = typename IteratorA::Fragment;
+
+ /// Fragment of operand B loaded from global memory
+ using FragmentB = typename IteratorB::Fragment;
+
+ /// Fragment of accumulator tile
+ using FragmentC = typename Policy::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Obtain the arch tag from the warp-level operator
+ using ArchTag = typename Policy::Operator::ArchTag;
+
+ /// Complex transform on A operand
+ static ComplexTransform const kTransformA = Operator::kTransformA;
+
+ /// Complex transform on B operand
+ static ComplexTransform const kTransformB = Operator::kTransformB;
+
+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
+ static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2");
+
+ static bool const kSmemContainsEntireMat = false;
+
+private:
+ using WarpFragmentA = typename Operator::FragmentA;
+ using WarpFragmentB = typename Operator::FragmentB;
+
+protected:
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
+ SmemIteratorA smem_iterator_A_;
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
+ SmemIteratorB smem_iterator_B_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ int thread_idx, ///< ID within the threadblock
+ int warp_idx, ///< ID of warp
+ int lane_idx ///< ID of each thread within a warp
+ )
+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
+ smem_iterator_A_(shared_storageA.ref(), thread_idx),
+ smem_iterator_B_(shared_storageB.ref(), thread_idx)
+ {
+ // Compute warp location within threadblock tile by mapping the warp_id to
+ // three coordinates:
+ // _m: the warp's position within the threadblock along the M dimension
+ // _n: the warp's position within the threadblock along the N dimension
+ // _k: the warp's position within the threadblock along the K dimension
+
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
+
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
+
+ // Add per-warp offsets in units of warp-level tiles
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+ }
+ CUTLASS_DEVICE
+ CustomMmaPipelined(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorage& st,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
+ {
+ }
+
+ CUTLASS_DEVICE
+ bool set_prologue_done(bool value)
+ {
+ // NOT IMPLEMENTED FOR PIPELINED
+ }
+
+ CUTLASS_DEVICE
+ bool set_zero_outside_bounds(bool value)
+ {
+ // NOT NEEDED FOR PIPELINED
+ // shared memory will always be zero-filled
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ prologue(shared_storage.operand_A,
+ shared_storage.operand_B,
+ iterator_A,
+ iterator_B,
+ thread_idx,
+ problem_size_k);
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ // NOT IMPLEMENTED FOR PIPELINED
+ }
+
+ /// Perform a threadblock-scoped matrix multiply-accumulate
+ CUTLASS_DEVICE
+ void operator()(
+ int gemm_k_iterations, ///< number of iterations of the mainloop
+ FragmentC& accum, ///< destination accumulator tile
+ IteratorA iterator_A, ///< iterator over A operand in global memory
+ IteratorB iterator_B, ///< iterator over B operand in global memory
+ FragmentC const& src_accum, ///< source accumulator tile
+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment
+ TransformB transform_B = TransformB())
+ { ///< transformation applied to B fragment
+
+ //
+ // Prologue
+ //
+
+ // Perform accumulation in the 'd' output operand
+ accum = src_accum;
+
+ FragmentA tb_frag_A;
+ FragmentB tb_frag_B;
+
+ tb_frag_A.clear();
+ tb_frag_B.clear();
+
+ // The last kblock is loaded in the prolog
+ iterator_A.load(tb_frag_A);
+ iterator_B.load(tb_frag_B);
+
+ ++iterator_A;
+ ++iterator_B;
+
+ this->smem_iterator_A_.store(transform_A(tb_frag_A));
+ this->smem_iterator_B_.store(transform_B(tb_frag_B));
+
+ ++this->smem_iterator_A_;
+ ++this->smem_iterator_B_;
+
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math
+ // instructions
+ WarpFragmentA warp_frag_A[2];
+ WarpFragmentB warp_frag_B[2];
+
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
+
+ this->warp_tile_iterator_A_.load(warp_frag_A[0]);
+ this->warp_tile_iterator_B_.load(warp_frag_B[0]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ Operator warp_mma;
+
+ int smem_write_stage_idx = 1;
+
+ // Avoid reading out of bounds
+ iterator_A.clear_mask(gemm_k_iterations <= 1);
+ iterator_B.clear_mask(gemm_k_iterations <= 1);
+
+ // Issue loads during the first warp-level matrix multiply-add *AFTER*
+ // issuing shared memory loads (which have the tightest latency requirement).
+
+ //
+ // Mainloop
+ //
+
+ // Note: The main loop does not support Base::kWarpGemmIterations == 2.
+ CUTLASS_GEMM_LOOP
+ for (; gemm_k_iterations > 0; --gemm_k_iterations) {
+ //
+ // Loop over GEMM K dimension
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+ // Load warp-level tiles from shared memory, wrapping to k offset if
+ // this is the last group as the case may be.
+
+ if (warp_mma_k == Base::kWarpGemmIterations - 1) {
+ // Write fragments to shared memory
+ this->smem_iterator_A_.store(transform_A(tb_frag_A));
+
+ this->smem_iterator_B_.store(transform_B(tb_frag_B));
+
+ __syncthreads();
+
+ ++this->smem_iterator_A_;
+ ++this->smem_iterator_B_;
+
+ // Add negative offsets to return iterators to the 'start' of the
+ // circular buffer in shared memory
+ if (smem_write_stage_idx == 1) {
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
+ } else {
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+ }
+
+ smem_write_stage_idx ^= 1;
+ }
+
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+
+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ if (warp_mma_k == 0) {
+ iterator_A.load(tb_frag_A);
+ iterator_B.load(tb_frag_B);
+
+ ++iterator_A;
+ ++iterator_B;
+
+ // Avoid reading out of bounds if this was the last loop iteration
+ iterator_A.clear_mask(gemm_k_iterations <= 2);
+ iterator_B.clear_mask(gemm_k_iterations <= 2);
+ }
+
+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
+ }
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h
new file mode 100644
index 000000000000..163dcbf85259
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h
@@ -0,0 +1,191 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Cutlass provides helper template functions to figure out the right
+ datastructures to instantiate to run a GEMM with various parameters (see
+ `cutlass/gemm/threadblock/default_mma.h`). However, due to template
+ instantiation priority rules, it will only create an MmaMultiStage with
+ kStages=3 (otherwise creates an MmePipelined - which is not compatible with
+ FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
+ so we just copy-pasted some code from `default_mma.h` and
+ `default_mma_core.h` files and wrapped this template to allow our usecase.
+
+ This is really only for the FastF32 case - aka using TensorCores with fp32.
+*/
+
+#pragma once
+
+#include "cutlass/gemm/threadblock/default_mma.h"
+#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Layout type for B matrix operand
+ typename LayoutB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Layout type for C and D matrix operand
+ typename LayoutC,
+ /// Operator class tag
+ typename OperatorClass,
+ /// Tag indicating architecture to tune for
+ typename ArchTag,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Number of stages used in the pipelined mainloop
+ int Stages,
+ /// Operation performed by GEMM
+ typename Operator,
+ typename Enable_ = void>
+struct FindDefaultMma {
+ static constexpr bool AccumulatorsInRowMajor = false;
+ static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone;
+ using DefaultMma = cutlass::gemm::threadblock::DefaultMma;
+};
+
+/// Specialization for sm80 / FastF32 / multistage with kStages=2
+template
+struct FindDefaultMma 1)>::type> {
+ using LayoutC = layout::RowMajor;
+ using OperatorClass = arch::OpClassTensorOp;
+ using ArchTag = arch::Sm80;
+
+ using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma;
+ struct DefaultMma : DefaultMma_ {
+ using MmaCore_ = typename DefaultMma_::MmaCore;
+ // Define the threadblock-scoped multistage matrix multiply
+ using ThreadblockMma =
+ cutlass::gemm::threadblock::MmaMultistage;
+ };
+};
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h
new file mode 100644
index 000000000000..5e2f0cf681bf
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h
@@ -0,0 +1,347 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include "cutlass/functional.h"
+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
+#include "cutlass/matrix_shape.h"
+
+/*
+TensorCores have different accumulator layouts.
+This file provides a class to easily map the accumulator
+i-th element with the corresponding matrix row/col.
+*/
+
+template
+struct AccumLambdaIteratorSm80 {
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ using Policy = typename T::Policy;
+ using InstructionShape = typename T::InstructionShape;
+ using OpDelta = typename T::OpDelta;
+ using Shape = typename T::Shape;
+ static int const kElementsPerAccess = InstructionShape::kN / 4;
+ static int const kRowsPerTile = 8;
+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ int quad = (lane_id >> 2);
+ int lane_in_quad = (lane_id & 3);
+ return cutlass::MatrixCoord(
+ quad + tile_offset.row() * Shape::kRow,
+ lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn);
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int row = 0; row < kAccumulatorRows; ++row) {
+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile +
+ lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
+ (mma_n * Policy::MmaIterations::kRow + mma_m);
+ CUTLASS_PRAGMA_UNROLL
+ for (int col = 0; col < kElementsPerAccess; ++col) {
+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col +
+ lane_offset.column();
+ int idx = mma_accum_start + row * kElementsPerAccess + col;
+ op(accum_m, accum_n, idx);
+ }
+ }
+
+ endRow(accum_m);
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ // In each warp, 4 threads will work on the same row
+ // - the ones with the same `quad`
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
+ myValue = fn(myValue, otherV);
+ otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
+ myValue = fn(myValue, otherV);
+ int lane_in_quad = (lane_id & 3);
+ return lane_in_quad == 0;
+ }
+};
+
+template
+struct AccumLambdaIteratorSm70 {
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ using Policy = typename T::Policy;
+ using InstructionShape = typename T::InstructionShape;
+ using OpDelta = typename T::OpDelta;
+ using Shape = typename T::Shape;
+ using Element = accum_t;
+
+ static int const kElementsPerPartial = 4;
+ using EleShapePerPatial =
+ typename cutlass::platform::conditional::value,
+ cutlass::MatrixShape<2, 2>,
+ cutlass::MatrixShape<1, 4>>::type;
+ static int const kElementsPerMma = 8;
+ static int const kAccumulatorPatials = 2;
+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ int quad = (lane_id >> 2);
+ int lane_in_quad = (lane_id & 3);
+ int accum_m, accum_n;
+
+ if (cutlass::platform::is_same::value) {
+ // (quad[2],quad[0])+lane_in_quad[0]
+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
+ // (quad[1])+lane_in_quad[1]
+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
+ (lane_in_quad & 2);
+ } else {
+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0])
+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
+ }
+ return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow,
+ accum_n + tile_offset.column() * Shape::kColumn);
+ }
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ static_assert(cutlass::platform::is_same::value,
+ "update to support non-float accum");
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
+ // T0 & T2 share same line within a quad
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
+ myValue = fn(myValue, otherV);
+ // quad 0 and quad 2 are on the same lines
+ otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
+ myValue = fn(myValue, otherV);
+ return (lane_id & ((1 << 1) | (1 << 3))) == 0;
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
+ int accum_m = tile_m * Policy::InterleavedTile::kRow +
+ mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < kAccumulatorPatials; ++p) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
+ int mma_accum_start =
+ (((tile_n * Policy::TileIterations::kRow + tile_m) *
+ Policy::MmaIterations::kColumn +
+ mma_n) *
+ Policy::MmaIterations::kRow +
+ mma_m) *
+ kElementsPerMma;
+ int accum_n = tile_n * Policy::InterleavedTile::kColumn +
+ mma_n * QuadShapePerPatialMma::kColumn +
+ p * Policy::InterleavedTile::kColumn / 2 + n +
+ lane_offset.column();
+ int idx = mma_accum_start + p * kElementsPerPartial +
+ m * EleShapePerPatial::kColumn + n;
+ op(accum_m, accum_n, idx);
+ }
+ }
+ }
+ }
+ endRow(accum_m);
+ }
+ }
+ }
+ }
+};
+
+template
+struct AccumLambdaIteratorSimt {
+ using Policy = typename T::Policy;
+ using Iterations = typename T::Iterations;
+ using Element = typename T::Element;
+ using Delta = typename T::Delta;
+ using Shape = typename T::Shape;
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
+ myValue = fn(myValue, otherV);
+ }
+ return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
+ int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
+ int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
+ lane_offset.column();
+ CUTLASS_PRAGMA_UNROLL
+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
+ int idx = n + Policy::LaneMmaShape::kN *
+ (mma_n + Iterations::kColumn *
+ (m + mma_m * Policy::LaneMmaShape::kM));
+ op(accum_m, accum_n + n, idx);
+ }
+ }
+ endRow(accum_m);
+ }
+ }
+ }
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ static_assert(cutlass::platform::is_same>::value,
+ "");
+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
+
+ cutlass::MatrixCoord lane_offset =
+ lane_layout.inverse(lane_id) *
+ cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
+ return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
+ }
+};
+
+template
+struct DefaultMmaAccumLambdaIterator;
+
+// Simt
+template
+struct DefaultMmaAccumLambdaIterator<
+ cutlass::gemm::warp::MmaSimtTileIterator,
+ accum_t,
+ kWarpSize> {
+ using WarpIterator =
+ typename cutlass::gemm::warp::MmaSimtTileIterator;
+ using Iterator = AccumLambdaIteratorSimt