{ "cells": [ { "cell_type": "markdown", "source": [ "# PerceiverIO classification\n", "\n", "An example of a **classification model** pretrained on images from ImageNet.\n", "\n", "```\n", "@article{Jaegle2021PerceiverIA,\n", " title={Perceiver IO: A General Architecture for Structured Inputs \\& Outputs},\n", " author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\\~a}o Carreira},\n", " journal={ArXiv},\n", " year={2021},\n", " volume={abs/2107.14795}\n", "}\n", "@article{Deng2009ImageNetAL,\n", " title={ImageNet: A large-scale hierarchical image database},\n", " author={Jia Deng and Wei Dong and Richard Socher and Li-Jia Li and K. Li and Li Fei-Fei},\n", " journal={2009 IEEE Conference on Computer Vision and Pattern Recognition},\n", " year={2009},\n", " pages={248-255}\n", "}\n", "```" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%%bash\n", "\n", "## Data preprocessing.\n", "pip install -U imageio opencv-python\n", "\n", "## Haiku is used to convert weights of the original model.\n", "pip install -U dm-haiku" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from functools import partial\n", "import jax\n", "from jax import numpy as jnp\n", "from flax import linen as nn\n", "from flax_extra import random\n", "from flax_extra.operator import ReshapeBatch\n", "from flax_extra.layer import (\n", " KVQAttention,\n", " Encoding,\n", " Decoding,\n", " FourierPositionEncoding,\n", " TrainablePositionalEncoding,\n", ")\n", "from flax_extra.layer.io import (\n", " input_encoding,\n", " query_encoding,\n", " output_decoding,\n", ")\n", "from flax_extra.model import PerceiverIO\n", "from util.data import (\n", " load_image,\n", " normalize,\n", " resize_and_center_crop,\n", " LABELS,\n", ")\n", "from util.original_model import variables" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "model = PerceiverIO(\n", " input_encoding=input_encoding(\n", " Encoding,\n", " positional_encoding=partial(\n", " FourierPositionEncoding,\n", " seqshape=(224, 224),\n", " n_bands=64,\n", " ),\n", " ),\n", " encoder_query_encoding=query_encoding(\n", " TrainablePositionalEncoding,\n", " seqlen=512,\n", " dimension=1024,\n", " ),\n", " decoder_query_encoding=query_encoding(\n", " TrainablePositionalEncoding,\n", " seqlen=1,\n", " dimension=1024,\n", " ),\n", " output_decoding=output_decoding(\n", " Decoding,\n", " embedding_decoding=partial(\n", " nn.Dense,\n", " features=1000,\n", " ),\n", " postprocessing=partial(\n", " ReshapeBatch,\n", " shape=(-1,)\n", " ),\n", " ),\n", " encoder_attention=partial(KVQAttention, n_heads=1),\n", " decoder_attention=partial(KVQAttention, n_heads=1),\n", " use_decoder_q_residual=True,\n", ")\n", "model_init = model.init\n", "model_apply = model.apply" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%%bash\n", "\n", "wget -cO \"/tmp/perceiver_classification_fourier_position_encoding.pystate\" \"https://storage.googleapis.com/perceiver_io/imagenet_fourier_position_encoding.pystate\"\n", "wget -cO \"/tmp/perceiver_classification_image_example.jpg\" \"https://storage.googleapis.com/perceiver_io/dalmation.jpg\"" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "rng = random.sequence(seed=0)\n", "collections = [\"params\"]\n", "\n", "initial_variables = variables(\"/tmp/perceiver_classification_fourier_position_encoding.pystate\")\n", "\n", "image = load_image(\"/tmp/perceiver_classification_image_example.jpg\")\n", "centered_image = resize_and_center_crop(image)\n", "inputs = normalize(centered_image)[None]\n", "\n", "outputs = model_apply(\n", " initial_variables,\n", " inputs=inputs,\n", " rngs=random.into_collection(key=next(rng), labels=collections),\n", ")\n", "\n", "_, indices = jax.lax.top_k(outputs[0], 5)\n", "probs = jax.nn.softmax(outputs[0])\n", "print('Top 5 labels:')\n", "for i in list(indices):\n", " print(f'{LABELS[i]}: {probs[i]}')" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Top 5 labels:\n", "dalmatian, coach dog, carriage dog: 0.8736159801483154\n", "Great Dane: 0.01089583057910204\n", "English setter: 0.002538368571549654\n", "muzzle: 0.0010286346077919006\n", "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier: 0.0007839840836822987\n" ] } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }