Skip to content

kwon13/summarization-rloo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Summarization with Reinforcement Learning

Back to Basics: 강화학습을 적용한 요약 모델 개발

개요

이 프로젝트는 RLOO 학습 방법을 사용하여, 강화학습을 통한 요약 모델의 성능 향상을 탐구합니다. 실험 결과, 강화학습 기반의 접근 방식이 요약 모델의 성능을 효과적으로 향상시킬 수 있음을 보여주었습니다. rloo

설치

  1. 레포지토리 클론:

    git clone https://github.com/kwon13/summarization-rloo.git
    cd summarization-rloo
  2. 필요한 패키지 설치:

    pip install -r requirements.txt
  3. 전체 파일 경로:

summarization-rloo
├── CSFT
   ├── src
      ├── data.py 
      ├── utils.py
      └── resource
   └── train.py
├── RLOO
   ├── reward.py
   └── rloo.py
├── inference.py
├── utils.py
├── vllm_inference.py
├── requirements.txt
└── deepspeed_zero3.yaml

실행

  1. CSFT:

    cd CSFT
    CUDA_VISIBLE_DEVICES=1,3 python -m run.train \
        --model_id `학습시킬 모델 경로` \
        --batch_size 1 \
        --gradient_accumulation_steps 64 \
        --epoch 5 \
        --lr 2e-5 \
        --warmup_steps 20
  2. Reward Model:

    cd RLOO
    python3 reward.py --base_model=maywell/EXAONE-3.0-7.8B-Instruct-Llamafied --sft_model_path=fiveflow/exa-base --lr=3e-6 --deepspeed --track --output_dir=models/exaone_reward_model --local_eval_batch_size=1 --seed=44413
  3. RLOO:

    cd RLOO
    accelerate launch --config_file deepspeed_zero3.yaml \
     rloo.py \
     --output_dir models/rloo_tldr_t=0.1_ppo=1 \
     --num_ppo_epochs 1 \
     --num_mini_batches 1 \
     --learning_rate 3e-6 \
     --per_device_train_batch_size 1 \
     --gradient_accumulation_steps 2 \
     --total_episodes 450 \
     --model_name_or_path fiveflow/exa-base \
     --sft_model_path fiveflow/exa-base \
     --reward_model_path exaone_reward_modelv2 \
     --local_rollout_forward_batch_size 1 \
     --non_eos_penalty True \
     --response_length 512 \
     --stop_token eos \
     --temperature 0.1 \
     --rloo_k 2
  4. 모델 추론:

    python3 inference.py --model_path fiveflow/exa_rlo --input_file test.json --output_file output.json

결과

  • RLOO 방식은 G-Eval Metric 기준으로 Base 모델 대비 0.12점 향상 (7.62/10)
  • RLOO 방식이 가장 높은 성능을 기록하여, 강화학습의 효과를 입증

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages