Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

structuredconfig for train.py and eval.py #1051

Merged
merged 240 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
240 commits
Select commit Hold shift + click to select a range
7dc3fbc
first commit for structuredconfig for train.py
milocress Mar 22, 2024
abc2b3f
Merge branch 'main' into milo/foundry-type-cleanup
milocress Mar 22, 2024
aceb2b4
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 1, 2024
f7258b7
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 1, 2024
ccc48d7
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 3, 2024
6f7c519
revamp configs
milocress Apr 4, 2024
e17d04f
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 8, 2024
ea77a27
wip latest issue
milocress Apr 9, 2024
5584de5
merged
milocress Apr 9, 2024
a74aa9e
reorder so mandatory attributes come first
milocress Apr 9, 2024
f767586
fix
milocress Apr 9, 2024
e3134e3
fix
milocress Apr 9, 2024
686fc66
fix fix
milocress Apr 9, 2024
cf1e42e
fix types
milocress Apr 9, 2024
4cf99fe
fix dictconfig
milocress Apr 9, 2024
839c61c
fix union of list|dict configs
milocress Apr 9, 2024
710c9b0
fix type annotation
milocress Apr 9, 2024
142518a
oops
milocress Apr 9, 2024
7408403
fixed configs
milocress Apr 9, 2024
a1bf2b8
add save ignore keys
milocress Apr 10, 2024
a26d4fc
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 10, 2024
dcf4142
fix batch size kerfuffle
milocress Apr 10, 2024
8f1177b
fix dictconfig stuff
milocress Apr 10, 2024
4965cd2
fix dictconfig stuff again
milocress Apr 10, 2024
53889fd
fix
milocress Apr 10, 2024
7669954
fix
milocress Apr 10, 2024
ba0783d
updated unit tests for variables
milocress Apr 10, 2024
2349390
last fix?
milocress Apr 10, 2024
0acd8c7
if this test case does not pass I will venmo Mihir 0
milocress Apr 10, 2024
6a3d43a
remove a 'not' -- eg. 'I am not going crazy'
milocress Apr 10, 2024
7f3d913
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 10, 2024
dd4a926
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 15, 2024
704195e
Update scripts/train/train.py
milocress Apr 16, 2024
ef0554c
set amp bf16 as default precision, etc
milocress Apr 17, 2024
0953267
Merge branch 'milo/foundry-type-cleanup' of github.com:milocress/llm-…
milocress Apr 17, 2024
cb0ad66
temporarily wrap with dictconfig before ** migration
milocress Apr 17, 2024
560e574
fix icl tasks
milocress Apr 17, 2024
346a875
fix
milocress Apr 17, 2024
f57a83a
fix activation checkpointing reentrant
milocress Apr 17, 2024
fff864a
fix extraneous keys
milocress Apr 17, 2024
ed8a94d
first round **
milocress Apr 17, 2024
42939aa
fix?
milocress Apr 17, 2024
8e981aa
quick fsdp config fix
milocress Apr 17, 2024
767f097
updated yamls to make variables explicit
milocress Apr 17, 2024
af250fe
remove precision from mandatory params list
milocress Apr 17, 2024
1c14ea5
I expect many of these to fail in interesting ways
milocress Apr 17, 2024
0b5721e
fix test_model test cases with **
milocress Apr 17, 2024
f770f60
fix many more test cases
milocress Apr 17, 2024
9dfa01b
fix dictconfig objectification
milocress Apr 17, 2024
fc4a86a
fix remaining test cases
milocress Apr 17, 2024
c7bb866
remove unneeded **
milocress Apr 17, 2024
6997d14
fix test case
milocress Apr 17, 2024
c55bafa
changed back argument name
milocress Apr 18, 2024
db8d207
fix
milocress Apr 18, 2024
cd5460e
** for finetuning dataloader
milocress Apr 18, 2024
bbd04d1
fix?
milocress Apr 18, 2024
fc6fb1b
fix dataloader
milocress Apr 18, 2024
1887ed0
fix
milocress Apr 18, 2024
9fd912f
fix finetuning dataloader
milocress Apr 18, 2024
8cd9e65
fix build_text_dataloader
milocress Apr 18, 2024
1048580
left to my own devices
milocress Apr 18, 2024
de2f893
fix packing
milocress Apr 18, 2024
9869057
fix typo
milocress Apr 18, 2024
e2fdf06
fix padding test cases
milocress Apr 18, 2024
4ee17f0
ignore extra parameters and warn
milocress Apr 18, 2024
d06e357
fix style
milocress Apr 18, 2024
b8fd65d
fix quality checks
milocress Apr 18, 2024
01b7419
fix code quality
milocress Apr 18, 2024
d986503
pyright-fu
milocress Apr 18, 2024
41fbe28
fix
milocress Apr 18, 2024
d730270
just one more type constraint bro
milocress Apr 18, 2024
0fbb3c6
OmegaConf -> om
milocress Apr 18, 2024
da962d3
rename variables for clarity
milocress Apr 18, 2024
f838b74
revert file
milocress Apr 18, 2024
4c31b6f
revert file II
milocress Apr 18, 2024
ee46918
revert file III: revert of the sith
milocress Apr 18, 2024
ff108c8
peft revert file
milocress Apr 18, 2024
e6edad1
revert v_mpt
milocress Apr 18, 2024
a59299c
last revert
milocress Apr 18, 2024
702910f
remove redundant checks
milocress Apr 18, 2024
902254c
deprecate
milocress Apr 18, 2024
20a7703
make cleaner
milocress Apr 18, 2024
b7db045
pyright is bullying me again
milocress Apr 18, 2024
40324c8
further clean config_utils
milocress Apr 18, 2024
e7a2bfc
polish train
milocress Apr 18, 2024
4c403fd
polish train and eval
milocress Apr 18, 2024
f716642
fix dist
milocress Apr 18, 2024
4109b5a
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 18, 2024
0baae32
fix style
milocress Apr 18, 2024
acb0f84
merged
milocress Apr 18, 2024
e4ee9fc
organize eval and train
milocress Apr 18, 2024
5ff1816
fix
milocress Apr 18, 2024
8e5bc1d
used helper function to make main cleaner
milocress Apr 18, 2024
f916a15
fix stuff
milocress Apr 18, 2024
0baca2e
fix pyright
milocress Apr 18, 2024
05342b2
added fix and explanation
milocress Apr 18, 2024
41d9255
fix typo in unit test update smh
milocress Apr 18, 2024
0f8b26b
Update llmfoundry/registry.py
milocress Apr 19, 2024
5d805c3
Update scripts/train/train.py
milocress Apr 19, 2024
cab542f
Update scripts/train/train.py
milocress Apr 19, 2024
ebcafc4
Update scripts/train/train.py
milocress Apr 19, 2024
f744bd4
Apply suggestions from code review
milocress Apr 19, 2024
fba1f63
see if this fails
milocress Apr 19, 2024
a2de27a
reject name and device rather than ignoring
milocress Apr 19, 2024
b6ebccc
pretrained is not a bool
milocress Apr 19, 2024
604f254
add validation to make sure the user doesn't set both
milocress Apr 19, 2024
ba9391b
forbid config keys
milocress Apr 19, 2024
fe7d7b9
oops forgot eval
milocress Apr 19, 2024
bcedcfd
address coomments
milocress Apr 19, 2024
8a2bb19
removed redundant check
milocress Apr 19, 2024
2fdbbc5
updated callsites not to use name
milocress Apr 19, 2024
def05c4
merged
milocress Apr 19, 2024
75ed2e5
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 19, 2024
acb9e1d
fix
milocress Apr 19, 2024
f831c61
validate extraneous keys in dataloader
milocress Apr 19, 2024
1c16dec
fix
milocress Apr 19, 2024
83cad9c
fix more
milocress Apr 19, 2024
d1b26f3
fix III: revenge of the fix
milocress Apr 19, 2024
bc2d5d3
fix IV: a new hope
milocress Apr 19, 2024
cd73f60
fix V: the empire fixes back
milocress Apr 19, 2024
a59e09c
fixed some more types
milocress Apr 19, 2024
b7fb56a
fix VI: return of the fix
milocress Apr 19, 2024
1eb809e
fix VII: the fix awakens
milocress Apr 19, 2024
c4922a5
fix VIII: the last bug
milocress Apr 19, 2024
3fc2d82
fix
milocress Apr 19, 2024
b9db81f
final fix I think
milocress Apr 19, 2024
e87b9b2
fixed
milocress Apr 20, 2024
48fa58e
fix style
milocress Apr 21, 2024
3e77198
fix
milocress Apr 21, 2024
961b034
fix fix
milocress Apr 21, 2024
d245cd1
fix fix style
milocress Apr 21, 2024
aec718b
icl task config
milocress Apr 21, 2024
d9e6f13
fix train
milocress Apr 21, 2024
4609950
fix finetuning dataloader
milocress Apr 21, 2024
722aeb1
fix train types
milocress Apr 21, 2024
8d08b17
fix token counting
milocress Apr 21, 2024
87d7cdf
fix train types
milocress Apr 21, 2024
d0c2b4f
oopsie
milocress Apr 21, 2024
fa639c6
fix straggler issues
milocress Apr 21, 2024
f71396d
fix tests
milocress Apr 21, 2024
02a50de
fix???
milocress Apr 21, 2024
76c413b
fix hf v mpt gpu test and fmapi test
milocress Apr 21, 2024
66e86dc
pop device
milocress Apr 21, 2024
b57a107
to_str_dict -> to_dict_recursive
milocress Apr 21, 2024
f680ea2
fix this darn unit test one more time
milocress Apr 21, 2024
a6ed534
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 21, 2024
7650b50
fix ComposerMPTCausalLM constructor invocation
milocress Apr 22, 2024
70af288
Delete tests/models/hf/test_hf_fsdp.py
milocress Apr 22, 2024
bc6c545
unwrap model in unit tests
milocress Apr 22, 2024
e1e59fe
merge
milocress Apr 22, 2024
be5105d
model.model.model.model.model
milocress Apr 22, 2024
6206db6
abstract away dataclass construction
milocress Apr 22, 2024
de67266
updated docstrings and removed dictconfig from logging logic
milocress Apr 22, 2024
31111a7
flag icl tasks required or not
milocress Apr 22, 2024
8c4aaa4
updated a couple yamls
milocress Apr 22, 2024
4758af3
updated train and eval scripts
milocress Apr 22, 2024
169f1a3
un-delete global train batch size
milocress Apr 22, 2024
df955a9
fix
milocress Apr 22, 2024
060c216
I don't understand why this doesn't work
milocress Apr 22, 2024
feea2d1
that was the sneakiest bug I've ever fixed
milocress Apr 22, 2024
adfa165
try to fix the regression test
milocress Apr 22, 2024
2c1f4d6
remove device train grad accum
milocress Apr 22, 2024
774db46
fix validate config
milocress Apr 22, 2024
c014baa
removed unused import
milocress Apr 22, 2024
05d9c68
use variables
milocress Apr 22, 2024
c2e1c4f
missing mandatory value fix
milocress Apr 22, 2024
e733b9f
use correct type of error
milocress Apr 22, 2024
3a5a960
fix
milocress Apr 22, 2024
e309366
import TrainConfig just in case?
milocress Apr 22, 2024
8704f3f
moved trainconfig and evalconfig into utils
milocress Apr 22, 2024
98164c8
works
milocress Apr 22, 2024
9ab6b8f
no cheating
milocress Apr 22, 2024
94fd55b
dicts everywhere gah
milocress Apr 22, 2024
84ba917
try no recursive just
milocress Apr 22, 2024
3e03cb0
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 22, 2024
155a484
rename typed helpers
milocress Apr 22, 2024
0403330
fix the test cases with deep magic
milocress Apr 22, 2024
d33eb10
towards a peaceful resolution
milocress Apr 22, 2024
bd203e6
remove comments
milocress Apr 22, 2024
853c173
fix type warnings
milocress Apr 23, 2024
22bbccc
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 23, 2024
0e74185
Update llmfoundry/utils/config_utils.py
milocress Apr 23, 2024
9fa0418
address low-hanging fruit
milocress Apr 23, 2024
22ede8d
merged
milocress Apr 23, 2024
6665575
remove peft wrapping extra model
milocress Apr 23, 2024
9cdc7a4
python :handshake: haskell
milocress Apr 23, 2024
08814e1
dataset config should be dict
milocress Apr 23, 2024
80acfb3
just because omega starts with OMMMM does not mean it's zen
milocress Apr 23, 2024
2dd350c
fix
milocress Apr 23, 2024
e8ecfcd
fix
milocress Apr 23, 2024
0842b36
structured settlement
milocress Apr 23, 2024
4141d48
precision further down
milocress Apr 23, 2024
53a2a80
throws TypeError instead of MissingMandatoryValue or whatever
milocress Apr 23, 2024
fc86f6f
remove debugging statement
milocress Apr 23, 2024
dc73a4f
remove to_container calls everywhere
milocress Apr 23, 2024
4987145
wrap then unwrap
milocress Apr 23, 2024
b9c3cbf
pyright
milocress Apr 23, 2024
cbfec68
error early on missing mandatory values
milocress Apr 23, 2024
f2ed1d7
remove unnecessory ignore
milocress Apr 23, 2024
063ab43
merged and resolved
milocress Apr 24, 2024
c586978
update unit tests
milocress Apr 24, 2024
fbe436e
update eval yamls
milocress Apr 24, 2024
bf21b14
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 24, 2024
c16e359
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 24, 2024
7c611f0
Update train.py
milocress Apr 24, 2024
96a620d
make log level optional again
milocress Apr 24, 2024
e549d9b
Merge branch 'main' into milo/foundry-type-cleanup
dakinggg Apr 25, 2024
28b86e4
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 25, 2024
a9f252e
resolved merge conflict
milocress Apr 25, 2024
a3ed2af
merge II
milocress Apr 25, 2024
e6d0923
oopsie
milocress Apr 25, 2024
b452d8d
resolve conflict II
milocress Apr 25, 2024
b37d3f8
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 28, 2024
02567f6
merged and resolved conflicts
milocress Apr 29, 2024
a1eedde
merged
milocress Apr 29, 2024
f7cede6
use keywords for arg clarity
milocress Apr 29, 2024
e13c75a
use keywords for arg clarity
milocress Apr 29, 2024
a1b9adf
style
milocress Apr 29, 2024
459fce5
style
milocress Apr 29, 2024
27c81a5
dist timeout
milocress Apr 29, 2024
99922fe
Merge branch 'main' into milo/foundry-type-cleanup
milocress Apr 30, 2024
55d7a66
Merge branch 'main' into milo/foundry-type-cleanup
dakinggg May 1, 2024
a30618b
Merge branch 'main' into milo/foundry-type-cleanup
dakinggg May 1, 2024
048e0e2
Merge branch 'main' into milo/foundry-type-cleanup
milocress May 2, 2024
38ae6b1
resolved merge conflict but expect errors
milocress May 3, 2024
bcaad4b
resolve deeper conflict issues
milocress May 3, 2024
e05ab76
Merge branch 'main' into milo/foundry-type-cleanup
dakinggg May 3, 2024
108268d
fix train.py
milocress May 3, 2024
cf29ede
fix eval
milocress May 3, 2024
6ab9bc2
fix registry
milocress May 3, 2024
61e30ca
fix dataloader
milocress May 3, 2024
b3cd8ce
fix train II
milocress May 3, 2024
760abb4
fix dataloader and utils
milocress May 3, 2024
c896437
fix dictconfig
milocress May 3, 2024
682d2bf
skill issue
milocress May 3, 2024
9d229b9
add new keys
milocress May 3, 2024
1d994a9
Merge branch 'main' into milo/foundry-type-cleanup
milocress May 6, 2024
eccf849
remove pop_config
milocress May 6, 2024
244d3e3
Merge branch 'main' into milo/foundry-type-cleanup
milocress May 8, 2024
8fb5e4c
fix
milocress May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

"""Dataloader builder utilities."""

from typing import Union
from typing import Any, Dict

from composer import DataSpec
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
Expand All @@ -18,9 +17,9 @@


def build_dataloader(
cfg: DictConfig,
cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
device_batch_size: Union[int, float],
device_batch_size: int,
) -> DataSpec:
"""Builds a dataloader from a config.

Expand All @@ -30,14 +29,15 @@ def build_dataloader(
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
kwargs = {
'cfg': cfg,
name = cfg.pop('name')
kwargs: Dict[str, Any] = {
**cfg,
'tokenizer': tokenizer,
'device_batch_size': device_batch_size,
}

return construct_from_registry(
name=cfg.name,
name=name,
registry=registry.dataloaders,
partial_function=False,
pre_validation_function=None,
Expand Down
349 changes: 214 additions & 135 deletions llmfoundry/data/finetuning/dataloader.py

Large diffs are not rendered by default.

63 changes: 32 additions & 31 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import logging
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
import torch
from composer.utils import dist
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -318,7 +317,7 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):


def auto_packing_ratio(
dataloader_cfg: DictConfig,
dataloader_cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
num_packing_ratios: int = 20,
Expand Down Expand Up @@ -352,20 +351,21 @@ def auto_packing_ratio(
# Set the seed so that auto packing is deterministic.
reproducibility.seed_all(0)

max_seq_len = dataloader_cfg.dataset.max_seq_len
# If max_seq_len is very small, skip profiling and select packing ratio of 1.
dataset_config = dataloader_cfg['dataset']
max_seq_len = dataset_config.get('max_seq_len')
if max_seq_len <= 100:
return 1

min_ratio = 1
max_ratio = max_seq_len / 100
profiling_results = profile_packing(
dataloader_cfg,
tokenizer,
min_ratio,
max_ratio,
num_packing_ratios,
device_batch_size,
dataloader_cfg=dataloader_cfg,
tokenizer=tokenizer,
min_ratio=min_ratio,
max_ratio=max_ratio,
num_packing_ratios=num_packing_ratios,
device_batch_size=device_batch_size,
)

# Obtain the maximum packing_ratio/minimum padding that has no waste.
Expand All @@ -392,7 +392,7 @@ def auto_packing_ratio(


def profile_packing(
dataloader_cfg: DictConfig,
dataloader_cfg: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
min_ratio: float,
max_ratio: float,
Expand All @@ -416,39 +416,40 @@ def profile_packing(

from llmfoundry.data.dataloader import build_dataloader

max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
max_leftovers_to_keep = dataloader_cfg.dataset.get(
'max_leftovers_to_keep',
None,
)
dataset_cfg = dataloader_cfg['dataset']
max_seq_len = dataset_cfg.get('max_seq_len')
max_leftovers_to_keep = dataset_cfg.get('max_leftovers_to_keep', None)

# Turn off packing and sequence parallelism for the dataloader (we want raw, pre-packed, full-length examples)
dataloader_cfg = copy.deepcopy(dataloader_cfg)
dataloader_cfg.dataset.packing_ratio = 1.0
dataloader_cfg.dataset.auto_packing_replication = dataloader_cfg.dataset.get(
'seq_parallel_replication',
1,
) or 1
dataloader_cfg.dataset.seq_parallel_replication = 1
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.persistent_workers = False
dataloader_cfg.update({
'drop_last': False,
'num_workers': 0,
'prefetch_factor': None,
'persistent_workers': False,
})
dataloader_cfg['dataset']['packing_ratio'] = 1.0
dataloader_cfg['dataset']['auto_packing_replication'
] = dataloader_cfg['dataset'].get(
'seq_parallel_replication',
1,
) or 1
dataloader_cfg['dataset']['seq_parallel_replication'] = 1

# If streaming dataset, use a temporary local folder for profiling
local_rank_zero = dist.get_global_rank() - dist.get_local_rank()
if dataloader_cfg.dataset.get('remote') is not None:
if dataloader_cfg['dataset'].get('remote') is not None:
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
dataloader_cfg.dataset.local = tmp_path
dataloader_cfg['dataset']['local'] = tmp_path

if dataloader_cfg.dataset.get('streams') is not None:
for stream_config in dataloader_cfg.dataset.streams.values():
if dataloader_cfg['dataset'].get('streams') is not None:
for stream_config in dataloader_cfg['dataset']['streams'].values():
tmp_path_to_broadcast = tempfile.TemporaryDirectory().name
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
tmp_path = gathered_paths[local_rank_zero]
stream_config.local = tmp_path
stream_config['local'] = tmp_path

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
Expand Down
80 changes: 49 additions & 31 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import numpy as np
import torch
from composer.core.data_spec import DataSpec
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from streaming import Stream, StreamingDataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -268,79 +266,96 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataset_cfg: DictConfig):
streams_dict = dataset_cfg.pop('streams', None)
def build_streams(streams: Optional[Dict[str, Any]] = None,):
streams_dict = streams
# build streams
streams = None
streams_ret = []
if streams_dict is not None:
streams = []
for stream in streams_dict.values():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams
streams_ret = [Stream(**stream) for stream in streams_dict.values()]
return streams_ret


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: Union[int, float],
device_batch_size: int,
dataset: Dict[str, Any],
drop_last: bool,
num_workers: int,
pin_memory: bool = True,
prefetch_factor: int = 2,
persistent_workers: bool = True,
timeout: int = 0,
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

dataset_cfg = dataset

# get kwargs
cfg.dataset['replication'], dataset_batch_size = construct_from_registry(
dataset_cfg['replication'], dataset_batch_size = construct_from_registry(
name='dataset_replication_validator',
registry=registry.dataset_replication_validators,
partial_function=False,
kwargs={
'cfg': cfg,
'dataset_cfg': dataset_cfg,
'tokenizer': tokenizer,
'device_batch_size': device_batch_size,
},
)

streams = build_streams(cfg.dataset)
streams = build_streams(
streams=dataset_cfg.pop('streams')
if 'streams' in dataset_cfg else None,
)

valid_streaming_text_dataset_parameters = inspect.signature(
StreamingTextDataset,
).parameters

dataset_config_subset_for_streaming_text_dataset = {
k: v
for k, v in cfg.dataset.items()
for k, v in dataset_cfg.items()
if k in valid_streaming_text_dataset_parameters
}

# build dataset potentially with streams
dataset = StreamingTextDataset(
text_dataset = StreamingTextDataset(
tokenizer=tokenizer,
streams=streams,
batch_size=dataset_batch_size,
**dataset_config_subset_for_streaming_text_dataset,
)

dataloader_cfg = {
'name': 'text',
'dataset': dataset_cfg,
'drop_last': drop_last,
'num_workers': num_workers,
'pin_memory': pin_memory,
'prefetch_factor': prefetch_factor,
'persistent_workers': persistent_workers,
'timeout': timeout,
}

collate_fn, dataloader_batch_size = construct_from_registry(
name='text_collator',
registry=registry.collators,
partial_function=False,
kwargs={
'cfg': cfg,
'tokenizer': dataset.tokenizer,
'dataloader_cfg': dataloader_cfg,
'tokenizer': tokenizer,
'dataset_batch_size': dataset_batch_size,
},
)

dl = DataLoader(
dataset,
text_dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
timeout=timeout,
)

return construct_from_registry(
Expand All @@ -349,7 +364,7 @@ def build_text_dataloader(
partial_function=False,
kwargs={
'dl': dl,
'dataset_cfg': cfg.dataset,
'dataset_cfg': dataset_cfg,
},
)

Expand Down Expand Up @@ -415,14 +430,17 @@ def build_text_dataloader(
'drop_last': False,
'num_workers': 4,
}
cfg = om.create(cfg)
device_batch_size = 2

tokenizer_name = args.tokenizer
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
loader = build_text_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader
assert isinstance(loader, DataLoader)
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer
Expand Down
Loading
Loading