Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint Issues #12

Open
vibhatha opened this issue Mar 3, 2020 · 11 comments
Open

Checkpoint Issues #12

vibhatha opened this issue Mar 3, 2020 · 11 comments
Assignees
Labels
question Further information is requested

Comments

@vibhatha
Copy link

vibhatha commented Mar 3, 2020

I tried the 'never' option for checkpointing. The idea was to see how the pipeline was performing without checkpointing overhead.

What I observed was the performance is consistent for pipeline parallelism 2, 4 and 8. And also another important observation was the performance is much lower than the performance with checkpointing.

Is this expected or are there any other tunning parameters to get better performance?

I checked the backward and forward to backward time ratio?

Assuming backward time increase with checkpointing, is it a valid logic with your implementation?
Meaning when I turn off checkpointing the pipeline performance must improve?

Could you clarify the implementation details on this.

@sublee
Copy link
Contributor

sublee commented Mar 4, 2020

By the original design, the checkpointing has a trade-off between speed and memory. It slows down the speed of backward pass to give much more memory capacity by forgetting activation memory on the forward pass.

What I observed was the performance is consistent for pipeline parallelism 2, 4 and 8. And also another important observation was the performance is much lower than the performance with checkpointing.

If your term "performance" means "speed", your second observation is unexpected. torchgpipe without checkpointing is identical with typical pipeline parallelism but not GPipe. If you choose the same chunk size in both settings, the concurrency should not decrease. How did you choose the batch size and the number of chunks on both checkpoint='never' and checkpoint='except_last' settings?

@sublee sublee self-assigned this Mar 4, 2020
@sublee sublee added the question Further information is requested label Mar 4, 2020
@vibhatha
Copy link
Author

vibhatha commented Mar 9, 2020

For this, I chose the batch size 60 with 480 data points.
That was the best batch size I was able to fit the model into memory.

Then I use the checkpoint never or exept_last

I also added a few arg params to make this convenient.

#!/bin/bash
id=$1
chk=$2
dataset_size=480
epochs=10
exp_type=pipeline-${id}
version=6_checkpoint_${chk}_chunk_variation
batch_size=240

for chunk_size in 10, 20, 40, 60, 120 
do
   echo "python3 main-micro.py ${exp_type} --batch_size ${batch_size} --chunks ${chunk_size} --dataset_size ${dataset_size} --save_file stats_${exp_type}_v${version}.csv --epochs ${epochs} --checkpointing ${chk}"
   python3 main-micro.py ${exp_type} --batch_size ${batch_size} --chunks ${chunk_size} --dataset_size ${dataset_size} --save_file stats_micro_${exp_type}_v${version}.csv --epochs ${epochs} --checkpointing ${chk}
done

@sublee
Copy link
Contributor

sublee commented Mar 10, 2020

One possibility came to my mind. When a process uses up almost all CUDA memory, CUDACachingAllocator in PyTorch might synchronize to GPU for releasing garbage blocks. Frequent synchronization between CPU and GPU is not good at the speed. Why don't you try to choose a smaller batch size and profile both options with NVIDIA Nsight Systems?

@vibhatha
Copy link
Author

Yes, I am heading that way @sublee. I observed some overheads with smaller batch sizes.
I am profiling those parts.

@vibhatha
Copy link
Author

https://github.com/kakaobrain/torchgpipe/blob/master/benchmarks/unet-speed/main.py

In here, when doing the speed benchmarks, why a constant mini-batch size is not used for pipelining. Shouldn't the variable be chunks?

@sublee
Copy link
Contributor

sublee commented Mar 26, 2020

The constant batch_sizes are used in input = torch.rand(batch_size, 3, 192, 192, device=in_device) on 168th line.

@vibhatha
Copy link
Author

def baseline(model: nn.Module, devices: List[int]) -> Stuffs:
        batch_size = 40
        device = devices[0]
        model.to(device)
        return model, batch_size, [torch.device(device)]

    @staticmethod
    def pipeline1(model: nn.Module, devices: List[int]) -> Stuffs:
        batch_size = 80
        chunks = 2
        balance = [241]

        model = cast(nn.Sequential, model)
        model = GPipe(model, balance, devices=devices, chunks=chunks)
        return model, batch_size, list(model.devices)

    @staticmethod
    def pipeline2(model: nn.Module, devices: List[int]) -> Stuffs:
        batch_size = 512
        chunks = 32
        balance = [104, 137]

        model = cast(nn.Sequential, model)
        model = GPipe(model, balance, devices=devices, chunks=chunks)
        return model, batch_size, list(model.devices)

Here the input to that line comes from these?
Have I misunderstood this?

@sublee
Copy link
Contributor

sublee commented Mar 26, 2020

Those static methods return batch_size as a result. It is used to initialize input later.

EXPERIMENTS: Dict[str, Experiment] = {
    'baseline': Experiments.baseline,
    'pipeline-1': Experiments.pipeline1,
    'pipeline-2': Experiments.pipeline2,
    'pipeline-4': Experiments.pipeline4,
    'pipeline-8': Experiments.pipeline8,
}
...
    f: Experiment = EXPERIMENTS[experiment]
    try:
        model, batch_size, _devices = f(model, devices)
...
    input = torch.rand(batch_size, 3, 192, 192, device=in_device)

@vibhatha
Copy link
Author

Yes, those values are different.
I mean the batch size per each pipeline config is different.
What is the reason for this?

@sublee
Copy link
Contributor

sublee commented Mar 26, 2020

Sorry for misunderstanding what "constant" means. We adjusted the batch sizes to maximize the throughput. You can find a similar explanation in the paper v1 "4.2. Performance".

@vibhatha
Copy link
Author

That is totally fine. I just wanted to learn why the numbers were chosen like that.
Thank you very much 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants