diff --git a/README.md b/README.md index 0558759..ee43f54 100644 --- a/README.md +++ b/README.md @@ -96,11 +96,13 @@ trainset = mz.dataloader.Dataset( data_pack=train_processed, mode='pair', num_dup=1, - num_neg=4 + num_neg=4, + batch_size=32 ) validset = mz.dataloader.Dataset( data_pack=valid_processed, - mode='point' + mode='point', + batch_size=32 ) ``` @@ -110,13 +112,11 @@ padding_callback = mz.models.ArcI.get_default_padding_callback() trainloader = mz.dataloader.DataLoader( dataset=trainset, - batch_size=32, stage='train', callback=padding_callback ) validloader = mz.dataloader.DataLoader( dataset=validset, - batch_size=32, stage='dev', callback=padding_callback ) @@ -127,6 +127,8 @@ Initialize the model, fine-tune the hyper-parameters: ```python model = mz.models.ArcI() model.params['task'] = ranking_task +model.params['embedding_output_dim'] = 100 +model.params['embedding_input_dim'] = preprocessor.context['embedding_input_dim'] model.guess_and_fill_missing_params() model.build() ``` diff --git a/matchzoo/auto/preparer/preparer.py b/matchzoo/auto/preparer/preparer.py index e12c0c7..a7d9905 100644 --- a/matchzoo/auto/preparer/preparer.py +++ b/matchzoo/auto/preparer/preparer.py @@ -159,14 +159,20 @@ def _build_matrix(self, preprocessor, embedding): return np.random.uniform(-0.2, 0.2, matrix_shape) def _build_dataset_builder(self, model, embedding_matrix, preprocessor): - builder_kwargs = dict(callbacks=[]) + builder_kwargs = dict( + callbacks=[], + batch_size=self._config['batch_size'], + shuffle=self._config['shuffle'], + sort=self._config['sort'] + ) if isinstance(self._task.losses[0], (mz.losses.RankHingeLoss, mz.losses.RankCrossEntropyLoss)): builder_kwargs.update(dict( mode='pair', num_dup=self._config['num_dup'], - num_neg=self._config['num_neg'] + num_neg=self._config['num_neg'], + resample=self._config['resample'], )) if isinstance(model, mz.models.CDSSM): @@ -201,11 +207,7 @@ def _build_dataset_builder(self, model, embedding_matrix, preprocessor): def _build_dataloader_builder(self, model, callback): builder_kwargs = dict( - batch_size=self._config['batch_size'], stage=self._config['stage'], - resample=self._config['resample'], - shuffle=self._config['shuffle'], - sort=self._config['sort'], callback=callback ) return DataLoaderBuilder(**builder_kwargs) diff --git a/matchzoo/dataloader/callbacks/__init__.py b/matchzoo/dataloader/callbacks/__init__.py index 3bb8164..eaeb31f 100755 --- a/matchzoo/dataloader/callbacks/__init__.py +++ b/matchzoo/dataloader/callbacks/__init__.py @@ -1,5 +1,4 @@ from .lambda_callback import LambdaCallback -from .dynamic_pooling import DynamicPooling from .histogram import Histogram from .ngram import Ngram from .padding import BasicPadding diff --git a/matchzoo/dataloader/callbacks/dynamic_pooling.py b/matchzoo/dataloader/callbacks/dynamic_pooling.py deleted file mode 100755 index 2746313..0000000 --- a/matchzoo/dataloader/callbacks/dynamic_pooling.py +++ /dev/null @@ -1,92 +0,0 @@ -import numpy as np - -from matchzoo.engine.base_callback import BaseCallback - - -class DynamicPooling(BaseCallback): - """:class:`DPoolPairDataGenerator` constructor. - - :param fixed_length_left: max length of left text. - :param fixed_length_right: max length of right text. - :param compress_ratio_left: the length change ratio, - especially after normal pooling layers. - :param compress_ratio_right: the length change ratio, - especially after normal pooling layers. - """ - - def __init__( - self, - fixed_length_left: int, - fixed_length_right: int, - compress_ratio_left: float = 1, - compress_ratio_right: float = 1, - ): - """Init.""" - self._fixed_length_left = fixed_length_left - self._fixed_length_right = fixed_length_right - self._compress_ratio_left = compress_ratio_left - self._compress_ratio_right = compress_ratio_right - - def on_batch_unpacked(self, x, y): - """ - Insert `dpool_index` into `x`. - - :param x: unpacked x. - :param y: unpacked y. - """ - x['dpool_index'] = _dynamic_pooling_index( - x['length_left'], - x['length_right'], - self._fixed_length_left, - self._fixed_length_right, - self._compress_ratio_left, - self._compress_ratio_right - ) - - -def _dynamic_pooling_index(length_left: np.array, - length_right: np.array, - fixed_length_left: int, - fixed_length_right: int, - compress_ratio_left: float, - compress_ratio_right: float) -> np.array: - def _dpool_index(one_length_left: int, - one_length_right: int, - fixed_length_left: int, - fixed_length_right: int): - if one_length_left == 0: - stride_left = fixed_length_left - else: - stride_left = 1.0 * fixed_length_left / one_length_left - - if one_length_right == 0: - stride_right = fixed_length_right - else: - stride_right = 1.0 * fixed_length_right / one_length_right - - one_idx_left = [int(i / stride_left) - for i in range(fixed_length_left)] - one_idx_right = [int(i / stride_right) - for i in range(fixed_length_right)] - mesh1, mesh2 = np.meshgrid(one_idx_left, one_idx_right) - index_one = np.transpose( - np.stack([mesh1, mesh2]), (2, 1, 0)) - return index_one - - index = [] - dpool_bias_left = dpool_bias_right = 0 - if fixed_length_left % compress_ratio_left != 0: - dpool_bias_left = 1 - if fixed_length_right % compress_ratio_right != 0: - dpool_bias_right = 1 - cur_fixed_length_left = int( - fixed_length_left // compress_ratio_left) + dpool_bias_left - cur_fixed_length_right = int( - fixed_length_right // compress_ratio_right) + dpool_bias_right - for i in range(len(length_left)): - index.append(_dpool_index( - length_left[i] // compress_ratio_left, - length_right[i] // compress_ratio_right, - cur_fixed_length_left, - cur_fixed_length_right)) - return np.array(index) diff --git a/matchzoo/dataloader/callbacks/padding.py b/matchzoo/dataloader/callbacks/padding.py index 27a176b..b38ce46 100755 --- a/matchzoo/dataloader/callbacks/padding.py +++ b/matchzoo/dataloader/callbacks/padding.py @@ -1,10 +1,35 @@ import typing +from collections import Iterable import numpy as np from matchzoo.engine.base_callback import BaseCallback +def _infer_dtype(value): + """Infer the dtype for the features. + + It is required as the input is usually array of objects before padding. + """ + while isinstance(value, (list, tuple)) and len(value) > 0: + value = value[0] + + if not isinstance(value, Iterable): + return np.array(value).dtype + + if value is not None and len(value) > 0 and np.issubdtype( + np.array(value).dtype, np.generic): + dtype = np.array(value[0]).dtype + else: + dtype = value.dtype + + # Single Precision + if dtype == np.double: + dtype = np.float32 + + return dtype + + def _padding_2D(input, output, mode: str = 'pre'): """ Pad the input 2D-tensor to the output 2D-tensor. @@ -122,24 +147,26 @@ def on_batch_unpacked(self, x: dict, y: np.ndarray): pad_length_right = self._fixed_length_right for key, value in x.items(): + dtype = _infer_dtype(value) + if key == 'text_left': padded_value = np.full([batch_size, pad_length_left], - self._pad_word_value, dtype=value.dtype) + self._pad_word_value, dtype=dtype) _padding_2D(value, padded_value, self._pad_word_mode) elif key == 'text_right': padded_value = np.full([batch_size, pad_length_right], - self._pad_word_value, dtype=value.dtype) + self._pad_word_value, dtype=dtype) _padding_2D(value, padded_value, self._pad_word_mode) elif key == 'ngram_left': padded_value = np.full( [batch_size, pad_length_left, ngram_length], - self._pad_ngram_value, dtype=value.dtype + self._pad_ngram_value, dtype=dtype ) _padding_3D(value, padded_value, self._pad_ngram_mode) elif key == 'ngram_right': padded_value = np.full( [batch_size, pad_length_right, ngram_length], - self._pad_ngram_value, dtype=value.dtype + self._pad_ngram_value, dtype=dtype ) _padding_3D(value, padded_value, self._pad_ngram_mode) else: @@ -193,18 +220,21 @@ def on_batch_unpacked(self, x: dict, y: np.ndarray): if key != 'text_left' and key != 'text_right' and \ key != 'match_histogram': continue - elif key == 'text_left': + + dtype = _infer_dtype(value) + + if key == 'text_left': padded_value = np.full([batch_size, pad_length_left], - self._pad_value, dtype=value.dtype) + self._pad_value, dtype=dtype) _padding_2D(value, padded_value, self._pad_mode) elif key == 'text_right': padded_value = np.full([batch_size, pad_length_right], - self._pad_value, dtype=value.dtype) + self._pad_value, dtype=dtype) _padding_2D(value, padded_value, self._pad_mode) else: # key == 'match_histogram' padded_value = np.full( [batch_size, pad_length_left, bin_size], - self._pad_value, dtype=value.dtype) + self._pad_value, dtype=dtype) _padding_3D(value, padded_value, self._pad_mode) x[key] = padded_value diff --git a/matchzoo/dataloader/dataloader.py b/matchzoo/dataloader/dataloader.py index ca33387..2266118 100755 --- a/matchzoo/dataloader/dataloader.py +++ b/matchzoo/dataloader/dataloader.py @@ -1,16 +1,13 @@ """Basic data loader.""" import typing - import math -import random -import collections + import numpy as np import torch from torch.utils import data +from matchzoo.dataloader.dataset import Dataset from matchzoo.engine.base_callback import BaseCallback -from matchzoo.dataloader.sampler import (SequentialSampler, RandomSampler, - SortedSampler, BatchSampler) class DataLoader(object): @@ -18,16 +15,10 @@ class DataLoader(object): DataLoader that loads batches of data from a Dataset. :param dataset: The Dataset object to load data from. - :param batch_size: Batch_size. (default: 32) :param device: The desired device of returned tensor. Default: if None, use the current device. If `torch.device` or int, use device specified by user. If list, the first item will be used. :param stage: One of "train", "dev", and "test". (default: "train") - :param resample: Whether to resample data between epochs. only effective - when `mode` of dataset is "pair". (default: `True`) - :param shuffle: Whether to shuffle data between epochs. (default: `False`) - :param sort: Whether to sort data according to length_right. (default: - `True`) :param callback: BaseCallback. See `matchzoo.engine.base_callback.BaseCallback` for more details. :param pin_momory: If set to `True`, tensors will be copied into @@ -45,7 +36,8 @@ class DataLoader(object): >>> data_pack = mz.datasets.toy.load_data(stage='train') >>> preprocessor = mz.preprocessors.BasicPreprocessor() >>> data_processed = preprocessor.fit_transform(data_pack) - >>> dataset = mz.dataloader.Dataset(data_processed, mode='point') + >>> dataset = mz.dataloader.Dataset( + ... data_processed, mode='point', batch_size=32) >>> padding_callback = mz.dataloader.callbacks.BasicPadding() >>> dataloader = mz.dataloader.DataLoader( ... dataset, stage='train', callback=padding_callback) @@ -56,13 +48,9 @@ class DataLoader(object): def __init__( self, - dataset: data.Dataset, - batch_size: int = 32, + dataset: Dataset, device: typing.Union[torch.device, int, list, None] = None, stage='train', - resample: bool = True, - shuffle: bool = False, - sort: bool = True, callback: BaseCallback = None, pin_memory: bool = False, timeout: int = 0, @@ -74,10 +62,6 @@ def __init__( raise ValueError(f"{stage} is not a valid stage type." f"Must be one of `train`, `dev`, `test`.") - if shuffle and sort: - raise ValueError(f"parameters `shuffle` and `sort` conflict, " - f"should not both be `True`.") - if isinstance(device, list) and len(device): device = device[0] elif not (isinstance(device, torch.device) or isinstance(device, int)): @@ -85,68 +69,44 @@ def __init__( "cuda" if torch.cuda.is_available() else "cpu") self._dataset = dataset - self._batch_size = batch_size - self._shuffle = shuffle - self._sort = sort - self._resample = resample - self._pin_momory = pin_memory self._timeout = timeout self._num_workers = num_workers self._worker_init_fn = worker_init_fn - self._device = device self._stage = stage self._callback = callback - self._dataloader = None + self._dataloader = data.DataLoader( + self._dataset, + batch_size=None, + shuffle=False, + collate_fn=lambda x: x, + batch_sampler=None, + num_workers=self._num_workers, + pin_memory=self._pin_momory, + timeout=self._timeout, + worker_init_fn=self._worker_init_fn, + ) def __len__(self) -> int: """Get the total number of batches.""" - return math.ceil(len(self._dataset) / self._batch_size) + return len(self._dataset) @property def id_left(self) -> np.ndarray: """`id_left` getter.""" - indices = sum(self._dataset.index_pool[:], []) - x, _ = self._dataset[indices] + x, _ = self._dataset[:] return x['id_left'] @property def label(self) -> np.ndarray: """`label` getter.""" - indices = sum(self._dataset.index_pool[:], []) - _, y = self._dataset[indices] + _, y = self._dataset[:] return y.squeeze() if y is not None else None - def init_epoch(self): - """Resample, shuffle or sort the dataset for a new epoch.""" - if self._resample: - self._dataset.sample() - - if not self._shuffle and not self._sort: - sampler = SequentialSampler(self._dataset) - elif not self._shuffle and self._sort: - sampler = SortedSampler(self._dataset) - elif self._shuffle and not self._sort: - sampler = RandomSampler(self._dataset) - - batch_sampler = BatchSampler( - sampler, self._batch_size) - - self._dataloader = data.DataLoader( - self._dataset, - collate_fn=mz_collate, - batch_sampler=batch_sampler, - num_workers=self._num_workers, - pin_memory=False, - timeout=self._timeout, - worker_init_fn=self._worker_init_fn, - ) - def __iter__(self) -> typing.Tuple[dict, torch.tensor]: """Iteration.""" - self.init_epoch() for batch_data in self._dataloader: x, y = batch_data self._handle_callbacks_on_batch_unpacked(x, y) @@ -156,48 +116,19 @@ def __iter__(self) -> typing.Tuple[dict, torch.tensor]: if key == 'id_left' or key == 'id_right': continue batch_x[key] = torch.tensor( - value.tolist(), - device=self._device, - pin_memory=self._pin_momory) + value, device=self._device) if self._stage == 'test': yield batch_x, None else: if y.dtype == 'int': # task='classification' batch_y = torch.tensor( - y.squeeze(axis=-1), dtype=torch.long, - device=self._device, pin_memory=self._pin_momory - ) + y.squeeze(axis=-1), dtype=torch.long, device=self._device) else: # task='ranking' batch_y = torch.tensor( - y, dtype=torch.float, - device=self._device, pin_memory=self._pin_momory - ) + y, dtype=torch.float, device=self._device) yield batch_x, batch_y def _handle_callbacks_on_batch_unpacked(self, x, y): if self._callback is not None: self._callback.on_batch_unpacked(x, y) - - -def mz_collate(batch): - """Put each data field into an array with outer dimension batch size.""" - - batch_x = collections.defaultdict(list) - batch_y = [] - - for x, y in batch: - for key in x.keys(): - batch_x[key].append(np.squeeze(x[key], axis=0)) - if y is not None: - batch_y.append(np.squeeze(y, axis=0)) - - for key in batch_x.keys(): - batch_x[key] = np.array(batch_x[key]) - - if len(batch_y) == 0: - batch_y = None - else: - batch_y = np.array(batch_y) - - return batch_x, batch_y diff --git a/matchzoo/dataloader/dataset.py b/matchzoo/dataloader/dataset.py index 9891102..6c71db7 100755 --- a/matchzoo/dataloader/dataset.py +++ b/matchzoo/dataloader/dataset.py @@ -1,7 +1,8 @@ """A basic class representing a Dataset.""" import typing +import math +from collections import Iterable -import functools import numpy as np import pandas as pd from torch.utils import data @@ -10,7 +11,7 @@ from matchzoo.engine.base_callback import BaseCallback -class Dataset(data.Dataset): +class Dataset(data.IterableDataset): """ Dataset that is built from a data pack. @@ -20,21 +21,26 @@ class Dataset(data.Dataset): `mode` is "pair". (default: 1) :param num_neg: Number of negative samples per instance, only effective when `mode` is "pair". (default: 1) - :param callbacks: Callbacks. See `matchzoo.data_generator.callbacks` for - more details. + :param batch_size: Batch size. (default: 32) + :param resample: Either to resample for each epoch, only effective when + `mode` is "pair". (default: `True`) + :param shuffle: Either to shuffle the samples/instances. (default: `True`) + :param sort: Whether to sort data according to length_right. (default: `False`) + :param callbacks: Callbacks. See `matchzoo.dataloader.callbacks` for more details. Examples: >>> import matchzoo as mz >>> data_pack = mz.datasets.toy.load_data(stage='train') >>> preprocessor = mz.preprocessors.BasicPreprocessor() >>> data_processed = preprocessor.fit_transform(data_pack) - >>> dataset_point = mz.dataloader.Dataset(data_processed, mode='point') + >>> dataset_point = mz.dataloader.Dataset( + ... data_processed, mode='point', batch_size=32) >>> len(dataset_point) - 100 + 4 >>> dataset_pair = mz.dataloader.Dataset( - ... data_processed, mode='pair', num_neg=2) + ... data_processed, mode='pair', num_dup=2, num_neg=2, batch_size=32) >>> len(dataset_pair) - 5 + 1 """ @@ -44,6 +50,10 @@ def __init__( mode='point', num_dup: int = 1, num_neg: int = 1, + batch_size: int = 32, + resample: bool = False, + shuffle: bool = True, + sort: bool = False, callbacks: typing.List[BaseCallback] = None ): """Init.""" @@ -54,48 +64,86 @@ def __init__( raise ValueError(f"{mode} is not a valid mode type." f"Must be one of `point`, `pair` or `list`.") + if shuffle and sort: + raise ValueError(f"parameters `shuffle` and `sort` conflict, " + f"should not both be `True`.") + + data_pack = data_pack.copy() self._mode = mode self._num_dup = num_dup self._num_neg = num_neg + self._batch_size = batch_size + self._resample = (resample if mode != 'point' else False) + self._shuffle = shuffle + self._sort = sort self._orig_relation = data_pack.relation self._callbacks = callbacks + + if mode == 'pair': + data_pack.relation = self._reorganize_pair_wise( + relation=self._orig_relation, + num_dup=num_dup, + num_neg=num_neg + ) + self._data_pack = data_pack - self._index_pool = None - self.sample() + self._batch_indices = None - def __len__(self) -> int: - """Get the total number of instances.""" - return len(self._index_pool) + self.reset_index() - def __getitem__(self, item: int) -> typing.Tuple[dict, np.ndarray]: - """Get a set of instances from index idx. + def __getitem__(self, item) -> typing.Tuple[dict, np.ndarray]: + """Get a batch from index idx. - :param item: the index of the instance. + :param item: the index of the batch. """ - item_data_pack = self._data_pack[item] - self._handle_callbacks_on_batch_data_pack(item_data_pack) - x, y = item_data_pack.unpack() + if isinstance(item, slice): + indices = sum(self._batch_indices[item], []) + elif isinstance(item, Iterable): + indices = [self._batch_indices[i] for i in item] + else: + indices = self._batch_indices[item] + batch_data_pack = self._data_pack[indices] + self._handle_callbacks_on_batch_data_pack(batch_data_pack) + x, y = batch_data_pack.unpack() self._handle_callbacks_on_batch_unpacked(x, y) return x, y - def _handle_callbacks_on_batch_data_pack(self, batch_data_pack): - for callback in self._callbacks: - callback.on_batch_data_pack(batch_data_pack) - - def _handle_callbacks_on_batch_unpacked(self, x, y): - for callback in self._callbacks: - callback.on_batch_unpacked(x, y) + def __len__(self) -> int: + """Get the total number of batches.""" + return len(self._batch_indices) + + def __iter__(self): + """Create a generator that iterate over the Batches.""" + if self._resample or self._shuffle: + self.on_epoch_end() + for i in range(len(self)): + yield self[i] + + def on_epoch_end(self): + """Reorganize the index array if needed.""" + if self._resample: + self.resample_data() + self.reset_index() + + def resample_data(self): + """Reorganize data.""" + if self.mode != 'point': + self._data_pack.relation = self._reorganize_pair_wise( + relation=self._orig_relation, + num_dup=self._num_dup, + num_neg=self._num_neg + ) - def get_index_pool(self): + def reset_index(self): """ - Set the:attr:`_index_pool`. + Set the :attr:`_batch_indices`. - Here the :attr:`_index_pool` records the index of all the instances. + Here the :attr:`_batch_indices` records the index of all the instances. """ + # index pool: index -> instance index if self._mode == 'point': num_instances = len(self._data_pack) - index_pool = np.expand_dims(range(num_instances), axis=1).tolist() - return index_pool + index_pool = list(range(num_instances)) elif self._mode == 'pair': index_pool = [] step_size = self._num_neg + 1 @@ -106,48 +154,44 @@ def get_index_pool(self): indices = list(range(lower, upper)) if indices: index_pool.append(indices) - return index_pool elif self._mode == 'list': raise NotImplementedError( - f'{self._mode} data generator not implemented.') + f'{self._mode} dataset not implemented.') else: raise ValueError(f"{self._mode} is not a valid mode type" f"Must be one of `point`, `pair` or `list`.") - def sample(self): - """Resample the instances from data pack.""" - if self._mode == 'pair': - self._data_pack.relation = self._reorganize_pair_wise( - relation=self._orig_relation, - num_dup=self._num_dup, - num_neg=self._num_neg - ) - self._index_pool = self.get_index_pool() + if self._shuffle: + np.random.shuffle(index_pool) - def shuffle(self): - """Shuffle the instances.""" - np.random.shuffle(self._index_pool) + if self._sort: + old_index_pool = index_pool - def sort(self): - """Sort the instances by length_right.""" - old_index_pool = self._index_pool - max_instance_right_length = [] - for row in range(len(old_index_pool)): - instance = self._data_pack[old_index_pool[row]].unpack()[0] - max_instance_right_length.append(max(instance['length_right'])) - sort_index = np.argsort(max_instance_right_length) + max_instance_right_length = [] + for row in range(len(old_index_pool)): + instance = self._data_pack[old_index_pool[row]].unpack()[0] + max_instance_right_length.append(max(instance['length_right'])) + sort_index = np.argsort(max_instance_right_length) - self._index_pool = [old_index_pool[index] for index in sort_index] + index_pool = [old_index_pool[index] for index in sort_index] - @property - def data_pack(self): - """`data_pack` getter.""" - return self._data_pack + # batch_indices: index -> batch of indices + self._batch_indices = [] + for i in range(math.ceil(num_instances / self._batch_size)): + lower = self._batch_size * i + upper = self._batch_size * (i + 1) + candidates = index_pool[lower:upper] + if self._mode == 'pair': + candidates = sum(candidates, []) + self._batch_indices.append(candidates) + + def _handle_callbacks_on_batch_data_pack(self, batch_data_pack): + for callback in self._callbacks: + callback.on_batch_data_pack(batch_data_pack) - @data_pack.setter - def data_pack(self, value): - """`data_pack` setter.""" - self._data_pack = value + def _handle_callbacks_on_batch_unpacked(self, x, y): + for callback in self._callbacks: + callback.on_batch_unpacked(x, y) @property def callbacks(self): @@ -168,6 +212,8 @@ def num_neg(self): def num_neg(self, value): """`num_neg` setter.""" self._num_neg = value + self.resample_data() + self.reset_index() @property def num_dup(self): @@ -178,21 +224,62 @@ def num_dup(self): def num_dup(self, value): """`num_dup` setter.""" self._num_dup = value + self.resample_data() + self.reset_index() @property def mode(self): """`mode` getter.""" return self._mode - @mode.setter - def mode(self, value): - """`mode` setter.""" - self._mode = value + @property + def batch_size(self): + """`batch_size` getter.""" + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + """`batch_size` setter.""" + self._batch_size = value + self.reset_index() + + @property + def shuffle(self): + """`shuffle` getter.""" + return self._shuffle + + @shuffle.setter + def shuffle(self, value): + """`shuffle` setter.""" + self._shuffle = value + self.reset_index() + + @property + def sort(self): + """`sort` getter.""" + return self._sort + + @sort.setter + def sort(self, value): + """`sort` setter.""" + self._sort = value + self.reset_index() + + @property + def resample(self): + """`resample` getter.""" + return self._resample + + @resample.setter + def resample(self, value): + """`resample` setter.""" + self._resample = value + self.reset_index() @property - def index_pool(self): - """`index_pool` getter.""" - return self._index_pool + def batch_indices(self): + """`batch_indices` getter.""" + return self._batch_indices @classmethod def _reorganize_pair_wise( diff --git a/matchzoo/dataloader/sampler.py b/matchzoo/dataloader/sampler.py deleted file mode 100644 index 4455b76..0000000 --- a/matchzoo/dataloader/sampler.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Sampler class for dataloader.""" -import typing - -import math -import numpy as np -from torch.utils.data import Sampler, Dataset - -import matchzoo as mz - - -class SequentialSampler(Sampler): - """ - Samples elements sequentially, always in the same order. - - :param dataset: The dataset to sample from. - """ - - def __init__(self, dataset: Dataset): - """Init.""" - self._dataset = dataset - - def __iter__(self): - """Get the indices of a batch.""" - return iter(self._dataset.index_pool) - - def __len__(self): - """Get the total number of instances.""" - return len(self._dataset) - - -class SortedSampler(Sampler): - """ - Samples elements according to `length_right`. - - :param dataset: The dataset to sample from. - """ - - def __init__(self, dataset: Dataset): - """Init.""" - self._dataset = dataset - - def __iter__(self): - """Get the indices of a batch.""" - self._dataset.sort() - return iter(self._dataset.index_pool) - - def __len__(self): - """Get the total number of instances.""" - return len(self._dataset) - - -class RandomSampler(Sampler): - """ - Samples elements randomly. - - :param dataset: The dataset to sample from. - """ - - def __init__(self, dataset: Dataset): - """Init.""" - self._dataset = dataset - - def __iter__(self): - """Get the indices of a batch.""" - self._dataset.shuffle() - return iter(self._dataset.index_pool) - - def __len__(self): - """Get the total number of instances.""" - return len(self._dataset) - - -class BatchSampler(Sampler): - """ - Wraps another sampler to yield the indices of a batch. - - :param sampler: Base sampler. - :param batch_size: Size of a batch. - """ - - def __init__( - self, - sampler: Sampler, - batch_size: int = 32, - ): - """Init.""" - self._sampler = sampler - self._batch_size = batch_size - - def __iter__(self): - """Get the indices of a batch.""" - batch = [] - for idx in self._sampler: - batch.append(idx) - if len(batch) == self._batch_size: - batch = sum(batch, []) - yield batch - batch = [] - if len(batch) > 0: - batch = sum(batch, []) - yield batch - - def __len__(self): - """Get the total number of batch.""" - return math.ceil(len(self._sampler) / self._batch_size) diff --git a/matchzoo/preprocessors/bert_preprocessor.py b/matchzoo/preprocessors/bert_preprocessor.py index 43e6ee4..b3332c2 100644 --- a/matchzoo/preprocessors/bert_preprocessor.py +++ b/matchzoo/preprocessors/bert_preprocessor.py @@ -34,6 +34,8 @@ def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: :return: Transformed data as :class:`DataPack` object. """ + data_pack = data_pack.copy() + data_pack.apply_on_text(self._tokenizer.encode, mode='both', inplace=True, verbose=verbose) data_pack.append_text_length(inplace=True, verbose=verbose) diff --git a/matchzoo/preprocessors/naive_preprocessor.py b/matchzoo/preprocessors/naive_preprocessor.py index 15e0e47..843929f 100755 --- a/matchzoo/preprocessors/naive_preprocessor.py +++ b/matchzoo/preprocessors/naive_preprocessor.py @@ -54,6 +54,8 @@ def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: :return: Transformed data as :class:`DataPack` object. """ + data_pack = data_pack.copy() + units_ = self._default_units() units_.append(self._context['vocab_unit']) units_.append( diff --git a/matchzoo/preprocessors/units/vocabulary.py b/matchzoo/preprocessors/units/vocabulary.py index 783df94..5c3d393 100755 --- a/matchzoo/preprocessors/units/vocabulary.py +++ b/matchzoo/preprocessors/units/vocabulary.py @@ -59,7 +59,8 @@ def fit(self, tokens: list): self._context['term_index'][self._oov] = 1 self._context['index_term'][0] = self._pad self._context['index_term'][1] = self._oov - terms = set(tokens) + + terms = sorted(set(tokens)) for index, term in enumerate(terms): self._context['term_index'][term] = index + 2 self._context['index_term'][index + 2] = term diff --git a/matchzoo/version.py b/matchzoo/version.py index a570ec0..64c95ca 100644 --- a/matchzoo/version.py +++ b/matchzoo/version.py @@ -1,3 +1,3 @@ """Matchzoo version file.""" -__version__ = '1.1' +__version__ = '1.1.1' diff --git a/requirements.txt b/requirements.txt index 0cc6e09..928e5c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ pytorch-transformers >= 1.1.0 tabulate >= 0.8.3 nltk >= 3.4.3 numpy >= 1.16.4 -tqdm >= 4.32.2 +tqdm == 4.38.0 dill >= 0.2.9 hyperopt == 0.1.2 pandas == 0.24.2 diff --git a/setup.py b/setup.py index 7483da5..7def580 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ 'pytorch-transformers >= 1.1.0', 'nltk >= 3.4.3', 'numpy >= 1.16.4', - 'tqdm >= 4.32.2', + 'tqdm == 4.38.0', 'dill >= 0.2.9', 'pandas == 0.24.2', 'networkx >= 2.3', diff --git a/tests/dataloader/test_callbacks.py b/tests/dataloader/test_callbacks.py index 337702b..b37c287 100644 --- a/tests/dataloader/test_callbacks.py +++ b/tests/dataloader/test_callbacks.py @@ -16,18 +16,17 @@ def train_raw(): def test_basic_padding(train_raw): preprocessor = preprocessors.BasicPreprocessor() data_preprocessed = preprocessor.fit_transform(train_raw, verbose=0) - dataset = Dataset(data_preprocessed, mode='point') + dataset = Dataset(data_preprocessed, batch_size=5, mode='point') pre_fixed_padding = callbacks.BasicPadding( fixed_length_left=5, fixed_length_right=5, pad_word_mode='pre', with_ngram=False) - dataloader = DataLoader( - dataset, batch_size=5, callback=pre_fixed_padding) + dataloader = DataLoader(dataset, callback=pre_fixed_padding) for batch in dataloader: assert batch[0]['text_left'].shape == (5, 5) assert batch[0]['text_right'].shape == (5, 5) post_padding = callbacks.BasicPadding(pad_word_mode='post', with_ngram=False) - dataloader = DataLoader(dataset, batch_size=5, callback=post_padding) + dataloader = DataLoader(dataset, callback=post_padding) for batch in dataloader: max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) @@ -45,19 +44,18 @@ def test_drmm_padding(train_raw): histgram_callback = callbacks.Histogram( embedding_matrix=embedding_matrix, bin_size=30, hist_mode='LCH') dataset = Dataset( - data_preprocessed, mode='point', callbacks=[histgram_callback]) + data_preprocessed, mode='point', batch_size=5, callbacks=[histgram_callback]) pre_fixed_padding = callbacks.DRMMPadding( fixed_length_left=5, fixed_length_right=5, pad_mode='pre') - dataloader = DataLoader( - dataset, batch_size=5, callback=pre_fixed_padding) + dataloader = DataLoader(dataset, callback=pre_fixed_padding) for batch in dataloader: assert batch[0]['text_left'].shape == (5, 5) assert batch[0]['text_right'].shape == (5, 5) assert batch[0]['match_histogram'].shape == (5, 5, 30) post_padding = callbacks.DRMMPadding(pad_mode='post') - dataloader = DataLoader(dataset, batch_size=5, callback=post_padding) + dataloader = DataLoader(dataset, callback=post_padding) for batch in dataloader: max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) @@ -69,18 +67,17 @@ def test_drmm_padding(train_raw): def test_bert_padding(train_raw): preprocessor = preprocessors.BertPreprocessor() data_preprocessed = preprocessor.transform(train_raw, verbose=0) - dataset = Dataset(data_preprocessed, mode='point') + dataset = Dataset(data_preprocessed, mode='point', batch_size=5) pre_fixed_padding = callbacks.BertPadding( fixed_length_left=5, fixed_length_right=5, pad_mode='pre') - dataloader = DataLoader( - dataset, batch_size=5, callback=pre_fixed_padding) + dataloader = DataLoader(dataset, callback=pre_fixed_padding) for batch in dataloader: assert batch[0]['text_left'].shape == (5, 7) assert batch[0]['text_right'].shape == (5, 6) post_padding = callbacks.BertPadding(pad_mode='post') - dataloader = DataLoader(dataset, batch_size=5, callback=post_padding) + dataloader = DataLoader(dataset, callback=post_padding) for batch in dataloader: max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) diff --git a/tests/dataloader/test_dataset.py b/tests/dataloader/test_dataset.py new file mode 100644 index 0000000..1573f24 --- /dev/null +++ b/tests/dataloader/test_dataset.py @@ -0,0 +1,42 @@ +import matchzoo as mz +from matchzoo import preprocessors +from matchzoo.dataloader import Dataset + + +def test_dataset(): + data_pack = mz.datasets.toy.load_data('train', task='ranking') + preprocessor = mz.preprocessors.BasicPreprocessor() + data_processed = preprocessor.fit_transform(data_pack) + + dataset_point = mz.dataloader.Dataset( + data_processed, + mode='point', + batch_size=1, + resample=False, + shuffle=True, + sort=False + ) + dataset_point.batch_size = 10 + dataset_point.shuffle = not dataset_point.shuffle + dataset_point.sort = not dataset_point.sort + assert len(dataset_point.batch_indices) == 10 + + dataset_pair = mz.dataloader.Dataset( + data_processed, + mode='pair', + num_dup=1, + num_neg=1, + batch_size=1, + resample=True, + shuffle=False, + sort=False + ) + assert len(dataset_pair) == 5 + dataset_pair.num_dup = dataset_pair.num_dup + 1 + assert len(dataset_pair) == 10 + dataset_pair.num_neg = dataset_pair.num_neg + 2 + assert len(dataset_pair) == 10 + dataset_pair.batch_size = dataset_pair.batch_size + 1 + assert len(dataset_pair) == 5 + dataset_pair.resample = not dataset_pair.resample + assert len(dataset_pair) == 5 diff --git a/tutorials/ranking/drmmtks.ipynb b/tutorials/ranking/drmmtks.ipynb index ee2bf2c..d074fd8 100644 --- a/tutorials/ranking/drmmtks.ipynb +++ b/tutorials/ranking/drmmtks.ipynb @@ -14,20 +14,66 @@ "name": "stdout", "output_type": "stream", "text": [ - "matchzoo version 1.0\n", - "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", + "matchzoo version 1.1.1\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matchzoo as mz\n", + "print('matchzoo version', mz.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n" + ] + } + ], + "source": [ + "ranking_task = mz.tasks.Ranking(losses=mz.losses.RankHingeLoss())\n", + "ranking_task.metrics = [\n", + " mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n", + " mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n", + " mz.metrics.MeanAveragePrecision()\n", + "]\n", + "print(\"`ranking_task` initialized with metrics\", ranking_task.metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ "data loading ...\n", "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n" ] } ], "source": [ - "%run init.ipynb" + "print('data loading ...')\n", + "train_pack_raw = mz.datasets.wiki_qa.load_data('train', task=ranking_task)\n", + "dev_pack_raw = mz.datasets.wiki_qa.load_data('dev', task=ranking_task, filtered=True)\n", + "test_pack_raw = mz.datasets.wiki_qa.load_data('test', task=ranking_task, filtered=True)\n", + "print('data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`')" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:35:56.633000Z", @@ -45,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:36:06.249211Z", @@ -57,41 +103,41 @@ "name": "stderr", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9631.61it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5597.16it/s]\n", - "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 998192.22it/s]\n", - "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 144536.21it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 162628.74it/s]\n", - "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 860891.16it/s]\n", - "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 864576.45it/s]\n", - "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3220926.79it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 10322.95it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5702.03it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 93226.11it/s]\n", - "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 162708.08it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 121160.69it/s]\n", - "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 737406.48it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 856954.13it/s]\n", - "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 508269.59it/s]\n", - "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 637858.13it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 7459.91it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 4777.49it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 141066.87it/s]\n", - "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 75550.73it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 93212.33it/s]\n", - "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 82947.82it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 396091.21it/s]\n", - "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 121487.44it/s]\n", - "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 522939.61it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 9573.09it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 4603.63it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 135324.80it/s]\n", - "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 134658.64it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 130363.50it/s]\n", - "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 265080.01it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 619423.35it/s]\n", - "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 292539.74it/s]\n", - "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 680077.49it/s]\n" + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 11001.23it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5972.72it/s]\n", + "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 995451.11it/s]\n", + "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 165418.50it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 106788.94it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 603091.37it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 730393.10it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3237271.33it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 11244.91it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5965.35it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 151505.92it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 222132.82it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 142060.31it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 556055.07it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 787531.83it/s]\n", + "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 779462.65it/s]\n", + "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 908091.90it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 10722.40it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5876.72it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 153493.80it/s]\n", + "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 87024.67it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 112074.60it/s]\n", + "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 137407.38it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 487506.41it/s]\n", + "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 162704.32it/s]\n", + "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 637493.04it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 10475.15it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5817.60it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 153176.44it/s]\n", + "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 172249.19it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 141452.21it/s]\n", + "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 240631.82it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 707043.33it/s]\n", + "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 281122.75it/s]\n", + "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 782900.44it/s]\n" ] } ], @@ -103,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:36:06.262937Z", @@ -114,13 +160,13 @@ { "data": { "text/plain": [ - "{'filter_unit': ,\n", - " 'vocab_unit': ,\n", + "{'filter_unit': ,\n", + " 'vocab_unit': ,\n", " 'vocab_size': 16675,\n", " 'embedding_input_dim': 16675}" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -131,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -144,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -152,16 +198,20 @@ " data_pack=train_pack_processed,\n", " mode='pair',\n", " num_dup=2,\n", - " num_neg=1\n", + " num_neg=1,\n", + " batch_size=20,\n", + " resample=True,\n", + " sort=False\n", ")\n", "testset = mz.dataloader.Dataset(\n", - " data_pack=test_pack_processed\n", + " data_pack=test_pack_processed,\n", + " batch_size=20\n", ")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -169,15 +219,11 @@ "\n", "trainloader = mz.dataloader.DataLoader(\n", " dataset=trainset,\n", - " batch_size=20,\n", " stage='train',\n", - " resample=True,\n", - " sort=True,\n", " callback=padding_callback\n", ")\n", "testloader = mz.dataloader.DataLoader(\n", " dataset=testset,\n", - " batch_size=20,\n", " stage='dev',\n", " callback=padding_callback\n", ")" @@ -185,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:36:06.413530Z", @@ -198,26 +244,26 @@ "output_type": "stream", "text": [ "DRMMTKS(\n", - " (embedding): Embedding(16675, 100)\n", + " (embedding): Embedding(16675, 100, padding_idx=0)\n", " (attention): Attention(\n", " (linear): Linear(in_features=100, out_features=1, bias=False)\n", " )\n", " (mlp): Sequential(\n", " (0): Sequential(\n", " (0): Linear(in_features=10, out_features=128, bias=True)\n", - " (1): ReLU()\n", + " (1): Tanh()\n", " )\n", " (1): Sequential(\n", " (0): Linear(in_features=128, out_features=128, bias=True)\n", - " (1): ReLU()\n", + " (1): Tanh()\n", " )\n", " (2): Sequential(\n", " (0): Linear(in_features=128, out_features=128, bias=True)\n", - " (1): ReLU()\n", + " (1): Tanh()\n", " )\n", " (3): Sequential(\n", " (0): Linear(in_features=128, out_features=1, bias=True)\n", - " (1): ReLU()\n", + " (1): Tanh()\n", " )\n", " )\n", " (out): Linear(in_features=1, out_features=1, bias=True)\n", @@ -233,7 +279,7 @@ "model.params['embedding'] = embedding_matrix\n", "model.params['mask_value'] = 0\n", "model.params['top_k'] = 10\n", - "model.params['mlp_activation_func'] = 'relu'\n", + "model.params['mlp_activation_func'] = 'tanh'\n", "\n", "model.build()\n", "\n", @@ -243,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:36:06.422264Z", @@ -267,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2019-03-20T09:37:59.341616Z", @@ -278,7 +324,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a8f2a42921f84720b169df671a88f6de", + "model_id": "9ce5526ec3f048c58e1ae640f98fcdc3", "version_major": 2, "version_minor": 0 }, @@ -293,15 +339,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-102 Loss-0.790]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5187 - normalized_discounted_cumulative_gain@5(0.0): 0.5819 - mean_average_precision(0.0): 0.5332\n", + "[Iter-102 Loss-0.817]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5808 - normalized_discounted_cumulative_gain@5(0.0): 0.6451 - mean_average_precision(0.0): 0.6026\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "97b0c215b42843b29435efa15bbda747", + "model_id": "56e22d6fedcb4b759c103ab803f63d4d", "version_major": 2, "version_minor": 0 }, @@ -316,15 +362,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-204 Loss-0.490]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5987 - normalized_discounted_cumulative_gain@5(0.0): 0.6472 - mean_average_precision(0.0): 0.61\n", + "[Iter-204 Loss-0.541]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5919 - normalized_discounted_cumulative_gain@5(0.0): 0.6427 - mean_average_precision(0.0): 0.6097\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1d6a758a990543d993e540ba4e45436c", + "model_id": "e851902807b44f00aad7066af56acf53", "version_major": 2, "version_minor": 0 }, @@ -339,15 +385,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-306 Loss-0.382]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6006 - normalized_discounted_cumulative_gain@5(0.0): 0.6582 - mean_average_precision(0.0): 0.6125\n", + "[Iter-306 Loss-0.403]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6074 - normalized_discounted_cumulative_gain@5(0.0): 0.6558 - mean_average_precision(0.0): 0.6144\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d20cda7567f9453ba3c918cc81c1e9d0", + "model_id": "d514f5bf08b14ecd8f21ff7756273488", "version_major": 2, "version_minor": 0 }, @@ -362,15 +408,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-408 Loss-0.259]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.607 - normalized_discounted_cumulative_gain@5(0.0): 0.6632 - mean_average_precision(0.0): 0.6084\n", + "[Iter-408 Loss-0.300]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5891 - normalized_discounted_cumulative_gain@5(0.0): 0.6367 - mean_average_precision(0.0): 0.5983\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "aea9c003f2fc49c4bec134d951e94ce2", + "model_id": "63d2f392a26d466cb242ae3021a6ac93", "version_major": 2, "version_minor": 0 }, @@ -385,15 +431,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-510 Loss-0.148]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6156 - normalized_discounted_cumulative_gain@5(0.0): 0.6649 - mean_average_precision(0.0): 0.6207\n", + "[Iter-510 Loss-0.240]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.57 - normalized_discounted_cumulative_gain@5(0.0): 0.6332 - mean_average_precision(0.0): 0.5815\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bd0f8403661d4ded8d4543d891803b6f", + "model_id": "914aab2d97aa4f9f9ea581ef16f08d30", "version_major": 2, "version_minor": 0 }, @@ -408,15 +454,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-612 Loss-0.105]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6154 - normalized_discounted_cumulative_gain@5(0.0): 0.6655 - mean_average_precision(0.0): 0.6206\n", + "[Iter-612 Loss-0.184]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5711 - normalized_discounted_cumulative_gain@5(0.0): 0.6315 - mean_average_precision(0.0): 0.5864\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7c64b2be24704e47aec7a06e1ada1ed5", + "model_id": "3163044009064314a506b0caf062d535", "version_major": 2, "version_minor": 0 }, @@ -431,15 +477,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-714 Loss-0.088]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6083 - normalized_discounted_cumulative_gain@5(0.0): 0.6641 - mean_average_precision(0.0): 0.6185\n", + "[Iter-714 Loss-0.136]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5783 - normalized_discounted_cumulative_gain@5(0.0): 0.6402 - mean_average_precision(0.0): 0.5988\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5150bc5dabd34d76bb75dd8e37cbf198", + "model_id": "6b11570e69924c4da75946a746346cfa", "version_major": 2, "version_minor": 0 }, @@ -454,15 +500,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-816 Loss-0.059]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5778 - normalized_discounted_cumulative_gain@5(0.0): 0.6493 - mean_average_precision(0.0): 0.5995\n", + "[Iter-816 Loss-0.109]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5713 - normalized_discounted_cumulative_gain@5(0.0): 0.6242 - mean_average_precision(0.0): 0.5819\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f240e1dba2e74be5b1cab49082d720cd", + "model_id": "7690d4c4b8aa410eabdf3a9b7cf11a4b", "version_major": 2, "version_minor": 0 }, @@ -477,15 +523,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-918 Loss-0.040]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6121 - normalized_discounted_cumulative_gain@5(0.0): 0.6726 - mean_average_precision(0.0): 0.6226\n", + "[Iter-918 Loss-0.087]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5872 - normalized_discounted_cumulative_gain@5(0.0): 0.6456 - mean_average_precision(0.0): 0.5928\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b1ba7e90933442ffb94f360c52c47c9e", + "model_id": "46c95fd6416345e7b980a94b213a046f", "version_major": 2, "version_minor": 0 }, @@ -500,10 +546,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-1020 Loss-0.049]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.609 - normalized_discounted_cumulative_gain@5(0.0): 0.6678 - mean_average_precision(0.0): 0.6164\n", + "[Iter-1020 Loss-0.092]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5799 - normalized_discounted_cumulative_gain@5(0.0): 0.6441 - mean_average_precision(0.0): 0.5929\n", "\n", - "Cost time: 1233.8079171180725s\n" + "Cost time: 133.27396488189697s\n" ] } ], diff --git a/tutorials/ranking/match_pyramid.ipynb b/tutorials/ranking/match_pyramid.ipynb index 8cb3aa7..c6c81af 100644 --- a/tutorials/ranking/match_pyramid.ipynb +++ b/tutorials/ranking/match_pyramid.ipynb @@ -9,35 +9,67 @@ "name": "stdout", "output_type": "stream", "text": [ - "matchzoo version 1.0\n", - "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", - "data loading ...\n", - "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n" + "matchzoo version 1.1.1\n" ] } ], "source": [ - "%run init.ipynb" + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matchzoo as mz\n", + "print('matchzoo version', mz.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n" + ] + } + ], "source": [ "ranking_task = mz.tasks.Ranking(losses=mz.losses.RankCrossEntropyLoss(num_neg=1))\n", "ranking_task.metrics = [\n", " mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n", " mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n", " mz.metrics.MeanAveragePrecision()\n", - "]" + "]\n", + "print(\"`ranking_task` initialized with metrics\", ranking_task.metrics)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data loading ...\n", + "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n" + ] + } + ], + "source": [ + "print('data loading ...')\n", + "train_pack_raw = mz.datasets.wiki_qa.load_data('train', task=ranking_task)\n", + "dev_pack_raw = mz.datasets.wiki_qa.load_data('dev', task=ranking_task, filtered=True)\n", + "test_pack_raw = mz.datasets.wiki_qa.load_data('test', task=ranking_task, filtered=True)\n", + "print('data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "preprocessor = mz.models.MatchPyramid.get_default_preprocessor()" @@ -45,48 +77,42 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8090.96it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 4945.43it/s]\n", - "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 504384.09it/s]\n", - "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 130361.07it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 95535.75it/s]\n", - "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 685510.91it/s]\n", - "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 740786.50it/s]\n", - "Building Vocabulary from a datapack.: 100%|██████████| 418401/418401 [00:00<00:00, 2945067.59it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 7370.35it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 4462.96it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 128347.99it/s]\n", - "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 197609.52it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 67602.09it/s]\n", - "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 653537.55it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 702793.23it/s]\n", - "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 697130.65it/s]\n", - "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 923867.82it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9681.12it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5133.06it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 135257.08it/s]\n", - "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 107365.73it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 126419.84it/s]\n", - "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 196281.20it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 584288.98it/s]\n", - "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 221134.44it/s]\n", - "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 681032.32it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 9885.73it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5171.45it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 115119.50it/s]\n", - "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 154499.54it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 110933.63it/s]\n", - "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 297263.77it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 520328.98it/s]\n", - "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 364655.19it/s]\n", - "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 752429.55it/s]\n" + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8411.71it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5277.38it/s]\n", + "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 967884.69it/s]\n", + "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 165206.51it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 98960.59it/s] \n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 566949.77it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 725631.35it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 418401/418401 [00:00<00:00, 3323443.22it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 10832.68it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5802.73it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 144186.52it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 214371.04it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 130193.98it/s]\n", + "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 538787.96it/s]\n", + "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 871585.14it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9960.20it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5433.38it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 138731.80it/s]\n", + "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 109291.99it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 118971.46it/s]\n", + "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 164218.58it/s]\n", + "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 664768.86it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 10497.94it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5785.64it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 121333.96it/s]\n", + "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 148698.59it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 130630.06it/s]\n", + "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 257592.65it/s]\n", + "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 755079.77it/s]\n" ] } ], @@ -98,19 +124,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'filter_unit': ,\n", - " 'vocab_unit': ,\n", + "{'filter_unit': ,\n", + " 'vocab_unit': ,\n", " 'vocab_size': 30058,\n", " 'embedding_input_dim': 30058}" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -121,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -142,16 +168,23 @@ " data_pack=train_pack_processed,\n", " mode='pair',\n", " num_dup=2,\n", - " num_neg=1\n", + " num_neg=1,\n", + " batch_size=20,\n", + " resample=True,\n", + " sort=False,\n", + " shuffle=True\n", ")\n", "testset = mz.dataloader.Dataset(\n", - " data_pack=test_pack_processed\n", + " data_pack=test_pack_processed,\n", + " batch_size=20,\n", + " sort=False,\n", + " shuffle=False\n", ")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -159,26 +192,19 @@ "\n", "trainloader = mz.dataloader.DataLoader(\n", " dataset=trainset,\n", - " batch_size=20,\n", " stage='train',\n", - " resample=True,\n", - " sort=False,\n", - " shuffle=True,\n", " callback=padding_callback\n", ")\n", "testloader = mz.dataloader.DataLoader(\n", " dataset=testset,\n", - " batch_size=20,\n", " stage='dev',\n", - " sort=False,\n", - " shuffle=False,\n", " callback=padding_callback\n", ")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -186,7 +212,7 @@ "output_type": "stream", "text": [ "MatchPyramid(\n", - " (embedding): Embedding(30058, 300)\n", + " (embedding): Embedding(30058, 300, padding_idx=0)\n", " (matching): Matching()\n", " (conv2d): Sequential(\n", " (0): Sequential(\n", @@ -203,7 +229,8 @@ " (dpool_layer): AdaptiveAvgPool2d(output_size=[3, 10])\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (out): Linear(in_features=960, out_features=1, bias=True)\n", - ") 9023161\n" + ")\n", + "Trainable params: 9023161\n" ] } ], @@ -212,7 +239,6 @@ "\n", "model.params['task'] = ranking_task\n", "model.params['embedding'] = embedding_matrix\n", - "\n", "model.params['kernel_count'] = [16, 32]\n", "model.params['kernel_size'] = [[3, 3], [3, 3]]\n", "model.params['dpool_size'] = [3, 10]\n", @@ -220,12 +246,13 @@ "\n", "model.build()\n", "\n", - "print(model, sum(p.numel() for p in model.parameters() if p.requires_grad))" + "print(model)\n", + "print('Trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad))" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -243,13 +270,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "de0b155e4ef84f9b876313ab74808cdc", + "model_id": "f87305c8fda64edbadd14b23ff5b795b", "version_major": 2, "version_minor": 0 }, @@ -264,15 +291,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-102 Loss-0.632]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5599 - normalized_discounted_cumulative_gain@5(0.0): 0.6125 - mean_average_precision(0.0): 0.5815\n", + "[Iter-102 Loss-0.642]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4809 - normalized_discounted_cumulative_gain@5(0.0): 0.546 - mean_average_precision(0.0): 0.5082\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "589e42d46c2c48d5bbefc9c2c2fd7d0a", + "model_id": "4f9bb1db1100445196a89c3c16dee0ce", "version_major": 2, "version_minor": 0 }, @@ -287,15 +314,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-204 Loss-0.243]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6082 - normalized_discounted_cumulative_gain@5(0.0): 0.6572 - mean_average_precision(0.0): 0.6115\n", + "[Iter-204 Loss-0.314]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5282 - normalized_discounted_cumulative_gain@5(0.0): 0.5779 - mean_average_precision(0.0): 0.5469\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c3a439b98e234d22b07bc5c07d169861", + "model_id": "2b4e627c28934d5db12c3bbf7fac5d9e", "version_major": 2, "version_minor": 0 }, @@ -310,15 +337,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-306 Loss-0.032]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5497 - normalized_discounted_cumulative_gain@5(0.0): 0.6012 - mean_average_precision(0.0): 0.5497\n", + "[Iter-306 Loss-0.051]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5248 - normalized_discounted_cumulative_gain@5(0.0): 0.583 - mean_average_precision(0.0): 0.5439\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2775b8d6232443878e1262e7cdb8b210", + "model_id": "5b052335238e440c8ae4f9d2969b31ee", "version_major": 2, "version_minor": 0 }, @@ -333,15 +360,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-408 Loss-0.011]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5517 - normalized_discounted_cumulative_gain@5(0.0): 0.6107 - mean_average_precision(0.0): 0.5628\n", + "[Iter-408 Loss-0.016]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5113 - normalized_discounted_cumulative_gain@5(0.0): 0.5663 - mean_average_precision(0.0): 0.5263\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ac12f63cbb1748e7b0dc5df070c54b64", + "model_id": "37b43a7f767245ee941a9a794c40dcfd", "version_major": 2, "version_minor": 0 }, @@ -356,10 +383,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Iter-510 Loss-0.005]:\n", - " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5472 - normalized_discounted_cumulative_gain@5(0.0): 0.6119 - mean_average_precision(0.0): 0.5603\n", + "[Iter-510 Loss-0.019]:\n", + " Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4994 - normalized_discounted_cumulative_gain@5(0.0): 0.5559 - mean_average_precision(0.0): 0.5164\n", "\n", - "Cost time: 510.47750544548035s\n" + "Cost time: 111.33789110183716s\n" ] } ], @@ -370,9 +397,9 @@ ], "metadata": { "kernelspec": { - "display_name": "match-zoo", + "display_name": "Python 3", "language": "python", - "name": "match-zoo" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -384,7 +411,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.6.8" } }, "nbformat": 4,