Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlflow datasets #1119

Merged
merged 47 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
a99525b
add logger
KuuCi Apr 18, 2024
a9e7b0d
reqs
KuuCi Apr 18, 2024
ecdfaca
small fix
KuuCi Apr 18, 2024
c29cbc2
import mlflow
KuuCi Apr 18, 2024
b787cca
parse_uri
KuuCi Apr 18, 2024
ab7268f
parse_uri
KuuCi Apr 18, 2024
a06fcb2
finished debug
KuuCi Apr 18, 2024
aae821d
precommit
KuuCi Apr 18, 2024
996fb01
more code fix
KuuCi Apr 18, 2024
e507eac
revert setup
KuuCi Apr 18, 2024
f570842
better dovs
KuuCi Apr 18, 2024
1245c64
rm docstr
KuuCi Apr 18, 2024
c9006c5
precommit
KuuCi Apr 18, 2024
d79410c
Update tests to not rely on mistral (#1117)
dakinggg Apr 18, 2024
d47caea
Bump transformers to 4.40 (#1118)
dakinggg Apr 18, 2024
af68170
merge
KuuCi Apr 18, 2024
ca698b8
revert setup
KuuCi Apr 18, 2024
802dd8c
Merge branch 'main' into mlflow-datasets
KuuCi Apr 18, 2024
47bf6cb
precommit
KuuCi Apr 18, 2024
bbcabcc
precommit
KuuCi Apr 19, 2024
cf7c9df
tweaks to resolve comments
KuuCi Apr 19, 2024
eb2afbb
unit test
KuuCi Apr 19, 2024
05c2461
code quality
KuuCi Apr 19, 2024
bb86a78
quotation
KuuCi Apr 19, 2024
f3d8348
quote
KuuCi Apr 19, 2024
c44fafa
more quality
KuuCi Apr 19, 2024
6250a44
optional
KuuCi Apr 19, 2024
e38964b
pyright
KuuCi Apr 19, 2024
0199788
type check
KuuCi Apr 19, 2024
90bcad0
rm typechecking
KuuCi Apr 19, 2024
0d66d1d
yapf
KuuCi Apr 19, 2024
d286e15
first pass
KuuCi Apr 19, 2024
6a6632c
fix
KuuCi Apr 19, 2024
969c1c0
get refactor
KuuCi Apr 19, 2024
d472282
refactor
KuuCi Apr 19, 2024
cbf0c30
local hf path
KuuCi Apr 19, 2024
8dd8cec
dbfs
KuuCi Apr 19, 2024
3457eb5
rm local
KuuCi Apr 19, 2024
48843f9
typo
KuuCi Apr 19, 2024
39ba332
second pass
KuuCi Apr 22, 2024
5e0f853
update
KuuCi Apr 22, 2024
84a7930
Merge branch 'main' into mlflow-datasets
KuuCi Apr 22, 2024
6aae695
Merge branch 'main' into mlflow-datasets
dakinggg Apr 23, 2024
f63c325
Merge branch 'main' into mlflow-datasets
dakinggg Apr 23, 2024
a9fda9c
third pass
KuuCi Apr 23, 2024
85fa2df
os.path.join
KuuCi Apr 24, 2024
6cecaec
precommit
KuuCi Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 134 additions & 9 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import contextlib
import logging
import math
import os
import warnings
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union

from composer.utils import dist
import mlflow
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
from composer.utils import dist, parse_uri
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

Expand Down Expand Up @@ -177,10 +179,133 @@ def log_config(cfg: DictConfig) -> None:
if wandb.run:
wandb.config.update(om.to_container(cfg, resolve=True))

if 'mlflow' in cfg.get('loggers', {}):
try:
import mlflow
except ImportError as e:
raise e
if mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
if 'mlflow' in cfg.get('loggers', {}) and mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
_log_dataset_uri(cfg)


def _parse_source_dataset(cfg: DictConfig) -> List[Tuple[str, str, str]]:
"""Parse a run config for dataset information.

Given a config dictionary, parse through it to determine what the datasource
should be categorized as. Possible data sources are Delta Tables, UC Volumes,
HuggingFace paths, remote storage, or local storage.

Args:
cfg (DictConfig): A config dictionary of a run

Returns:
List[Tuple[str, str, str]]: A list of tuples formatted as (data type, path, split)
"""
data_paths = []

# Handle train loader if it exists
train_dataset = cfg.get('train_loader', {}).get('dataset', {})
train_split = train_dataset.get('split', None)
train_source_path = cfg.get('source_dataset_train', None)
_process_data_source(train_source_path, train_dataset, train_split, 'train',
data_paths)

# Handle eval_loader which might be a list or a single dictionary
eval_data_loaders = cfg.get('eval_loader', {})
if not isinstance(eval_data_loaders, ListConfig):
eval_data_loaders = [eval_data_loaders
] # Normalize to list if it's a single dictionary

for eval_data_loader in eval_data_loaders:
eval_dataset = eval_data_loader.get('dataset', {})
eval_split = eval_dataset.get('split', None)
eval_source_path = cfg.get('source_dataset_eval', None)
_process_data_source(eval_source_path, eval_dataset, eval_split, 'eval',
data_paths)

return data_paths


def _process_data_source(source_dataset_path: Optional[str],
dataset: Dict[str, str], cfg_split: Optional[str],
true_split: str, data_paths: List[Tuple[str, str,
str]]):
"""Add a data source by mutating data_paths.

Given various dataset attributes, attempt to determine what type of dataset is being added, and parse
the dataset accordingly.

KuuCi marked this conversation as resolved.
Show resolved Hide resolved
Args:
source_dataset_path (Optional[str]): The source dataset in cfg metadata
dataset (Dict[str, str]): The dataset from cfg
cfg_split (str): The split listed for the dataset in cfg
true_split (str): The split of the dataset to be added (i.e. train or eval)
data_paths (List[Tuple[str, str, str]]): A list of tuples formatted as (data type, path, split)
"""
# Check for Delta table
if source_dataset_path and len(source_dataset_path.split('.')) == 3:
data_paths.append(('delta_table', source_dataset_path, true_split))
# Check for UC volume
elif source_dataset_path and source_dataset_path.startswith('dbfs:'):
data_paths.append(
('uc_volume', source_dataset_path[len('dbfs:'):], true_split))
# Check for HF path
elif 'hf_name' in dataset:
hf_path = dataset['hf_name']
backend, _, _ = parse_uri(hf_path)
if backend:
hf_path = os.path.join(hf_path, cfg_split) if cfg_split else hf_path
data_paths.append((backend, hf_path, true_split))
elif os.path.exists(hf_path):
data_paths.append(('local', hf_path, true_split))
else:
data_paths.append(('hf', hf_path, true_split))
# Check for remote path
elif 'remote' in dataset:
remote_path = dataset['remote']
backend, _, _ = parse_uri(remote_path)
if backend:
remote_path = os.path.join(
remote_path, f'{cfg_split}/') if cfg_split else remote_path
data_paths.append((backend, remote_path, true_split))
else:
data_paths.append(('local', remote_path, true_split))
else:
log.warning('DataSource Not Found.')


def _log_dataset_uri(cfg: DictConfig) -> None:
"""Logs dataset tracking information to MLflow.

Args:
cfg (DictConfig): A config dictionary of a run
"""
# Figure out which data source to use
data_paths = _parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
'oci': mlflow.data.http_dataset_source.HTTPDatasetSource,
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
'azure': mlflow.data.http_dataset_source.HTTPDatasetSource,
'gs': mlflow.data.http_dataset_source.HTTPDatasetSource,
'https': mlflow.data.http_dataset_source.HTTPDatasetSource,
'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource,
'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource,
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
}

KuuCi marked this conversation as resolved.
Show resolved Hide resolved
# Map data source types to their respective MLFlow DataSource.
for dataset_type, path, split in data_paths:

if dataset_type in dataset_source_mapping:
source_class = dataset_source_mapping[dataset_type]
if dataset_type == 'delta_table':
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
source = source_class(delta_table_name=path)
elif dataset_type == 'hf' or dataset_type == 'uc_volume':
source = source_class(path=path)
else:
source = source_class(url=path)
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
else:
log.info(
f'{dataset_type} unknown, defaulting to http dataset source')
source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path)

mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(source, name=split))
133 changes: 133 additions & 0 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from omegaconf import OmegaConf

from llmfoundry.utils.config_utils import (_log_dataset_uri,
_parse_source_dataset)

mlflow = pytest.importorskip('mlflow')
from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource


def create_config(**kwargs: Any):
"""Helper function to create OmegaConf configurations."""
return OmegaConf.create(kwargs)


def test_parse_source_dataset_delta_table():
cfg = create_config(source_dataset_train='db.schema.train_table',
source_dataset_eval='db.schema.eval_table')
expected = [('delta_table', 'db.schema.train_table', 'train'),
('delta_table', 'db.schema.eval_table', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_uc_volume():
cfg = create_config(source_dataset_train='dbfs:/Volumes/train_data',
source_dataset_eval='dbfs:/Volumes/eval_data')
expected = [('uc_volume', '/Volumes/train_data', 'train'),
('uc_volume', '/Volumes/eval_data', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_hf():
cfg = create_config(
train_loader={'dataset': {
'hf_name': 'huggingface/train_dataset',
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset',
}})
expected = [('hf', 'huggingface/train_dataset', 'train'),
('hf', 'huggingface/eval_dataset', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_parse_source_dataset_remote():
cfg = create_config(train_loader={
'dataset': {
'remote': 'https://remote/train_dataset',
'split': 'train'
}
},
eval_loader={
'dataset': {
'remote': 'https://remote/eval_dataset',
'split': 'eval'
}
})
expected = [('https', 'https://remote/train_dataset/train/', 'train'),
('https', 'https://remote/eval_dataset/eval/', 'eval')]
assert _parse_source_dataset(cfg) == expected


def test_log_dataset_uri():
cfg = create_config(
train_loader={'dataset': {
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
'hf_name': 'huggingface/train_dataset'
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset'
}},
source_dataset_train='huggingface/train_dataset',
source_dataset_eval='huggingface/eval_dataset')

with patch('mlflow.log_input') as mock_log_input:
_log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
meta_dataset_calls = [
args[0] for args, _ in mock_log_input.call_args_list
]
assert all(
isinstance(call.source, HuggingFaceDatasetSource)
for call in meta_dataset_calls), 'Source types are incorrect'
# Verify the names
assert meta_dataset_calls[
0].name == 'train', f"Expected 'train', got {meta_dataset_calls[0].name}"
assert meta_dataset_calls[
1].name == 'eval', f"Expected 'eval', got {meta_dataset_calls[1].name}"


def test_multiple_eval_datasets():
# Setup a configuration with multiple evaluation datasets
cfg = OmegaConf.create({
'train_loader': {
'dataset': {
'hf_name': 'huggingface/train_dataset',
},
},
'eval_loader': [{
'dataset': {
'hf_name': 'huggingface/eval_dataset1',
},
}, {
'dataset': {
'hf_name': 'huggingface/eval_dataset2',
},
}]
})

expected_data_paths = [('hf', 'huggingface/train_dataset', 'train'),
('hf', 'huggingface/eval_dataset1', 'eval'),
('hf', 'huggingface/eval_dataset2', 'eval')]

# Mock mlflow to avoid any actual logging calls
with patch('mlflow.data.meta_dataset.MetaDataset') as mock_meta_dataset:
mock_meta_dataset.side_effect = lambda source, name: MagicMock()
data_paths = _parse_source_dataset(cfg)
assert sorted(data_paths) == sorted(
expected_data_paths), 'Data paths did not match expected'


@pytest.fixture
def mock_mlflow_classes():
with patch('mlflow.data.http_dataset_source.HTTPDatasetSource') as http_source, \
patch('mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource') as hf_source, \
patch('mlflow.data.delta_dataset_source.DeltaDatasetSource') as delta_source, \
patch('mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource') as uc_source:
yield http_source, hf_source, delta_source, uc_source
Loading