{ "cells": [ { "cell_type": "markdown", "source": [ "# PerceiverIO multimodal autoencoding\n", "\n", "An example of an **autoencoder model** pretrained on multimodal input (audio, video, and label) of the Kinetics-700-2020 dataset.\n", "\n", "```\n", "@article{Jaegle2021PerceiverIA,\n", " title={Perceiver IO: A General Architecture for Structured Inputs \\& Outputs},\n", " author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\\~a}o Carreira},\n", " journal={ArXiv},\n", " year={2021},\n", " volume={abs/2107.14795}\n", "}\n", "@article{Smaira2020ASN,\n", " title={A Short Note on the Kinetics-700-2020 Human Action Dataset},\n", " author={Lucas Smaira and Jo{\\~a}o Carreira and Eric Noland and Ellen Clancy and Amy Wu and Andrew Zisserman},\n", " journal={ArXiv},\n", " year={2020},\n", " volume={abs/2010.10864}\n", "}\n", "```" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%%bash\n", "\n", "## Data preprocessing.\n", "pip install -U imageio opencv-python scipy\n", "\n", "## Haiku is used to convert weights of the original model.\n", "pip install -U dm-haiku" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from functools import partial\n", "import jax\n", "from jax import numpy as jnp\n", "from flax import linen as nn\n", "from flax_extra import random\n", "from flax_extra import operator as op\n", "from flax_extra.layer import (\n", " KVQAttention,\n", " Encoding,\n", " Decoding,\n", " TrainablePositionalEncoding,\n", " FourierPositionEncoding,\n", ")\n", "from flax_extra.layer.io import (\n", " input_encoding,\n", " query_encoding,\n", " output_decoding,\n", ")\n", "from flax_extra.model import PerceiverIO\n", "from util.data import (\n", " load_audio,\n", " load_video,\n", " output_positions,\n", " LABELS,\n", ")\n", "from util.original_model import variables\n", "\n", "Array = jnp.ndarray" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "N_VIDEO_FRAMES = 16 # per example\n", "N_AUDIO_FRAMES = 48000 // 25 * N_VIDEO_FRAMES # per example\n", "N_PACKED_AUDIO_FRAMES = 16 # per sequence position\n", " # audio is grouped in packets\n", "VIDEO_PATCH_SIZE = 56 # per sequence position\n", " # video is splitted into patches\n", "N_CHUNKS = 128\n", "N_CLASSES = 700\n", "\n", "## Audio.\n", "AudioPositionalEncoding = partial(\n", " FourierPositionEncoding,\n", " seqshape=(N_AUDIO_FRAMES // N_PACKED_AUDIO_FRAMES,),\n", " n_bands=192,\n", ")\n", "AudioEncoding = partial(\n", " Encoding,\n", " preprocessing=partial(\n", " op.Rearrange,\n", " pattern=\"b (t dt) dc -> b t (dt dc)\",\n", " bindings=dict(dt=N_PACKED_AUDIO_FRAMES),\n", " ),\n", " positional_encoding=AudioPositionalEncoding,\n", ")\n", "AudioDecoding = partial(\n", " Decoding,\n", " embedding_decoding=partial(\n", " nn.Dense,\n", " features=N_PACKED_AUDIO_FRAMES,\n", " ),\n", " postprocessing=partial(\n", " op.ReshapeBatch,\n", " shape=(-1,)\n", " ),\n", ")\n", "\n", "## Video.\n", "VideoPositionalEncoding = partial(\n", " FourierPositionEncoding,\n", " seqshape=(N_VIDEO_FRAMES, VIDEO_PATCH_SIZE, VIDEO_PATCH_SIZE),\n", " n_bands=32,\n", ")\n", "VideoEncoding = partial(\n", " Encoding,\n", " preprocessing=partial(\n", " op.Rearrange,\n", " pattern=\"b (t dt) (h dh) (w dw) dc -> b t h w (dt dh dw dc)\",\n", " bindings=dict(dt=1, dh=4, dw=4),\n", " ),\n", " positional_encoding=VideoPositionalEncoding,\n", ")\n", "VideoDecoding = partial(\n", " Decoding,\n", " embedding_decoding=partial(\n", " nn.Dense,\n", " features=3,\n", " ),\n", ")\n", "\n", "## Labels.\n", "LabelPositionalEncoding = partial(\n", " TrainablePositionalEncoding,\n", " seqlen=1,\n", " dimension=1024,\n", ")\n", "LabelEncoding = partial(\n", " Encoding,\n", " preprocessing=partial(\n", " op.Rearrange,\n", " pattern=\"b dc -> b 1 dc\",\n", " bindings=dict(),\n", " ),\n", ")\n", "LabelDecoding = partial(\n", " Decoding,\n", " embedding_decoding=partial(\n", " nn.Dense,\n", " features=N_CLASSES,\n", " ),\n", " postprocessing=partial(\n", " op.ReshapeBatch,\n", " shape=(-1,)\n", " ),\n", ")\n", "\n", "model = PerceiverIO(\n", " input_encoding=input_encoding(\n", " AudioEncoding,\n", " VideoEncoding,\n", " LabelEncoding,\n", " mask_rates=[0., 0., 1.],\n", " d_reserved=4,\n", " ),\n", " encoder_query_encoding=query_encoding(\n", " TrainablePositionalEncoding,\n", " seqlen=784,\n", " dimension=512,\n", " ),\n", " decoder_query_encoding=query_encoding(\n", " AudioPositionalEncoding,\n", " VideoPositionalEncoding,\n", " LabelPositionalEncoding,\n", " d_reserved=2,\n", " ),\n", " output_decoding=output_decoding(\n", " AudioDecoding,\n", " VideoDecoding,\n", " LabelDecoding,\n", " multimodal_embedding_decoding=partial(\n", " nn.Dense,\n", " features=512,\n", " ),\n", " ),\n", " n_processor_shards=1,\n", " n_processor_blocks=8,\n", " encoder_attention=partial(KVQAttention, n_heads=1),\n", " decoder_attention=partial(KVQAttention, n_heads=1),\n", ")\n", "model_init = model.init\n", "model_apply = model.apply" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%%bash\n", "\n", "wget -cO \"/tmp/perceiver_autoencoding.pickle\" \"https://storage.googleapis.com/perceiver_io/video_autoencoding_checkpoint.pystate\"\n", "wget --check-certificate=quiet -cO \"/tmp/perceiver_autoencoding_video_example.avi\" \"https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/v_ApplyEyeMakeup_g01_c01.avi\"\n", "yes | ffmpeg -i \"/tmp/perceiver_autoencoding_video_example.avi\" -c copy -f wav -map 0:a pcm_f32le -ar 48000 \"/tmp/perceiver_autoencoding_audio_example.wav\"\n" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "def autoencode(variables, inputs, rng):\n", " audio_outputs = None\n", " video_outputs = None\n", " label_outputs = None\n", " audio_inputs, video_inputs, label_inputs = inputs\n", " for chunk_index in range(N_CHUNKS):\n", " audio_chunk, video_chunk, label_outputs = model_apply(\n", " variables,\n", " inputs=inputs,\n", " output_positions=[\n", " output_positions(\n", " chunk_shape=audio_inputs.shape,\n", " chunk_index=chunk_index,\n", " n_chunks=N_CHUNKS,\n", " n_frames=N_PACKED_AUDIO_FRAMES,\n", " ),\n", " output_positions(\n", " chunk_shape=video_inputs.shape,\n", " chunk_index=chunk_index,\n", " n_chunks=N_CHUNKS,\n", " ),\n", " None,\n", " ],\n", " rngs=random.into_collection(key=next(rng), labels=collections),\n", " )\n", "\n", " if audio_outputs is None:\n", " audio_outputs = audio_chunk\n", " else:\n", " audio_outputs = jnp.concatenate([audio_outputs, audio_chunk], axis=1)\n", "\n", " if video_outputs is None:\n", " video_outputs = video_chunk\n", " else:\n", " video_outputs = jnp.concatenate([video_outputs, video_chunk], axis=1)\n", "\n", " audio_outputs = jnp.reshape(audio_outputs, audio_inputs.shape)\n", " video_outputs = jnp.reshape(video_outputs, video_inputs.shape)\n", " return [audio_outputs, video_outputs, label_outputs]\n", "\n", "rng = random.sequence(seed=0)\n", "collections = [\"params\"]\n", "initial_variables = variables(\"/tmp/perceiver_autoencoding.pickle\")\n", "\n", "audio = load_audio(\"/tmp/perceiver_autoencoding_audio_example.wav\")\n", "video = load_video(\"/tmp/perceiver_autoencoding_video_example.avi\")\n", "inputs = [\n", " audio[None, :N_AUDIO_FRAMES, :1],\n", " video[None, :N_VIDEO_FRAMES],\n", " jnp.zeros((1, N_CLASSES)),\n", "]\n", "\n", "audio_outputs, video_outputs, label_outputs = autoencode(initial_variables, inputs, rng)\n", "\n", "# Kinetics 700 Labels.\n", "scores, indices = jax.lax.top_k(jax.nn.softmax(label_outputs), 5)\n", "for score, index in zip(scores[0], indices[0]):\n", " print(\"%s: %s\" % (LABELS[index], score))" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "trimming or shaving beard: 0.21497257\n", "dyeing hair: 0.19800161\n", "raising eyebrows: 0.09644758\n", "winking: 0.09643903\n", "playing harmonica: 0.083919466\n" ] } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }