You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# hidden_state 임시로 지정hidden_states=torch.arange(4*12*768).view(4,12,768).contiguous() # b,s,d = 4,12,768# 0~36863 값이 (4,12,768) = (batch(b), seqlen(s), h_dim(d))로 변환hidden_states=hidden_states.transpose(0, 1) # (s,b,d) = (12,4,768)q,k,v=hidden_states ,hidden_states ,hidden_states# 이해를 위한 임시 지정seq_len, batch_size, embed_dim=12,4,768query_vectors=query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) #(s,b,d) => (s,b,h,w) => (b,s,h,w)key_vectors=key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) #(s,b,d) => (s,b,h,w) => (b,s,h,w)# 아래 코드가 실행되어 sliding window를 활용하여 local attention 값을 구함attn_scores=self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size)
self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size) 분석
b, s, h, w=query.size() # (b,s,h,w) = (4,12,4,192)chunks_count=s//window_overlap-1# window_overlap =one_sided_attn_window_size = 2 이므로, 12 // 2 -1 = 5# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2query=query.transpose(1, 2).reshape(b*h, s, d) #(b,s,h,d) => (b,h,s,d) => (b*h,s,d)key=key.transpose(1, 2).reshape(b*h, s, d) #(b,s,h,d) => (b,h,s,d) => (b*h,s,d)#(b,h,s,d) => 각 batch별로, head에 따라 seuqence의 dim 정보가 나뉘고# (b*h,s,d) => b*h로 하여 차원을 하나 축소를 해줌, 이후 sequence를 chunk로 나누기 위한 과정# 따라서, batch별 head에 따른 sequence의 dim 정보가 순서대로 출력이 됨.# b[0]*h[0]에는, 첫 번째 batch에서 각 sequence의 1번째 head에 해당하는 정보를 가지고 있음.# b[0]*h[1]에는, 첫 번째 batch에서 각 sequence의 2번째 head에 해당하는 정보를 가지고 있음.# b[n]*h[k]에는, n번째 batch에서 각 sequence의 k번째 head에 해당하는 정보를 가지고 있음.query=self._chunk(query, window_overlap)
key=self._chunk(key, window_overlap)
_sliding_chunks_query_key_matmul 함수 내부의 _chunk(query, window_overlap)분석
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""# non-overlapping chunks of size = 2w# hidden_states.size() = (b*h,s,w) = (16,12,192)#우선 b*h별 각 sequence를 overlapping되지 않게 window size로 재생성hidden_states=hidden_states.view( # (b*h, s,w) =>(b*h, s//window_size, window_size, w) => (16,12,192)=>(16,3,4,192)hidden_states.size(0), # 16hidden_states.size(1) // (window_overlap*2), #12 //(2*2) = 3window_overlap*2, # (2*2) = 4hidden_states.size(2), # 192
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlapchunk_size=list(hidden_states.size()) # (b*h, s//window_size, window_size, w) = (16,3,4,192)chunk_size[1] =chunk_size[1] *2-1#3=> 3*2-1 = 5 => (16,5,4,192)#이전에 overlapping 되지 않도록 변경 후, window_size의 절반에 해당하는 길이만큼 오버랩을 시켰을 때# 변경되는 형태의 해당 chunk_size를 계산# 아래와 같이 되도록 변경해야 함# [1,2,3,4],[5,6,7,8],[9,10,11,12] => [1,2,3,4],[3,4,5,6],[5,6,7,8],[7,8,9,10],[9,10,11,12]chunk_stride=list(hidden_states.stride()) #2304,768,192,1)chunk_stride[1] =chunk_stride[1] //2#원하는 형태로 변경을 위해서 기존의 hidden_states를 어떻게 변경할 지 지정# chunk_stride[1] = 384//2 = 384# chunk_stride = (2304,384,192,1)# stride는 각 dim별로 다음 값으로 이동하기 위한 거리를 출력해줌# sstride 값 계산은, 우선 storage(?) 형태로 tensor 값을 1차원으로 변경# 이 때, tensor (-1) dim으로 한 칸 이동하기 위해선 오른쪽으로 한 칸만 이동하면 됨# 반면 tensor (-2) dim으로 한 칸 이동하기 위해선 (-1) dim을 모두 이동해야 하므로 (-1) dim 크기만큼 됨# 이런 방식으로 전체 stride가 계산됨.returnhidden_states.as_strided(size=chunk_size, stride=chunk_stride)
# as_strided는 tensor.as_stride(...) 형태로 쓰이며, 해석하자면, tensor를 chunk_size로 변경하는데, 이 때 chunk_stride의 stride로 이동하며 값을 취득# [1,2,3,4],[5,6,7,8],[9,10,11,12] => [1,2,3,4],[3,4,5,6],[5,6,7,8],[7,8,9,10],[9,10,11,12] 형태로 변경됨
다시 _sliding_chunks_query_key_matmul() 함수 분석
# matrix multiplication# bcxd: (batch_size * num_head)s x (chunks) x (2window_overlap) x (head_dim)# bcyd: (batch_size * num_heads) x (chunks) x (2window_overlap) x (head_dim)# bcxy: (batch_size * num_head)s x (chunks) x (2window_overlap) x (2window_overlap)diagonal_chunked_attention_scores=torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply# 우리가 알던, query,key matmul이며 matmul 대신 einsum() 함수를 활용# 단 matmul 계산 시 key 또는 query를 Transpose 하여 계산했지만, enisum은 현재 상태에서 추가 가공 없이 연산이 가능함# (16,5,4,192) * (16,5,4,192) => (16,5,4,4)# convert diagonals into columnsdiagonal_chunked_attention_scores=self._pad_and_transpose_last_two_dims(
diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
)
self._pad_and_transpose_last_two_dims(diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)) 분석
"""pads rows and then flips rows and columns"""hidden_states_padded=nn.functional.pad(
hidden_states_padded, padding) # padding value is not important because it will be overwritten# torch.nn.functional.pad(diagonal_chunked_attention_scores, (0, 0, 0, 1))# diagonal_chunked_attention_scores에서, 마지막 dimension에 padding 추가# ex## tmp = [[1,2,3],[4,5,6],[7,8,9]])## torch.nn.functional.pad(tmp,(0,1)) => [[1,2,3],[4,5,6],[7,8,9],[0,0,0]]hidden_states_padded=hidden_states_padded.view(
*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2))
# upper trianguler matrix 정보를 가져오기 위하여 변환#[[1,2,3],[4,5,6]]에서 upper trianguler matrix 값은 [1,2,3] 과 [5,6]],[9]# 따라서 해당 변환 시 [[1,2,3],[4,5,6],[0,0,0]] => [1,2,3,4],[5,6,7,8],[9,0,0,0]returnhidden_states_padded# (16,5,4,4) => (16,5,5,4) => (16,5,4,5) #(batch*head, chunk, window, side_window_overlap + token, side_window_overlap )## _pad_and_transpose_last_two_dims의 출력으로 나오는 tensor에 최종적으로 담을 정보는 각 token별 token 양측의 window/2에 대한 정보를 담는 것 ##
다시 _sliding_chunks_query_key_matmul 분석
diagonal_attention_scores=diagonal_chunked_attention_scores.new_empty(
(batch_size*num_heads, chunks_count+1, window_overlap, window_overlap*2+1)
)
# 값을 저장하기 위한 임시 공간 생성# b*h = 16# chunks_count + 1 = 6# window_overlap = 2# window_overlap * 2 + 1 = 5# batch *head별 sequence length가 12이므로, 이 때 2개의 token씩 overlapping되면 총 5개의 chunk가 생성됨## e.g.) seq = [1,2,3,4,5,6,7,8,9,10,11,12]## chunk = [1,2,3,4,],3,4,5,6,],[5,6,7,8,],[7,8,9,10],[9,10,11,12]# 이 때, 중복되는 2개씩 따로 생성할 경우, set(중복개수) 시 총 6개로 나뉨## 1,2 / 3,4, / 5,6 / 7,8 / ,9,10 / 11,12# 그리고 window size가 4일 경우, 기준 토큰 좌우 2개씩 보므로 총 5개의 값을 저정해야 해서 마지막 dim은 5로 됨# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions# 아래 코드를 통하여 기준 token 기준 좌/우 window_size/2만큼의 attention 정보를 저장# - copying the main diagonal and the upper trianglediagonal_attention_scores[:, :-1, :, window_overlap:] =diagonal_chunked_attention_scores[
:, :, :window_overlap, : window_overlap+1
]
diagonal_attention_scores[:, -1, :, window_overlap:] =diagonal_chunked_attention_scores[
:, -1, window_overlap:, : window_overlap+1
]
# - copying the lower trianglediagonal_attention_scores[:, 1:, :, :window_overlap] =diagonal_chunked_attention_scores[
:, :, -(window_overlap+1) : -1, window_overlap+1 :
]
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] =diagonal_chunked_attention_scores[
:, 0, : window_overlap-1, 1-window_overlap :
]
# separate batch_size and num_heads dimensions againdiagonal_attention_scores=diagonal_attention_scores.view(
batch_size, num_heads, seq_len, 2*window_overlap+1
).transpose(2, 1)
#(b*h, chunks_count + 1, window_overlap, window_overlap*2+1), (16,6,2,5) # => (b,h,s,window_overlap*2+1), (4,4,12,5)# => (b,s,h,window_overlap*2+1) (4,12,4,5)self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
returndiagonal_attention_scores
self._mask_invalid_locations(diagonal_attention_scores, window_overlap) 분석
@staticmethoddef_mask_invalid_locations(input_tensor, affected_seq_len) ->torch.Tensor:
# token별로 window_size//2만큼 좌우측 token들과의 attention을 계산함# 첫 번째 토큰과, 마지막 토큰의 경우 가장자리에 위치하여 편측 정보만을 활용해야 함# 단, 문제는 현재까지 구현된 코드로는 sequence를 벗어나는 값이 random하게 지정되어 있음# 이 부분을 해결하고자 해당 함수를 사용함#window가 sequence를 벗어나는 경우에 해당하는 sparse matrix 생성#ex) affected_seq_len(==windowsize//2) = 2 일때, window가 sequence를 벗어나는 token은 1,2번째 토큰# 토큰의 좌측에 2개의 token과도 self-attention을 해야 하는데 존재하지 않음# 따라서 해당 값을 maksing 해주기 위해 아래와 같이 코드 실행됨beginning_mask_2d=input_tensor.new_ones(affected_seq_len, affected_seq_len+1).tril().flip(dims=[0]) #(w//2, w//2+1), reversed lower triangle matrix [[1,1,0][1,0,0]]# attention과 동일한 차원을 갖도록 변경 beginning_mask=beginning_mask_2d[None, :, None, :] #(1,w//2, 1, w//2+1)# 마지막 토큰의 경우도 동일하므로, beginning_mask 를 flipending_mask=beginning_mask.flip(dims=(1, 3))
# input_tensor은 위에서 구한 attention tensor이며, masking이 필요한 구간만 선택beginning_input=input_tensor[:, :affected_seq_len, :, : affected_seq_len+1] #(batch, w//2, head, w//2+1)# mask를 beginning_input 과 동일한 차원이 되도록 expandbeginning_mask=beginning_mask.expand(beginning_input.size())
# beginning_mask 중 원소 값이 1을 갖는 경우, beginning_input의 해당 위치에 -float("inf")로 지정beginning_input.masked_fill_(beginning_mask==1, -float("inf")) # `== 1` converts to bool or uint8# end도 동일한 방식으로 적용ending_input=input_tensor[:, -affected_seq_len:, :, -(affected_seq_len+1) :]
ending_mask=ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask==1, -float("inf")) # `== 1` converts to bool or uint8
self._sliding_chunks_query_key_matmul() 함수가 끝났으므로 다시 forward 분석!
attn_scores=_sliding_chunks_query_key_matmul(...) #이전 과정을 통해 구한 attn_scores# values to pad for attention probsremove_from_windowed_attention_mask= (attention_mask!=0)[:, :, None, None]
# max_seq_length보다 짧아서 pad가 추가된 부분을 masking# cast to fp32/fp16 then replace 1's with -inf# remove_from_windowed_attention_mask에서 값이 True인 경우 -10000으로 변환float_mask=remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0)
# diagonal mask with zeros everywhere and -inf inplace of padding# 앞서 확인한 _sliding_chunks_query_key_matmul을 통하여 mask 생성diagonal_mask=self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probsattn_scores+=diagonal_mask#기존 attention에 mask 추가!assertlist(attn_scores.size()) == [
batch_size,
seq_len,
self.num_heads,
self.one_sided_attn_window_size*2+1,
], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size*2+1}), but is of size {attn_scores.size()}"
우선 여기까지 구했으면, sliding window를 활용한 local attention은 계산 완료
그리고 다음 코드 부턴 global attntion!!
# compute local attention probs from global attention keys and contact over window dimifis_global_attn: # is_global_attn 있을 경우만 진행# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) =self._get_global_attn_indices(is_index_global_attn) #해당 코드는 어렵지 않으니 직접 확인해볼 것# calculate global attn probs from global keyglobal_key_attn_scores=self._concat_with_global_key_attn_probs( #global attention idx를 기반으로 attention을 계산하여 반환query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to local_attn_probs# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)attn_scores=torch.cat((global_key_attn_scores, attn_scores), dim=-1) #global과 local attention을 합침# free memorydelglobal_key_attn_scores#리소스 관리를 위한 삭제#attention weight 계산attn_probs=nn.functional.softmax(
attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stabilityiflayer_head_maskisnotNone:
assertlayer_head_mask.size() == ( #head 자체에 maskingself.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"attn_probs=layer_head_mask.view(1, 1, -1, 1) *attn_probs# softmax sometimes inserts NaN if all positions are masked, replace them with 0attn_probs=torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs=attn_probs.type_as(attn_scores)
# free memorydelattn_scores# apply dropoutattn_probs=nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors=value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add# 로컬 어텐션 아웃풋 계산ifis_global_attn:
# compute sum of global and local attnattn_output=self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn onlyattn_output=self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assertattn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"attn_output=attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
# compute value for global attention and overwrite to attention output# TODO: remove the redundant computation#글로벌 어텐션 아웃풋 계산ifis_global_attn:
global_attn_output, global_attn_probs=self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
# get only non zero global attn outputnonzero_global_attn_output=global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
# overwrite values with global attentionattn_output[is_index_global_attn_nonzero[::-1]] =nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
# The attention weights for tokens with global attention are# just filler values, they were never used to compute the output.# Fill with 0 now, the correct values are in 'global_attn_probs'.attn_probs[is_index_global_attn_nonzero] =0outputs= (attn_output.transpose(0, 1),)
ifoutput_attentions:
outputs+= (attn_probs,)
returnoutputs+ (global_attn_probs,) if (is_global_attnandoutput_attentions) elseoutputs
여기까지 하면 LongformerSelfAttention 코드 분석은 끝입니다.
초반 코드에 비해 밑으로 갈 수록 코드에 대한 분석이 줄어드는데, 이 부분들은 어렵지 않다고 판단하여 작성만 하고 넘어갔습니다.
설명이 부족한 코드는 설명보단 눈으로 한 번 확인하면 쉽게 확인하실 수 있으실 겁니다.
조금 난잡하게 쓰긴 했지만, 조금이라도 도움이 됐으면 합니다.
The text was updated successfully, but these errors were encountered:
Longformer: The Long-Document Transformer paper view paper veview
transformers.models.longformer.modeling_longformer.py
def forward(...) 함수 실행 시
self._sliding_chunks_query_key_matmul(query_vectors, key_vectors, self.one_sided_attn_window_size) 분석
_sliding_chunks_query_key_matmul 함수 내부의 _chunk(query, window_overlap)분석
다시 _sliding_chunks_query_key_matmul() 함수 분석
self._pad_and_transpose_last_two_dims(diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)) 분석
다시 _sliding_chunks_query_key_matmul 분석
self._mask_invalid_locations(diagonal_attention_scores, window_overlap) 분석
self._sliding_chunks_query_key_matmul() 함수가 끝났으므로 다시 forward 분석!
우선 여기까지 구했으면, sliding window를 활용한 local attention은 계산 완료
그리고 다음 코드 부턴 global attntion!!
여기까지 하면 LongformerSelfAttention 코드 분석은 끝입니다.
초반 코드에 비해 밑으로 갈 수록 코드에 대한 분석이 줄어드는데, 이 부분들은 어렵지 않다고 판단하여 작성만 하고 넘어갔습니다.
설명이 부족한 코드는 설명보단 눈으로 한 번 확인하면 쉽게 확인하실 수 있으실 겁니다.
조금 난잡하게 쓰긴 했지만, 조금이라도 도움이 됐으면 합니다.
The text was updated successfully, but these errors were encountered: