Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongformerBart 모델 만들기 및 TAPT 적용하기 #38

Open
changyong93 opened this issue Dec 16, 2021 · 0 comments
Open

LongformerBart 모델 만들기 및 TAPT 적용하기 #38

changyong93 opened this issue Dec 16, 2021 · 0 comments
Assignees
Labels
report Sharing information or results of analysis

Comments

@changyong93
Copy link
Contributor

changyong93 commented Dec 16, 2021

1. 개요


  • 생성 요약 태스크에서 단일 모델 기준 성능이 좋은 Bart를 decoder와,
    최대 문서 길이가 약 2000정도인 입력 정보를 받기 위한 longformer의 encoder를 결합한 모델을 생성
  • longformer의 sliding window attention과 global attention을 통하여 input에 대한 local 및 global 정보를 학습
  • 또한, 입력 문서가 논문/뉴스/사설잡지로 나뉘므로, 문서 타입에 따른 문체 및 제목 유형을 학습하기 위해 document embedding을 추가

2. Model 흐름

LongformerBartWithDoctypeForConditionalGeneration class/
├──LongformerBartEncoderWithDocType class
│ ├── doc_type_shared(=document type embedding) class
│ ├── LongformerBartEncoderLayer class
├──BartDecoder class
│ ├── BartDecoderLayer class

3. modeling_longformerbart.py 구조

  1. LongformerBartConfig

    • 기존 BartConfig에 Longformer Attention을 위한 요소들을 추가 기입하였습니다.
      class LongformerBartConfig(BartConfig):
          def __init__(self,
              attention_window:List[int] = [512]*6,
              attention_dropout:float = 0.1,
              doc_type_size:int = 4,
              architectures:str = 'LongformerBartConditionalGeneration',
              max_position_embeddings:int = 2048, 
              max_target_positions:int = 1026, 
              encoder_layers:int = 6,
              decoder_layers:int = 6,
              encoder_attention_heads:int = 16,
              decoder_attention_heads:int = 16,
              **kwargs):
  2. LongformerSelfAttentionForBart

    • 모델 구조 : Longformer Self Attention 구조를 활용하기 위해서 만든 함수입니다.
    • Encoder에서만 Longformer Self Attention을 활용하기 때문에 EncoderLayer안에 들어가서 동작하게 됩니다.
    class LongformerSelfAttentionForBart(nn.Module):
        def __init__(self, config:LongformerBartConfig, layer_id:int):
            super().__init__()
            self.d_model = config.d_model
            config.num_attention_heads = config.encoder_attention_heads
            config.attention_probs_dropout_prob = config.attention_dropout
            self.longformer_self_attn_layer = LongformerSelfAttention(config,layer_id=layer_id)
            self.outputs = nn.Linear(self.d_model, self.d_model)
  3. LongformerBartEncoderLayer

    • 모델 구조
      • 위에서 만든 LongformerSelfAttentionForBart를 활용해서 attention을 하고 나머지는 기존 BartEncoderLayer와 동일하게 구성됩니다.
      • 이 Layer이 Config.encoder_layer에 정해진 수만큼 생성이 되어서 LongformerBartEncoderWithDocType에 들어가게 됩니다.
    def forward(
        self,
        hidden_states: torch.Tensor,          # (batch_size, seq_size, hidden_size)
        attention_mask: torch.Tensor,         # (batch_size, seq_size)
        layer_head_mask: torch.Tensor, 
        is_index_masked: torch.Tensor,        # (batch_size, seq_size)
        is_index_global_attn: torch.Tensor,   # (batch_size, seq_size)
        is_global_attn:torch.Tensor,          # (1,)
        output_attentions: bool = False,
    ):
    
        residual = hidden_states
        hidden_states, output_attn , global_attn = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            is_index_masked=is_index_masked,           
            is_index_global_attn=is_index_global_attn,  
            is_global_attn=is_global_attn,   
            output_attentions=output_attentions,
        )
    • 입력 인자

      1. hidden_states : input_ids, doc_type_ids, potition_ids을 사용해서 생성된 embedded 된 입력입니다.
      2. attention_mask : attention_mask를 활용해서 padding된 부분을 attention을 할 때 영향을 끼치지 못하도록 하는 입력입니다.
      3. is_index_masked : attention이 적용되는 인자들의 위치가 masking된 입력입니다.
      4. is_index_global_attention : global attention이 적용되는 인자들의 위치가 masking된 입력입니다.
      5. is_global_attn : global attention이 적용될지에 대한 여부를 boolean torch tensor로 입력 받습니다.
    • 동작 과정

      1. residual_connection을 위해서 기존의 hidden_states를 따로 저장합니다.
      2. 기존의 hidden_states를 사용해서 LongformerSelfAttention을 통해서 출력값을 구합니다.
      3. 구한 hidden_states와 기존의 residual을 사용해서 residual_connection을 합니다.
      4. output_attn, global_attn은 output_attention이 True일 때만 출력합니다.
  4. LongformerBartEncoderWithDocType

    • 입력받는 인자는 기존과 같습니다.
    • 단, config에는 doc_type_size가 포함되어 있으며, 해당 정보를 바탕으로 doc_type_tokens로 명명한 embedding을 생성합니다.
     class LongformerBartEncoderWithDocType(BartPretrainedModel):
         def __init__(self,
               config: LongformerBartConfig,
               embed_tokens: Optional[nn.Embedding] = None):
               super().__init__(config)
               self.doc_type_tokens = nn.Embedding(config.doc_type_size, self.embed_dim, 0)
    • forward 부분에서 중요한 부분은 아래 두 내용입니다.
      • global attention으로 사용할 token ids 지정

        • encoder는 longformer의 encoder로, sliding window 방식의 local attention과 기존에 알던 (global) attention을 사용합니다. 단, global attention은 local attention에 비해 매우 작은, 예를 들어 특정 special token 들을 지정하여 사용합니다. 하지만 코드 상에서 따로 global attention을 지정하지 않을 경우, padding idx가 아닌 모든 tokens을 선택하므로 이 부분을 지정해줘야 합니다.
        • 저희는 이 부분을, text의 시작과 끝인 bos_token/eos_token을 global attention으로 지정하기로 하여 저희가 만든 get_is_index_global_attn() 함수를 통해 해당 special token이 위치한 index 위치 정보를 생성합니다. 이후 이 값을 encoder의 self-attention에 전달하여 global attention을 수행합니다.
          is_index_masked = attention_mask < 0
              device = attention_mask.device
              is_index_global_attn = self.get_is_index_global_attn(attention_mask).to(device)
              is_global_attn = is_index_global_attn.flatten().any().item()
      • hidden states with doc_type_ids

        • 기존의 hidden states에 doc_type_ids를 더합니다.
        • 여기서, doc_type_ids는 논문/뉴스/설설잡지가 각각이 정수인코딩을 통해 1/2/3로 변환되어, pad token을 제외한 나머지 토큰 idx 위치에 해당 값들이 입력되어 있습니다.
        doc_type = self.doc_type_tokens(doc_type_ids)
        hidden_states = inputs_embeds + embed_pos + doc_type
  5. LongformerBartWithDoctypeForConditionalGeneration

    • 기본적으로 ***ModelForConditionalGeneration은 Model class를 init에서 생성하지만, 저희 모델은 encoder와 decoder를 받습니다.

    • Scheduled Sampling for Transformers 논문에 설명한 transformer 계열의 teacher forcing 방법은 2-stage decoder가 있는데, feater forcing을 사용할 경우 첫 번째 stage는 동작하지 않고 두 번째 stage만 동작합니다. 반면 이전 time step의 output을 현재 time step의 input으로 넣는 경우, 첫 번째 stage의 decoder output을 이전 step의 output으로 보고 해당 값을 두 번째 stage의 decoder input으로 사용합니다.

    • 다만, 이 경우 학습 step과 무관하게 teacher forcing이 랜덤하게 적용되므로, 저희는 teacher forcing scheduler를 개발하였으며, 전체 학습 step을 기준으로 현재 타임step과 비교하여 선형적으로 teacher forcing 적용 확률을 1->0까지 낮췄습니다.

    • 추가적으로, teacher forcing을 선택적으로 적용할 수 있도록, 만약 모델 config에 전체 학습 step을 입력해주지 않았을 경우 미적용하도록 구현하였습니다.

    • LongformerBartWithDoctypeForConditionalGeneration의 forward()를 참고 부탁드리며, 세미나 때 따로 설명드리도록 하겠습니다.

4. processor.py 구조

  1. 추가 요소

    • gogamza/kobart-base-v1 에서는 tokenizer에서 자동으로 bos, eos 를 붙여주지 않기 때문에 bos, eos를 따로 구해서 tokenized된 input_ids에 붙여주었습니다.
    def add_padding(sample_tokens:List[int],
                    padding:bool,
                    padding_num:int,
                    max_length:Optional[int],
                    bos_token_id:int,
                    eos_token_id:int) -> List:
        sample_tokens_len = len(sample_tokens)
        if len(sample_tokens) > max_length - 2:
            sample_tokens = [bos_token_id] + sample_tokens[:max_length-2] + [eos_token_id]
        else:
            sample_tokens = [bos_token_id] + sample_tokens + [eos_token_id]
            if padding:
                sample_tokens = sample_tokens + [padding_num]*(max_length-sample_tokens_len-2)
        return sample_tokens
    • doc_type_ids는 padding 된 부분에는 0을 기입하고 나머지는 각 문서의 doc_type에 맞는 수를 저장하였습니다.
    def get_doc_type_ids(sample_tokens:List[int],
                         doc_type_id:int) -> List:
        sample_tokens = np.array(sample_tokens)
        doc_type_id_list = list(np.where(sample_tokens == 1, doc_type_id, 0))
        return doc_type_id_list

5. data_collator.py

  1. 개요

    1. Bart 논문에 따르면 Bart 모델을 Pretraining 할 때 다양한 방법으로 입력을 변형하였고 그 입력을 변형한 방법 중에서는 mask infilling 방식이 summarization task에서 가장 성능이 좋았습니다.
    2. collator 단에서 mask infilling 방식을 구현하고 변형된 입력을 넣고 출력은 원래의 입력을 가져감으로써 모델을 pretraining 하고자 하였습니다.
  2. Mask Infilling 알고리즘

    1. Sequence가 주어지면 Sequence의 최대 15% 까지 선택이 될 때 까지 Sequence span을 선택해야 합니다.
    2. Span의 길이를 선택하는 방식은 poisson 분포를 따릅니다. (lambda = 3 => 확률분포의 평균/분산이 3/3)
    3. Span의 길이의 시작지점은 Random하게 선택을 합니다.
    4. Span의 시작점 부터 마지막 지점까지 masking을 합니다.
    5. masking이후 special token mask는 masking이 되면 안되기 때문에 이 부분은 제외합니다.
    6. masking이 연속된다면 이를 하나의 mask token으로 바꾸고 삭제된 길이는 pad_token_id로 채웁니다.

6 trainer with Noam scheduler

  • warmup step에 도달 후 inversesqrt 하게 줄어드는 Noam을 적용하고자 할 경우, 실행 시 --is_noam== True여야 하며, 그 경우 아래 코드가 실행됩니다.
  • 아래 코드는 저희가 구현한 Seq2SeqTrainerWithConditionalDocType의 일부이며, noam scheduler를 생성하기 위해 아래와 같이 overwriting하였습니다.
  • Noam은 warmup까진 linear하게 증가하다가, 위에서 언급한 것과 같이, inverse sqrt하게 줄어듧니다.
  • noam을 사용할 때 주의할 점은, 코드 실행 시 lr은 우리가 도달할 lr값이 아니라, noam의 출력값이 factor로 하고 해당 값과 곱해져서 원하는 lr에 도달할 lr값을 지정해줘야 합니다.
    • ex) noam 출력 0.5, 원하는 peak lr값 10 => 입력할 lr값 20
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        if not self.args.is_noam:
            super().create_scheduler(num_training_steps, optimizer)
        else:
            if self.lr_scheduler is None:
                self.lr_scheduler = self.get_noam_schedule_with_warmup(
                    optimizer=self.optimizer if optimizer is None else optimizer,
                    num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
                )
            return self.lr_scheduler

    def get_noam_schedule_with_warmup(self, optimizer, num_warmup_steps, last_epoch=-1):
        def lr_lambda(current_step: int):
            return 1 / math.sqrt(self.args.model_config.d_model) * min(1/math.sqrt(current_step+1), (current_step+1) /(num_warmup_steps**(1.5)))
        return LambdaLR(optimizer, lr_lambda, last_epoch)

7. pretrain.py

  1. train.py는 finetuning을 위한 코드였다면 pretrain.py는 pretraining을 위해서 만든 코드입니다.
    1. train.py와 다른 점

      • from pretrained가 아니라 from confing함수를 이용해서 모델을 만듭니다.
      training_args.model_config = config
      def model_init():
          # https://discuss.huggingface.co/t/fixing-the-random-seed-in-the-trainer-does-not-produce-the-same-results-across-runs/3442
          # Producibility parameter initialization
          model = LongformerBartWithDoctypeForConditionalGeneration._from_config(training_args.model_config)
          return model
      • DataCollator를 기존의 Seq2Seq Collator를 쓰는 것이 아니라 TextInfilling을 위한 Collator를 불러옵니다.
      data_collator = DataCollatorForTextInfillingDocType(
          tokenizer,
          label_pad_token_id=label_pad_token_id,
          pad_to_multiple_of=model_args.attention_window_size,
      )
      • pretraining 하는 과정은 시간이 많이 소요되고 중간에 끊기게 되면 처음부터 다시하는 것이 시간 소모가 많이 되기 때문에 이전에 저장된 checkpoint가 있다면 이를 가져와서 학습을 재시작하는 코드를 작성하였습니다.
      if training_args.do_train:
          last_checkpoint = None
          if (
              os.path.isdir(training_args.output_dir)
              and training_args.do_train
              and not training_args.overwrite_output_dir
          ):
      
              last_checkpoint = get_last_checkpoint(training_args.output_dir)
              if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
                  raise ValueError(
                      f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                      "Use --overwrite_output_dir to overcome."
                  )
                  
          train_result = trainer.train(resume_from_checkpoint=last_checkpoint)

8. Shell Script

  1. pretrain.py

    python pretrain.py \
    --do_train \
    --is_pretrain \
    --output_dir checkpoint/longformerbart_pretrain_V2_6to3_Big \
    --num_train_epochs 5 \
    --logging_steps 2000 \
    --save_strategy epoch \
    --evaluation_strategy no \
    --max_source_length 2048 \
    --max_target_length 2048 \
    --project_name longformerbart \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 2 \
    --wandb_unique_tag longformerBart_pretraining_V2_Hidden512 \
    --hidden_size 512 \
    --encoder_layer_size 6 \
    --decoder_layer_size 3 \
    --attention_head_size 4 \
    --attention_window_size 64 \
    --dropout 0.1 \
    --learning_rate 1e-4 \
    --warmup_steps 10000 \
    --weight_decay 1e-4 \
    --adam_beta1  0.9 \
    --adam_beta2  0.98 \
    --adam_epsilon 1e-06 \
    --use_doc_type_ids True \
    
  2. train.py

    • pretraining 할 때 적용했던 model 특성을 그대로 가져와야 합니다.
    python train.py \
    --do_train \
    --output_dir model/baseV1.0_Kobart \
    --use_doc_type_ids \
    --is_noam \
    --use_model longbart \
    --num_train_epochs 3 \
    --learning_rate 3e-05 \
    --max_source_length 2048 \
    --max_target_length 256 \
    --metric_for_best_model rougeLsum \
    --relative_eval_steps 10 \
    --es_patience 3 \
    --load_best_model_at_end True \
    --project_name baseV1.0_Kobart \
    --hidden_size 512 \
    --encoder_layer_size 6 \
    --decoder_layer_size 3 \
    --attention_head_size 4 \
    --attention_window_size 64 \
    --wandb_unique_tag longformerbart_ep3_lr3e05_len2048 
    
@changyong93 changyong93 added the report Sharing information or results of analysis label Dec 16, 2021
@changyong93 changyong93 self-assigned this Dec 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
report Sharing information or results of analysis
Projects
None yet
Development

No branches or pull requests

1 participant