Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongformerSelfAttention code분석 #20

Open
changyong93 opened this issue Dec 2, 2021 · 0 comments
Open

LongformerSelfAttention code분석 #20

changyong93 opened this issue Dec 2, 2021 · 0 comments
Assignees
Labels
report Sharing information or results of analysis

Comments

@changyong93
Copy link
Contributor

changyong93 commented Dec 2, 2021

Longformer: The Long-Document Transformer paper view paper veview


transformers.models.longformer.modeling_longformer.py

# attention_window = config.attention_window[self.layer_id]
num_heads = 4
attention_window = 4 #임시로 attention_window 값 지정
one_sided_attn_window_size = attention_window // 2 # 2

def forward(...) 함수 실행 시

# hidden_state 임시로 지정
hidden_states = torch.arange(4*12*768).view(4,12,768).contiguous() # b,s,d = 4,12,768
# 0~36863 값이 (4,12,768) = (batch(b), seqlen(s), h_dim(d))로 변환

hidden_states = hidden_states.transpose(0, 1) # (s,b,d) = (12,4,768)
q,k,v = hidden_states ,hidden_states ,hidden_states # 이해를 위한 임시 지정
seq_len, batch_size, embed_dim = 12,4,768

query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) #(s,b,d) => (s,b,h,w) => (b,s,h,w)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) #(s,b,d) => (s,b,h,w) => (b,s,h,w)

# 아래 코드가 실행되어 sliding window를 활용하여 local attention 값을 구함
attn_scores = self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size)

self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size) 분석

b, s, h, w= query.size() # (b,s,h,w) = (4,12,4,192)
chunks_count = s// window_overlap - 1
# window_overlap =one_sided_attn_window_size = 2 이므로, 12 // 2 -1 = 5

# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = query.transpose(1, 2).reshape(b* h, s, d) #(b,s,h,d) => (b,h,s,d) => (b*h,s,d)
key = key.transpose(1, 2).reshape(b* h, s, d) #(b,s,h,d) => (b,h,s,d) => (b*h,s,d)
#(b,h,s,d) => 각 batch별로, head에 따라 seuqence의 dim 정보가 나뉘고
# (b*h,s,d) => b*h로 하여 차원을 하나 축소를 해줌, 이후 sequence를 chunk로 나누기 위한 과정
# 따라서, batch별 head에 따른 sequence의 dim 정보가 순서대로 출력이 됨.
# b[0]*h[0]에는, 첫 번째 batch에서 각 sequence의 1번째 head에 해당하는 정보를 가지고 있음.
# b[0]*h[1]에는, 첫 번째 batch에서 각 sequence의 2번째 head에 해당하는 정보를 가지고 있음.
# b[n]*h[k]에는, n번째 batch에서 각 sequence의 k번째 head에 해당하는 정보를 가지고 있음.

query = self._chunk(query, window_overlap)
key = self._chunk(key, window_overlap)

_sliding_chunks_query_key_matmul 함수 내부의 _chunk(query, window_overlap)분석

"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""

# non-overlapping chunks of size = 2w
# hidden_states.size() = (b*h,s,w) = (16,12,192)
#우선 b*h별 각 sequence를 overlapping되지 않게 window size로 재생성
hidden_states = hidden_states.view( # (b*h, s,w) =>(b*h, s//window_size, window_size, w) => (16,12,192)=>(16,3,4,192)
    hidden_states.size(0), # 16
    hidden_states.size(1) // (window_overlap * 2), #12 //(2*2) = 3
    window_overlap * 2, # (2*2) = 4
    hidden_states.size(2), # 192
)

# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(hidden_states.size()) # (b*h, s//window_size, window_size, w) = (16,3,4,192)
chunk_size[1] = chunk_size[1] * 2 - 1 #3=> 3*2-1 = 5 => (16,5,4,192)
#이전에 overlapping 되지 않도록 변경 후, window_size의 절반에 해당하는 길이만큼 오버랩을 시켰을 때
# 변경되는 형태의 해당 chunk_size를 계산
# 아래와 같이 되도록 변경해야 함
# [1,2,3,4],[5,6,7,8],[9,10,11,12] => [1,2,3,4],[3,4,5,6],[5,6,7,8],[7,8,9,10],[9,10,11,12]

chunk_stride = list(hidden_states.stride()) #2304,768,192,1)
chunk_stride[1] = chunk_stride[1] // 2
#원하는 형태로 변경을 위해서 기존의 hidden_states를 어떻게 변경할 지 지정
# chunk_stride[1] = 384//2 = 384
# chunk_stride = (2304,384,192,1)
# stride는 각 dim별로 다음 값으로 이동하기 위한 거리를 출력해줌
# sstride 값 계산은, 우선 storage(?) 형태로 tensor 값을 1차원으로 변경
# 이 때, tensor (-1) dim으로 한 칸 이동하기 위해선 오른쪽으로 한 칸만 이동하면 됨
# 반면 tensor (-2) dim으로 한 칸 이동하기 위해선 (-1) dim을 모두 이동해야 하므로 (-1) dim 크기만큼 됨
# 이런 방식으로 전체 stride가 계산됨.

return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# as_strided는 tensor.as_stride(...) 형태로 쓰이며, 해석하자면, tensor를 chunk_size로 변경하는데, 이 때 chunk_stride의 stride로 이동하며 값을 취득
# [1,2,3,4],[5,6,7,8],[9,10,11,12] => [1,2,3,4],[3,4,5,6],[5,6,7,8],[7,8,9,10],[9,10,11,12] 형태로 변경됨

다시 _sliding_chunks_query_key_matmul() 함수 분석

# matrix multiplication
# bcxd: (batch_size * num_head)s x (chunks) x (2window_overlap) x (head_dim)
# bcyd: (batch_size * num_heads) x (chunks) x (2window_overlap) x (head_dim)
# bcxy: (batch_size * num_head)s x (chunks) x (2window_overlap) x (2window_overlap)
diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key))  # multiply
# 우리가 알던, query,key matmul이며 matmul 대신 einsum() 함수를 활용
# 단 matmul 계산 시 key 또는 query를 Transpose 하여 계산했지만, enisum은 현재 상태에서 추가 가공 없이 연산이 가능함
# (16,5,4,192) * (16,5,4,192) => (16,5,4,4)

# convert diagonals into columns
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
    diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
)

self._pad_and_transpose_last_two_dims(diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)) 분석

"""pads rows and then flips rows and columns"""
hidden_states_padded = nn.functional.pad(
    hidden_states_padded, padding)  # padding value is not important because it will be overwritten
# torch.nn.functional.pad(diagonal_chunked_attention_scores, (0, 0, 0, 1))
# diagonal_chunked_attention_scores에서, 마지막 dimension에 padding 추가
# ex
## tmp = [[1,2,3],[4,5,6],[7,8,9]])
## torch.nn.functional.pad(tmp,(0,1)) => [[1,2,3],[4,5,6],[7,8,9],[0,0,0]]

hidden_states_padded = hidden_states_padded.view( 
    *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2))
# upper trianguler matrix 정보를 가져오기 위하여 변환
#[[1,2,3],[4,5,6]]에서 upper trianguler matrix 값은 [1,2,3] 과 [5,6]],[9]
# 따라서 해당 변환 시 [[1,2,3],[4,5,6],[0,0,0]] => [1,2,3,4],[5,6,7,8],[9,0,0,0]
return hidden_states_padded
# (16,5,4,4) => (16,5,5,4) => (16,5,4,5) #(batch*head, chunk, window, side_window_overlap + token, side_window_overlap )
## _pad_and_transpose_last_two_dims의 출력으로 나오는 tensor에 최종적으로 담을 정보는 각 token별 token 양측의 window/2에 대한 정보를 담는 것 
## 

다시 _sliding_chunks_query_key_matmul 분석

diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
    (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
)
# 값을 저장하기 위한 임시 공간 생성
# b*h = 16
# chunks_count + 1 = 6
# window_overlap = 2
#  window_overlap * 2 + 1 = 5
# batch *head별 sequence length가 12이므로, 이 때 2개의 token씩 overlapping되면 총 5개의 chunk가 생성됨
## e.g.) seq = [1,2,3,4,5,6,7,8,9,10,11,12]
## chunk = [1,2,3,4,],3,4,5,6,],[5,6,7,8,],[7,8,9,10],[9,10,11,12]
# 이 때, 중복되는 2개씩 따로 생성할 경우, set(중복개수) 시 총 6개로 나뉨
## 1,2 / 3,4, / 5,6 / 7,8 / ,9,10 / 11,12
# 그리고 window size가 4일 경우, 기준 토큰 좌우 2개씩 보므로 총 5개의 값을 저정해야 해서 마지막 dim은 5로 됨


# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# 아래 코드를 통하여 기준 token 기준 좌/우 window_size/2만큼의 attention 정보를 저장
# - copying the main diagonal and the upper triangle
diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
    :, :, :window_overlap, : window_overlap + 1
]
diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
    :, -1, window_overlap:, : window_overlap + 1
]
# - copying the lower triangle
diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
    :, :, -(window_overlap + 1) : -1, window_overlap + 1 :
]

diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
    :, 0, : window_overlap - 1, 1 - window_overlap :
]

# separate batch_size and num_heads dimensions again
diagonal_attention_scores = diagonal_attention_scores.view(
    batch_size, num_heads, seq_len, 2 * window_overlap + 1
).transpose(2, 1)
#(b*h, chunks_count + 1, window_overlap, window_overlap*2+1), (16,6,2,5) 
# => (b,h,s,window_overlap*2+1), (4,4,12,5)
# => (b,s,h,window_overlap*2+1) (4,12,4,5)

self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
return diagonal_attention_scores

self._mask_invalid_locations(diagonal_attention_scores, window_overlap) 분석

@staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
    # token별로 window_size//2만큼 좌우측 token들과의 attention을 계산함
   # 첫 번째 토큰과, 마지막 토큰의 경우 가장자리에 위치하여 편측 정보만을 활용해야 함
   # 단, 문제는 현재까지 구현된 코드로는 sequence를 벗어나는 값이 random하게 지정되어 있음
   # 이 부분을 해결하고자 해당 함수를 사용함

   #window가 sequence를 벗어나는 경우에 해당하는 sparse matrix 생성
   #ex) affected_seq_len(==windowsize//2) = 2 일때, window가 sequence를 벗어나는 token은 1,2번째 토큰
   # 토큰의 좌측에 2개의 token과도 self-attention을 해야 하는데 존재하지 않음
   # 따라서 해당 값을 maksing 해주기 위해 아래와 같이 코드 실행됨
    beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) #(w//2, w//2+1), reversed lower triangle matrix [[1,1,0][1,0,0]]
 
    # attention과 동일한 차원을 갖도록 변경 
    beginning_mask = beginning_mask_2d[None, :, None, :] #(1,w//2, 1, w//2+1)
    
    # 마지막 토큰의 경우도 동일하므로, beginning_mask 를 flip
    ending_mask = beginning_mask.flip(dims=(1, 3))

    # input_tensor은 위에서 구한 attention tensor이며, masking이 필요한 구간만 선택
    beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] #(batch, w//2, head, w//2+1)

    # mask를 beginning_input 과 동일한 차원이 되도록 expand
    beginning_mask = beginning_mask.expand(beginning_input.size())
   
    # beginning_mask 중 원소 값이 1을 갖는 경우, beginning_input의 해당 위치에 -float("inf")로 지정
    beginning_input.masked_fill_(beginning_mask == 1, -float("inf"))  # `== 1` converts to bool or uint8

    # end도 동일한 방식으로 적용
    ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
    ending_mask = ending_mask.expand(ending_input.size())
    ending_input.masked_fill_(ending_mask == 1, -float("inf"))  # `== 1` converts to bool or uint8

self._sliding_chunks_query_key_matmul() 함수가 끝났으므로 다시 forward 분석!

attn_scores = _sliding_chunks_query_key_matmul(...) #이전 과정을 통해 구한 attn_scores

# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
# max_seq_length보다 짧아서 pad가 추가된 부분을 masking

# cast to fp32/fp16 then replace 1's with -inf
# remove_from_windowed_attention_mask에서 값이 True인 경우 -10000으로 변환
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
    remove_from_windowed_attention_mask, -10000.0)

# diagonal mask with zeros everywhere and -inf inplace of padding
# 앞서 확인한 _sliding_chunks_query_key_matmul을 통하여 mask 생성
diagonal_mask = self._sliding_chunks_query_key_matmul(
    float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)

# pad local attention probs
attn_scores += diagonal_mask #기존 attention에 mask 추가!

assert list(attn_scores.size()) == [
    batch_size,
    seq_len,
    self.num_heads,
    self.one_sided_attn_window_size * 2 + 1,
], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"

우선 여기까지 구했으면, sliding window를 활용한 local attention은 계산 완료
그리고 다음 코드 부턴 global attntion!!

# compute local attention probs from global attention keys and contact over window dim
if is_global_attn: # is_global_attn 있을 경우만 진행
    # compute global attn indices required through out forward fn
    (
        max_num_global_attn_indices,
        is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero,
    ) = self._get_global_attn_indices(is_index_global_attn) #해당 코드는 어렵지 않으니 직접 확인해볼 것
    # calculate global attn probs from global key

    global_key_attn_scores = self._concat_with_global_key_attn_probs( #global attention idx를 기반으로 attention을 계산하여 반환
        query_vectors=query_vectors,
        key_vectors=key_vectors,
        max_num_global_attn_indices=max_num_global_attn_indices,
        is_index_global_attn_nonzero=is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
    )
    # concat to local_attn_probs
    # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
    attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) #global과 local attention을 합침

    # free memory
    del global_key_attn_scores #리소스 관리를 위한 삭제

#attention weight 계산
attn_probs = nn.functional.softmax(
    attn_scores, dim=-1, dtype=torch.float32
)  # use fp32 for numerical stability

if layer_head_mask is not None:
    assert layer_head_mask.size() == ( #head 자체에 masking
        self.num_heads,
    ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
    attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs

# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores)

# free memory
del attn_scores

# apply dropout
attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)

value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)

# compute local attention output with global attention value and add
# 로컬 어텐션 아웃풋 계산
if is_global_attn:
    # compute sum of global and local attn
    attn_output = self._compute_attn_output_with_global_indices(
        value_vectors=value_vectors,
        attn_probs=attn_probs,
        max_num_global_attn_indices=max_num_global_attn_indices,
        is_index_global_attn_nonzero=is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
    )
else:
    # compute local attn only
    attn_output = self._sliding_chunks_matmul_attn_probs_value(
        attn_probs, value_vectors, self.one_sided_attn_window_size
    )

assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()

# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
#글로벌 어텐션 아웃풋 계산
if is_global_attn:
    global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
        hidden_states=hidden_states,
        max_num_global_attn_indices=max_num_global_attn_indices,
        layer_head_mask=layer_head_mask,
        is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
        is_index_global_attn_nonzero=is_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
        is_index_masked=is_index_masked,
    )

    # get only non zero global attn output
    nonzero_global_attn_output = global_attn_output[
        is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
    ]

    # overwrite values with global attention
    attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
        len(is_local_index_global_attn_nonzero[0]), -1
    )
    # The attention weights for tokens with global attention are
    # just filler values, they were never used to compute the output.
    # Fill with 0 now, the correct values are in 'global_attn_probs'.
    attn_probs[is_index_global_attn_nonzero] = 0

outputs = (attn_output.transpose(0, 1),)

if output_attentions:
    outputs += (attn_probs,)

return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs

여기까지 하면 LongformerSelfAttention 코드 분석은 끝입니다.
초반 코드에 비해 밑으로 갈 수록 코드에 대한 분석이 줄어드는데, 이 부분들은 어렵지 않다고 판단하여 작성만 하고 넘어갔습니다.
설명이 부족한 코드는 설명보단 눈으로 한 번 확인하면 쉽게 확인하실 수 있으실 겁니다.

조금 난잡하게 쓰긴 했지만, 조금이라도 도움이 됐으면 합니다.

@changyong93 changyong93 added the report Sharing information or results of analysis label Dec 2, 2021
@changyong93 changyong93 self-assigned this Dec 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
report Sharing information or results of analysis
Projects
None yet
Development

No branches or pull requests

1 participant