forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
fourier_mix.py
35 lines (27 loc) · 1.15 KB
/
fourier_mix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
@register_attention("fourier_mix", AttentionConfig)
class FourierMix(Attention):
def __init__(self, dropout: float, *_, **__):
"""
FFT-based pseudo-attention mechanism, from
"
"FNet: Mixing Tokens with Fourier Transforms"
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.requires_input_projection = False
def forward(self, q: torch.Tensor, *_, **__):
# Guard against autocast / fp16, not supported by torch.fft.fft2
with autocast(enabled=False):
att = torch.fft.fft2(q).real
att = self.attn_drop(att)
return att