forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_ucf101.py
465 lines (388 loc) · 17.5 KB
/
eval_ucf101.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UCF101 linear evaluation."""
import functools
from typing import Any, Dict, Optional
from absl import app
from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import sklearn
from sklearn import preprocessing
import sklearn.linear_model
import sklearn.svm
import tensorflow as tf
import tensorflow_datasets as tfds
from mmv import config
from mmv.models import mm_embeddings
from mmv.utils import checkpoint
from mmv.utils import ucf101_dataset
flags.DEFINE_string('checkpoint_path', '~/tmp/mmv_s3d.pkl',
'The directory to load pre-trained weights from.')
flags.DEFINE_string('dataset_folder', '/tmp/ucf101',
'The directory with the ucf101 dataset.')
flags.DEFINE_integer('eval_batch_size', 1,
'The batch size for evaluation.')
flags.DEFINE_integer('train_batch_size', 16,
'The batch size for training.')
flags.DEFINE_integer('num_train_epochs', 10,
'How many epochs to collect features during training.')
flags.DEFINE_integer('num_test_windows', 10,
'How many windows to average on during test.')
flags.DEFINE_integer('min_resize', 224,
'Min value to resize images to during preprocessing.')
flags.DEFINE_integer('crop_size', 224,
'Value to resize images to during preprocessing.')
flags.DEFINE_integer('num_frames', 32,
'Number of video frames.')
flags.DEFINE_integer('stride', 2,
'Stride for video frames.')
flags.DEFINE_integer('ucf101_split', 1,
'Which split of ucf101 to use.')
FLAGS = flags.FLAGS
def get_sampling_offset(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
stride: int = 1,
seed: Optional[int] = None) -> tf.Tensor:
"""Calculates the initial offset for a sequence where all steps will fit.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: The number of timesteps we will output. If None,
deterministically start at the first frame.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame always the first frame.
stride: distance to sample between timesteps.
seed: a deterministic seed to use when sampling.
Returns:
The first index to begin sampling from. A best effort is made to provide a
starting index such that all requested steps fit within the sequence (i.e.
offset + 1 + (num_steps - 1) * stride < len(sequence)). If this is not
satisfied, the starting index is chosen randomly from the full sequence.
"""
if num_steps is None or not is_training:
return tf.constant(0)
sequence_length = tf.shape(sequence)[0]
max_offset = tf.cond(
tf.greater(sequence_length, (num_steps - 1) * stride),
lambda: sequence_length - (num_steps - 1) * stride,
lambda: sequence_length)
offset = tf.random.uniform(
(),
maxval=tf.cast(max_offset, tf.int32),
dtype=tf.int32,
seed=seed)
return offset
def sample_or_pad_sequence_indices(sequence: tf.Tensor,
num_steps: Optional[int],
is_training: bool,
repeat_sequence: bool = True,
stride: int = 1,
offset: Optional[int] = None) -> tf.Tensor:
"""Returns indices to take for sampling or padding a sequence to fixed size.
Samples num_steps from the sequence. If the sequence is shorter than
num_steps, the sequence loops. If the sequence is longer than num_steps and
is_training is True, then we seek to a random offset before sampling. If
offset is provided, we use that deterministic offset.
This method is appropriate for sampling from a tensor where you want every
timestep between a start and end time. See sample_stacked_sequence_indices for
more flexibility.
Args:
sequence: any tensor where the first dimension is timesteps.
num_steps: how many steps (e.g. frames) to take. If None, all steps from
start to end are considered and `is_training` has no effect.
is_training: A boolean indicates whether the graph is for training or not.
If False, the starting frame is deterministic.
repeat_sequence: A boolean indicates whether the sequence will repeat to
have enough steps for sampling. If False, a runtime error is thrown if
num_steps * stride is longer than sequence length.
stride: distance to sample between timesteps.
offset: a deterministic offset to use regardless of the is_training value.
Returns:
Indices to gather from the sequence Tensor to get a fixed size sequence.
"""
sequence_length = tf.shape(sequence)[0]
sel_idx = tf.range(sequence_length)
if num_steps:
if offset is None:
offset = get_sampling_offset(sequence, num_steps, is_training, stride)
if repeat_sequence:
# Repeats sequence until num_steps are available in total.
num_repeats = tf.cast(
tf.math.ceil(
tf.math.divide(
tf.cast(num_steps * stride + offset, tf.float32),
tf.cast(sequence_length, tf.float32)
)), tf.int32)
sel_idx = tf.tile(sel_idx, [num_repeats])
steps = tf.range(offset, offset + num_steps * stride, stride)
else:
steps = tf.range(0, sequence_length, stride)
return tf.gather(sel_idx, steps)
def random_sample_sequence(sequence: tf.Tensor,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Randomly sample a segment of size num_steps from a given sequence."""
indices = sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=True, # Random sample.
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=None)
indices.set_shape((num_steps,))
output = tf.gather(sequence, indices)
return output
def sample_linspace_sequence(sequence: tf.Tensor,
num_windows: int,
num_steps: int,
stride: int = 1) -> tf.Tensor:
"""Samples num_windows segments from sequence with linearly spaced offsets.
The samples are concatenated in a single Tensor in order to have the same
format structure per timestep (e.g. a single frame). If num_steps * stride is
bigger than the number of timesteps, the sequence is repeated. This function
can be used in evaluation in order to extract enough segments in order to span
the entire sequence.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_windows: Number of windows retrieved from the sequence.
num_steps: Number of steps (e.g. frames) to take.
stride: Distance to sample between timesteps.
Returns:
A single Tensor with first dimension num_windows * num_steps. The Tensor
contains the concatenated list of num_windows tensors which offsets have
been linearly spaced from input.
"""
sequence_length = tf.shape(sequence)[0]
max_offset = tf.maximum(0, sequence_length - num_steps * stride)
offsets = tf.linspace(0.0, tf.cast(max_offset, tf.float32), num_windows)
offsets = tf.cast(offsets, tf.int32)
all_indices = []
for i in range(num_windows):
all_indices.append(
sample_or_pad_sequence_indices(
sequence=sequence,
num_steps=num_steps,
is_training=False,
repeat_sequence=True, # Will repeat the sequence if request more.
stride=stride,
offset=offsets[i]))
indices = tf.concat(all_indices, axis=0)
indices.set_shape((num_windows * num_steps,))
output = tf.gather(sequence, indices)
return output
def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor:
"""Resizes frames so that min(height, width) is equal to min_resize.
This function will not do anything if the min(height, width) is already equal
to min_resize. This allows to save compute time.
Args:
frames: A Tensor of dimension [timesteps, input_h, input_w, channels].
min_resize: Minimum size of the final image dimensions.
Returns:
A Tensor of shape [timesteps, output_h, output_w, channels] of type
frames.dtype where min(output_h, output_w) = min_resize.
"""
shape = tf.shape(frames)
input_h = shape[1]
input_w = shape[2]
output_h = tf.maximum(min_resize, (input_h * min_resize) // input_w)
output_w = tf.maximum(min_resize, (input_w * min_resize) // input_h)
def resize_fn():
frames_resized = tf.image.resize(frames, (output_h, output_w))
return tf.cast(frames_resized, frames.dtype)
should_resize = tf.math.logical_or(tf.not_equal(input_w, output_w),
tf.not_equal(input_h, output_h))
frames = tf.cond(should_resize, resize_fn, lambda: frames)
return frames
def process_samples(features_dict, num_frames=32, stride=1, is_training=True,
min_resize=224, crop_size=224, num_windows=1):
"""Process video frames."""
video = features_dict['video']
if is_training:
assert num_windows == 1
video = random_sample_sequence(video, num_frames, stride)
is_flipped = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
video = tf.cond(tf.equal(is_flipped, 1),
true_fn=lambda: tf.image.flip_left_right(video),
false_fn=lambda: video)
else:
video = sample_linspace_sequence(video, num_windows, num_frames, stride)
# Resize smallest side.
video = resize_smallest(video, min_resize)
if is_training:
# Random crop.
video = tf.image.random_crop(video, [num_frames, crop_size, crop_size, 3])
else:
# Central crop.
video = tf.image.resize_with_crop_or_pad(video, crop_size, crop_size)
video = tf.cast(video, tf.float32)
video /= 255.0 # Set between [0, 1].
features_dict['video'] = video
return features_dict
def space_to_depth_batch(features_dict):
images = features_dict['video']
_, l, h, w, c = images.shape
images = tf.reshape(images, [-1, l // 2, 2, h // 2, 2, w // 2, 2, c])
images = tf.transpose(images, [0, 1, 3, 5, 2, 4, 6, 7])
images = tf.reshape(images, [-1, l // 2, h // 2, w // 2, 8 * c])
features_dict['video'] = images
return features_dict
def reshape_windows(features_dict, num_frames):
x = features_dict['video']
x = tf.reshape(x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4]))
features_dict['video'] = x
return features_dict
def compute_accuracy_metrics(pred, gt, prefix=''):
order_pred = np.argsort(pred, axis=1)
assert len(gt.shape) == len(order_pred.shape) == 2
top1_pred = order_pred[:, -1:]
top5_pred = order_pred[:, -5:]
top1_acc = np.mean(top1_pred == gt)
top5_acc = np.mean(np.max(top5_pred == gt, 1))
return {prefix + 'top1': top1_acc,
prefix + 'top5': top5_acc}
def forward_fn(images: jnp.ndarray,
audio_spectrogram: jnp.ndarray,
word_ids: jnp.ndarray,
is_training: bool,
model_config: Dict[str, Any]):
"""Forward pass of the model."""
# This should contain the pre-trained weights. We set it to zero because it
# will be loaded from the checkpoint.
language_model_vocab_size = 65536
word_embedding_dim = 300
dummy_embedding_matrix = jnp.zeros(shape=(language_model_vocab_size,
word_embedding_dim))
module = mm_embeddings.AudioTextVideoEmbedding(
**model_config,
word_embedding_matrix=dummy_embedding_matrix)
return module(images=images,
audio_spectrogram=audio_spectrogram,
word_ids=word_ids,
is_training=is_training)['vid_repr']
def main(argv):
del argv
sklearn_reg = 0.001
model_config = config.get_model_config(FLAGS.checkpoint_path)
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
forward_apply = jax.jit(functools.partial(forward.apply,
is_training=False,
model_config=model_config))
# Get the UCF101 config.
dset_config = tfds.video.ucf101.Ucf101.BUILDER_CONFIGS[FLAGS.ucf101_split]
builder = ucf101_dataset.ModUcf101(
data_dir=FLAGS.dataset_folder,
config=dset_config)
# Create the tfrecord files (no-op if already exists)
dl_config = tfds.download.DownloadConfig(verify_ssl=False)
builder.download_and_prepare(download_config=dl_config)
# Generate the training dataset.
train_ds = builder.as_dataset(split='train', shuffle_files=False)
train_ds = train_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=True,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size))
train_ds = train_ds.batch(batch_size=FLAGS.train_batch_size)
if model_config['visual_backbone'] == 's3d':
train_ds = train_ds.map(space_to_depth_batch)
train_ds = train_ds.repeat(FLAGS.num_train_epochs)
# Generate the test dataset.
test_ds = builder.as_dataset(split='test', shuffle_files=False)
test_ds = test_ds.map(lambda x: process_samples( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames, stride=FLAGS.stride, is_training=False,
min_resize=FLAGS.min_resize, crop_size=FLAGS.crop_size,
num_windows=FLAGS.num_test_windows))
test_ds = test_ds.batch(batch_size=FLAGS.eval_batch_size)
test_ds = test_ds.map(lambda x: reshape_windows( # pylint: disable=g-long-lambda
x, num_frames=FLAGS.num_frames))
if model_config['visual_backbone'] == 's3d':
test_ds = test_ds.map(space_to_depth_batch)
test_ds = test_ds.repeat(1)
pretrained_weights = checkpoint.load_checkpoint(FLAGS.checkpoint_path)
params = pretrained_weights['params']
state = pretrained_weights['state']
# Collect training samples.
audio_frames = 96
mel_filters = 40
num_tokens = 16
dummy_audio = jnp.zeros(
shape=(FLAGS.train_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.train_batch_size, num_tokens), dtype=jnp.int32)
train_features = []
train_labels = []
print('Computing features on train')
training_examples = iter(tfds.as_numpy(train_ds))
for train_ex in training_examples:
vid_representation, _ = forward_apply(params=params,
state=state,
images=train_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
train_labels.append(train_ex['label'])
train_features.append(vid_representation)
if len(train_labels) % 50 == 0:
print(f'Processed {len(train_labels)} examples.')
train_labels = np.concatenate(train_labels, axis=0)
train_features = np.concatenate(train_features, axis=0)
print(f'Finish collecting train features of shape {train_features.shape}')
# Collect test samples.
dummy_audio = jnp.zeros(
shape=(FLAGS.eval_batch_size, audio_frames, mel_filters, 1))
dummy_word_ids = jnp.zeros(
shape=(FLAGS.eval_batch_size, num_tokens), dtype=jnp.int32)
test_features = []
test_labels = []
print('Computing features on test')
test_examples = iter(tfds.as_numpy(test_ds))
for test_ex in test_examples:
vid_representation_test, _ = forward_apply(params=params,
state=state,
images=test_ex['video'],
audio_spectrogram=dummy_audio,
word_ids=dummy_word_ids)
test_labels.append(test_ex['label'])
test_features.append(vid_representation_test)
if len(test_labels) % 50 == 0:
print(f'Processed {len(test_labels)} examples.')
test_features = np.concatenate(test_features, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
print(f'Finish collecting test features of shape {test_features.shape}')
# Train classifier
print('Training linear classifier!')
classifier = sklearn.svm.LinearSVC(C=sklearn_reg)
scaler = preprocessing.StandardScaler().fit(train_features)
train_features = scaler.transform(train_features)
classifier.fit(train_features, train_labels.ravel())
print('Training done !')
# Evaluation.
test_features = scaler.transform(test_features)
print('Running inference on train')
pred_train = classifier.decision_function(train_features)
print('Running inference on test')
pred_test = classifier.decision_function(test_features)
if FLAGS.num_test_windows > 1:
pred_test = np.reshape(
pred_test, (test_labels.shape[0], -1, pred_test.shape[1]))
pred_test = pred_test.mean(axis=1)
# Compute accuracies.
metrics = compute_accuracy_metrics(pred_train, train_labels[:, None],
prefix='train_')
metrics.update(
compute_accuracy_metrics(pred_test, test_labels[:, None], prefix='test_'))
print(metrics)
if __name__ == '__main__':
app.run(main)