PerceiverIO language modelingΒΆ
An example of a masked-language model pretrained using a large text corpus obtained by combining English Wikipedia and C4.
@article{Jaegle2021PerceiverIA,
title={Perceiver IO: A General Architecture for Structured Inputs \& Outputs},
author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\~a}o Carreira},
journal={ArXiv},
year={2021},
volume={abs/2107.14795}
}
@article{Raffel2020ExploringTL,
title={Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
author={Colin Raffel and Noam M. Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
journal={ArXiv},
year={2020},
volume={abs/1910.10683}
}
[1]:
%%bash
## Haiku is used to convert weights of the original model.
pip install -U dm-haiku
Requirement already satisfied: dm-haiku in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (0.0.4)
Requirement already satisfied: absl-py>=0.7.1 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (0.14.1)
Requirement already satisfied: tabulate>=0.8.9 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (0.8.9)
Requirement already satisfied: numpy>=1.18.0 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (1.19.5)
Requirement already satisfied: six in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from absl-py>=0.7.1->dm-haiku) (1.16.0)
[2]:
from typing import Any, List, Optional, Union
from functools import partial
import jax.numpy as jnp
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra import combinator as cb
from flax_extra.layer import (
FeedForward,
FeedForwardCt,
KVQAttention,
KVQAttentionCt,
SelfAttention,
SelfAttentionCt,
Encoding,
MultimodalEncodingCt,
MultimodalPositionalEncodingCt,
Decoding,
TrainablePositionalEncoding,
EmbedDecoding,
)
from flax_extra import data
from flax_extra.layer.io import (
input_encoding,
target_encoding,
query_encoding,
output_decoding,
)
from flax_extra.model import PerceiverIO
from util.original_model import variables
Array = jnp.ndarray
Precision = Any
Positions = List[int]
MaybePositions = Optional[Positions]
[3]:
class PerceiverMLM(nn.Module):
input_embedding: MultimodalEncodingCt
input_positional_encoding: Union[MultimodalPositionalEncodingCt, MultimodalEncodingCt]
encoder_query_encoding: MultimodalPositionalEncodingCt
decoder_query_encoding: MultimodalPositionalEncodingCt
n_processor_shards: int = 8
n_processor_blocks: int = 6
processor_attention: SelfAttentionCt = SelfAttention
processor_feed_forward: FeedForwardCt = FeedForward
encoder_attention: KVQAttentionCt = KVQAttention
encoder_feed_forward: FeedForwardCt = FeedForward
use_encoder_q_residual: bool = True
decoder_attention: KVQAttentionCt = KVQAttention
decoder_feed_forward: FeedForwardCt = FeedForward
use_decoder_q_residual: bool = False
deterministic: bool = True
precision: Optional[Precision] = None
@nn.compact
def __call__(
self,
inputs: Union[Array, List[Array]],
input_mask: Optional[Array] = None,
targets: Optional[Union[Array, List[Array]]] = None,
target_mask: Optional[Array] = None,
output_positions: MaybePositions = None,
) -> Array:
input_embedding = self.input_embedding(name="EmbeddingEncoder")
# Use the same vocabulary for inputs and outputs.
def shared_output_embedding():
return input_embedding
def decoder_query_encoding(use_teacher_forcing: bool) -> type:
if use_teacher_forcing:
return target_encoding(
Encoding,
preprocessing=shared_output_embedding,
aggregation=cb.add(),
positional_encoding=self.decoder_query_encoding,
)
else:
return self.decoder_query_encoding
return PerceiverIO(
input_encoding=input_encoding(
Encoding,
preprocessing=shared_output_embedding,
aggregation=cb.add(),
positional_encoding=self.input_positional_encoding,
),
encoder_query_encoding=self.encoder_query_encoding,
decoder_query_encoding=decoder_query_encoding(
use_teacher_forcing=targets is not None,
),
output_decoding=output_decoding(
Decoding,
embedding_decoding=partial(
EmbedDecoding,
embedding=shared_output_embedding().embedding,
)
),
n_processor_shards=self.n_processor_shards,
n_processor_blocks=self.n_processor_blocks,
processor_attention=self.processor_attention,
processor_feed_forward=self.processor_feed_forward,
encoder_attention=self.encoder_attention,
encoder_feed_forward=self.encoder_feed_forward,
use_encoder_q_residual=self.use_encoder_q_residual,
decoder_attention=self.decoder_attention,
decoder_feed_forward=self.decoder_feed_forward,
use_decoder_q_residual=self.use_decoder_q_residual,
deterministic=self.deterministic,
precision=self.precision,
name="PerceiverIO",
)(inputs, input_mask, targets, target_mask, output_positions)
[4]:
MAX_INPUT_LENGTH = 2048
D_INPUT = 768
tokenizer = data.bytes_tokenizer(["PAD", "BOS", "EOS", "MASK", "CLS", "SEP"])
model = PerceiverMLM(
input_embedding=partial(
nn.Embed,
num_embeddings=tokenizer.vocab_size,
features=D_INPUT,
),
input_positional_encoding=partial(
TrainablePositionalEncoding,
seqlen=MAX_INPUT_LENGTH,
dimension=D_INPUT,
),
encoder_query_encoding=query_encoding(
TrainablePositionalEncoding,
seqlen=256,
dimension=1280,
),
decoder_query_encoding=query_encoding(
TrainablePositionalEncoding,
seqlen=MAX_INPUT_LENGTH,
dimension=D_INPUT,
),
n_processor_shards=1,
n_processor_blocks=26,
processor_attention=partial(SelfAttention, n_heads=8, d_qk=256, d_v=1280),
encoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=1280),
decoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=768),
)
model_init = model.init
model_apply = model.apply
[ ]:
%%bash
wget -cO "/tmp/perceiver_mlm_bytes.pickle" "https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle"
[7]:
initial_variables = variables("/tmp/perceiver_mlm_bytes.pickle")
rng = random.sequence(seed=0)
collections = ["params"]
input_tokens = "This is an incomplete sentence where some words are missing."
input_ids = tokenizer.to_ids(input_tokens)
# Mask " missing.". Note that the model performs much better if the masked chunk starts with a space.
input_ids[51:60] = tokenizer.reserved_ids.get("MASK")
print(f"Input sequence without masked text:\n`{tokenizer.to_tokens(input_ids)}`")
inputs = tokenizer.pad(input_ids[None], max_length=MAX_INPUT_LENGTH)
input_mask = tokenizer.pad(jnp.ones_like(input_ids)[None], max_length=MAX_INPUT_LENGTH)
outputs = model_apply(
initial_variables,
inputs=inputs,
input_mask=input_mask,
rngs=random.into_collection(key=next(rng), labels=collections),
)
output_ids = outputs[0, 51:60].argmax(axis=-1)
print(f"Predicted text:\n`{tokenizer.to_tokens(output_ids)}` <- {output_ids}")
Input sequence without masked text:
`This is an incomplete sentence where some words are`
Predicted text:
` missing.` <- [ 38 115 111 121 121 111 116 109 52]
[ ]: