PerceiverIO classificationΒΆ
An example of a classification model pretrained on images from ImageNet.
@article{Jaegle2021PerceiverIA,
title={Perceiver IO: A General Architecture for Structured Inputs \& Outputs},
author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\~a}o Carreira},
journal={ArXiv},
year={2021},
volume={abs/2107.14795}
}
@article{Deng2009ImageNetAL,
title={ImageNet: A large-scale hierarchical image database},
author={Jia Deng and Wei Dong and Richard Socher and Li-Jia Li and K. Li and Li Fei-Fei},
journal={2009 IEEE Conference on Computer Vision and Pattern Recognition},
year={2009},
pages={248-255}
}
[ ]:
%%bash
## Data preprocessing.
pip install -U imageio opencv-python
## Haiku is used to convert weights of the original model.
pip install -U dm-haiku
[2]:
from functools import partial
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra.operator import ReshapeBatch
from flax_extra.layer import (
KVQAttention,
Encoding,
Decoding,
FourierPositionEncoding,
TrainablePositionalEncoding,
)
from flax_extra.layer.io import (
input_encoding,
query_encoding,
output_decoding,
)
from flax_extra.model import PerceiverIO
from util.data import (
load_image,
normalize,
resize_and_center_crop,
LABELS,
)
from util.original_model import variables
[3]:
model = PerceiverIO(
input_encoding=input_encoding(
Encoding,
positional_encoding=partial(
FourierPositionEncoding,
seqshape=(224, 224),
n_bands=64,
),
),
encoder_query_encoding=query_encoding(
TrainablePositionalEncoding,
seqlen=512,
dimension=1024,
),
decoder_query_encoding=query_encoding(
TrainablePositionalEncoding,
seqlen=1,
dimension=1024,
),
output_decoding=output_decoding(
Decoding,
embedding_decoding=partial(
nn.Dense,
features=1000,
),
postprocessing=partial(
ReshapeBatch,
shape=(-1,)
),
),
encoder_attention=partial(KVQAttention, n_heads=1),
decoder_attention=partial(KVQAttention, n_heads=1),
use_decoder_q_residual=True,
)
model_init = model.init
model_apply = model.apply
[ ]:
%%bash
wget -cO "/tmp/perceiver_classification_fourier_position_encoding.pystate" "https://storage.googleapis.com/perceiver_io/imagenet_fourier_position_encoding.pystate"
wget -cO "/tmp/perceiver_classification_image_example.jpg" "https://storage.googleapis.com/perceiver_io/dalmation.jpg"
[5]:
rng = random.sequence(seed=0)
collections = ["params"]
initial_variables = variables("/tmp/perceiver_classification_fourier_position_encoding.pystate")
image = load_image("/tmp/perceiver_classification_image_example.jpg")
centered_image = resize_and_center_crop(image)
inputs = normalize(centered_image)[None]
outputs = model_apply(
initial_variables,
inputs=inputs,
rngs=random.into_collection(key=next(rng), labels=collections),
)
_, indices = jax.lax.top_k(outputs[0], 5)
probs = jax.nn.softmax(outputs[0])
print('Top 5 labels:')
for i in list(indices):
print(f'{LABELS[i]}: {probs[i]}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Top 5 labels:
dalmatian, coach dog, carriage dog: 0.8736159801483154
Great Dane: 0.01089583057910204
English setter: 0.002538368571549654
muzzle: 0.0010286346077919006
American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier: 0.0007839840836822987