Skip to content

Latest commit

 

History

History
325 lines (229 loc) · 6.3 KB

SingleShotMemorySampler.md

File metadata and controls

325 lines (229 loc) · 6.3 KB

TFSimilarity.samplers.SingleShotMemorySampler

Base object for fitting to a sequence of data, such as a dataset.

TFSimilarity.samplers.SingleShotMemorySampler(
    x,
    examples_per_batch: int,
    num_augmentations_per_example: int = 2,
    steps_per_epoch: int = 1000,
    warmup: int = -1
) -> None

Every Sequence must implement the getitem and the len methods. If you want to modify your dataset between epochs you may implement on_epoch_end. The method getitem should return a complete batch.

Notes:

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

# and y_set are the associated classes.

class CIFAR10Sequence(Sequence):

def __init__(self, x_set, y_set, batch_size):
    self.x, self.y = x_set, y_set
    self.batch_size = batch_size

def __len__(self):
    return math.ceil(len(self.x) / self.batch_size)

def __getitem__(self, idx):
    batch_x = self.x[idx * self.batch_size:(idx + 1) *
    self.batch_size]
    batch_y = self.y[idx * self.batch_size:(idx + 1) *
    self.batch_size]

    return np.array([
        resize(imread(file_name), (200, 200))
           for file_name in batch_x]), np.array(batch_y)

<!-- Tabular view -->
 <table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Args</h2></th></tr>

<tr>
<td>
<b>x</b>
</td>
<td>
Input data. The sampler assumes that each element of X is from a
distinct class.
</td>
</tr><tr>
<td>
<b>augmenter</b>
</td>
<td>
A function that takes a batch of single examples and
return a batch out with additional examples per class.
</td>
</tr><tr>
<td>
<b>steps_per_epoch</b>
</td>
<td>
How many steps/batch per epoch. Defaults to 1000.
</td>
</tr><tr>
<td>
<b>examples_per_batch</b>
</td>
<td>
effectively the number of element to pass to
the augmenter for each batch request in the single shot setting.
</td>
</tr><tr>
<td>
<b>num_augmentations_per_example</b>
</td>
<td>
how many augmented examples must be
returned by the augmenter for each example. The augmenter is
responsible to decide if one of those is the original or not.
</td>
</tr><tr>
<td>
<b>warmup</b>
</td>
<td>
Keep track of warmup epochs and let the augmenter knows
when the warmup is over by passing along with each batch data a
boolean <b>is_warmup</b>. See <b>self._get_examples()</b> Defaults to 0.
</td>
</tr>
</table>





<!-- Tabular view -->
 <table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Attributes</h2></th></tr>

<tr>
<td>
<b>example_shape</b>
</td>
<td>

</td>
</tr><tr>
<td>
<b>num_examples</b>
</td>
<td>

</td>
</tr>
</table>



## Methods

<h3 id="generate_batch">generate_batch</h3>

<a target="_blank" class="external" href="https://github.com/tensorflow/similarity/blob/main/tensorflow_similarity/samplers/samplers.py#L137-L154">View source</a>

```python
generate_batch(
    batch_id: int
) -> Tuple[Batch, Batch]

Generate a batch of data.

Args
batch_id ([type]): [description]
Returns
x, y: Batch

get_slice

View source

get_slice(
    begin: int = 0, size: int = -1
) -> Tuple[<a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor``<b>
</a>, <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor</b>``
</a>]

Extracts an augmented slice over both the x and y tensors.

This method extracts a slice of size size over the first dimension of both the x and y tensors starting at the index specified by begin.

The value of begin + size must be less than self.num_examples.

Args
begin The starting index.
size The size of the slice.
Returns
A Tuple of FloatTensor and IntTensor

on_epoch_end

View source

on_epoch_end() -> None

Keep track of warmup epochs

__getitem__

View source

__getitem__(
    batch_id: int
) -> Tuple[Batch, Batch]

Gets batch at position index.

Args
index position of the batch in the Sequence.
Returns
A batch

__iter__

__iter__()

Create a generator that iterate over the Sequence.

__len__

View source

__len__() -> int

Return the number of batch per epoch