PerceiverIO multimodal autoencodingΒΆ
An example of an autoencoder model pretrained on multimodal input (audio, video, and label) of the Kinetics-700-2020 dataset.
@article{Jaegle2021PerceiverIA,
title={Perceiver IO: A General Architecture for Structured Inputs \& Outputs},
author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\~a}o Carreira},
journal={ArXiv},
year={2021},
volume={abs/2107.14795}
}
@article{Smaira2020ASN,
title={A Short Note on the Kinetics-700-2020 Human Action Dataset},
author={Lucas Smaira and Jo{\~a}o Carreira and Eric Noland and Ellen Clancy and Amy Wu and Andrew Zisserman},
journal={ArXiv},
year={2020},
volume={abs/2010.10864}
}
[ ]:
%%bash
## Data preprocessing.
pip install -U imageio opencv-python scipy
## Haiku is used to convert weights of the original model.
pip install -U dm-haiku
[2]:
from functools import partial
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra import operator as op
from flax_extra.layer import (
KVQAttention,
Encoding,
Decoding,
TrainablePositionalEncoding,
FourierPositionEncoding,
)
from flax_extra.layer.io import (
input_encoding,
query_encoding,
output_decoding,
)
from flax_extra.model import PerceiverIO
from util.data import (
load_audio,
load_video,
output_positions,
LABELS,
)
from util.original_model import variables
Array = jnp.ndarray
[3]:
N_VIDEO_FRAMES = 16 # per example
N_AUDIO_FRAMES = 48000 // 25 * N_VIDEO_FRAMES # per example
N_PACKED_AUDIO_FRAMES = 16 # per sequence position
# audio is grouped in packets
VIDEO_PATCH_SIZE = 56 # per sequence position
# video is splitted into patches
N_CHUNKS = 128
N_CLASSES = 700
## Audio.
AudioPositionalEncoding = partial(
FourierPositionEncoding,
seqshape=(N_AUDIO_FRAMES // N_PACKED_AUDIO_FRAMES,),
n_bands=192,
)
AudioEncoding = partial(
Encoding,
preprocessing=partial(
op.Rearrange,
pattern="b (t dt) dc -> b t (dt dc)",
bindings=dict(dt=N_PACKED_AUDIO_FRAMES),
),
positional_encoding=AudioPositionalEncoding,
)
AudioDecoding = partial(
Decoding,
embedding_decoding=partial(
nn.Dense,
features=N_PACKED_AUDIO_FRAMES,
),
postprocessing=partial(
op.ReshapeBatch,
shape=(-1,)
),
)
## Video.
VideoPositionalEncoding = partial(
FourierPositionEncoding,
seqshape=(N_VIDEO_FRAMES, VIDEO_PATCH_SIZE, VIDEO_PATCH_SIZE),
n_bands=32,
)
VideoEncoding = partial(
Encoding,
preprocessing=partial(
op.Rearrange,
pattern="b (t dt) (h dh) (w dw) dc -> b t h w (dt dh dw dc)",
bindings=dict(dt=1, dh=4, dw=4),
),
positional_encoding=VideoPositionalEncoding,
)
VideoDecoding = partial(
Decoding,
embedding_decoding=partial(
nn.Dense,
features=3,
),
)
## Labels.
LabelPositionalEncoding = partial(
TrainablePositionalEncoding,
seqlen=1,
dimension=1024,
)
LabelEncoding = partial(
Encoding,
preprocessing=partial(
op.Rearrange,
pattern="b dc -> b 1 dc",
bindings=dict(),
),
)
LabelDecoding = partial(
Decoding,
embedding_decoding=partial(
nn.Dense,
features=N_CLASSES,
),
postprocessing=partial(
op.ReshapeBatch,
shape=(-1,)
),
)
model = PerceiverIO(
input_encoding=input_encoding(
AudioEncoding,
VideoEncoding,
LabelEncoding,
mask_rates=[0., 0., 1.],
d_reserved=4,
),
encoder_query_encoding=query_encoding(
TrainablePositionalEncoding,
seqlen=784,
dimension=512,
),
decoder_query_encoding=query_encoding(
AudioPositionalEncoding,
VideoPositionalEncoding,
LabelPositionalEncoding,
d_reserved=2,
),
output_decoding=output_decoding(
AudioDecoding,
VideoDecoding,
LabelDecoding,
multimodal_embedding_decoding=partial(
nn.Dense,
features=512,
),
),
n_processor_shards=1,
n_processor_blocks=8,
encoder_attention=partial(KVQAttention, n_heads=1),
decoder_attention=partial(KVQAttention, n_heads=1),
)
model_init = model.init
model_apply = model.apply
[ ]:
%%bash
wget -cO "/tmp/perceiver_autoencoding.pickle" "https://storage.googleapis.com/perceiver_io/video_autoencoding_checkpoint.pystate"
wget --check-certificate=quiet -cO "/tmp/perceiver_autoencoding_video_example.avi" "https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/v_ApplyEyeMakeup_g01_c01.avi"
yes | ffmpeg -i "/tmp/perceiver_autoencoding_video_example.avi" -c copy -f wav -map 0:a pcm_f32le -ar 48000 "/tmp/perceiver_autoencoding_audio_example.wav"
[7]:
def autoencode(variables, inputs, rng):
audio_outputs = None
video_outputs = None
label_outputs = None
audio_inputs, video_inputs, label_inputs = inputs
for chunk_index in range(N_CHUNKS):
audio_chunk, video_chunk, label_outputs = model_apply(
variables,
inputs=inputs,
output_positions=[
output_positions(
chunk_shape=audio_inputs.shape,
chunk_index=chunk_index,
n_chunks=N_CHUNKS,
n_frames=N_PACKED_AUDIO_FRAMES,
),
output_positions(
chunk_shape=video_inputs.shape,
chunk_index=chunk_index,
n_chunks=N_CHUNKS,
),
None,
],
rngs=random.into_collection(key=next(rng), labels=collections),
)
if audio_outputs is None:
audio_outputs = audio_chunk
else:
audio_outputs = jnp.concatenate([audio_outputs, audio_chunk], axis=1)
if video_outputs is None:
video_outputs = video_chunk
else:
video_outputs = jnp.concatenate([video_outputs, video_chunk], axis=1)
audio_outputs = jnp.reshape(audio_outputs, audio_inputs.shape)
video_outputs = jnp.reshape(video_outputs, video_inputs.shape)
return [audio_outputs, video_outputs, label_outputs]
rng = random.sequence(seed=0)
collections = ["params"]
initial_variables = variables("/tmp/perceiver_autoencoding.pickle")
audio = load_audio("/tmp/perceiver_autoencoding_audio_example.wav")
video = load_video("/tmp/perceiver_autoencoding_video_example.avi")
inputs = [
audio[None, :N_AUDIO_FRAMES, :1],
video[None, :N_VIDEO_FRAMES],
jnp.zeros((1, N_CLASSES)),
]
audio_outputs, video_outputs, label_outputs = autoencode(initial_variables, inputs, rng)
# Kinetics 700 Labels.
scores, indices = jax.lax.top_k(jax.nn.softmax(label_outputs), 5)
for score, index in zip(scores[0], indices[0]):
print("%s: %s" % (LABELS[index], score))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
trimming or shaving beard: 0.21497257
dyeing hair: 0.19800161
raising eyebrows: 0.09644758
winking: 0.09643903
playing harmonica: 0.083919466