{ "cells": [ { "cell_type": "markdown", "source": [ "# PerceiverIO language modeling\n", "\n", "An example of a **masked-language model** pretrained using a large text corpus obtained by combining English Wikipedia and C4.\n", "\n", "```\n", "@article{Jaegle2021PerceiverIA,\n", " title={Perceiver IO: A General Architecture for Structured Inputs \\& Outputs},\n", " author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\\~a}o Carreira},\n", " journal={ArXiv},\n", " year={2021},\n", " volume={abs/2107.14795}\n", "}\n", "@article{Raffel2020ExploringTL,\n", " title={Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},\n", " author={Colin Raffel and Noam M. Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},\n", " journal={ArXiv},\n", " year={2020},\n", " volume={abs/1910.10683}\n", "}\n", "```" ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "%%bash\n", "\n", "## Haiku is used to convert weights of the original model.\n", "pip install -U dm-haiku" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: dm-haiku in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (0.0.4)\n", "Requirement already satisfied: absl-py>=0.7.1 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (0.14.1)\n", "Requirement already satisfied: tabulate>=0.8.9 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (0.8.9)\n", "Requirement already satisfied: numpy>=1.18.0 in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from dm-haiku) (1.19.5)\n", "Requirement already satisfied: six in /Users/manifest/development/github/flax-extra/env/lib/python3.9/site-packages (from absl-py>=0.7.1->dm-haiku) (1.16.0)\n" ] } ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from typing import Any, List, Optional, Union\n", "from functools import partial\n", "import jax.numpy as jnp\n", "\n", "from jax import numpy as jnp\n", "from flax import linen as nn\n", "from flax_extra import random\n", "from flax_extra import combinator as cb\n", "from flax_extra.layer import (\n", " FeedForward,\n", " FeedForwardCt,\n", " KVQAttention,\n", " KVQAttentionCt,\n", " SelfAttention,\n", " SelfAttentionCt,\n", " Encoding,\n", " MultimodalEncodingCt,\n", " MultimodalPositionalEncodingCt,\n", " Decoding,\n", " TrainablePositionalEncoding,\n", " EmbedDecoding,\n", ")\n", "from flax_extra import data\n", "from flax_extra.layer.io import (\n", " input_encoding,\n", " target_encoding,\n", " query_encoding,\n", " output_decoding,\n", ")\n", "from flax_extra.model import PerceiverIO\n", "from util.original_model import variables\n", "\n", "Array = jnp.ndarray\n", "Precision = Any\n", "Positions = List[int]\n", "MaybePositions = Optional[Positions]" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "class PerceiverMLM(nn.Module):\n", " input_embedding: MultimodalEncodingCt\n", " input_positional_encoding: Union[MultimodalPositionalEncodingCt, MultimodalEncodingCt]\n", " encoder_query_encoding: MultimodalPositionalEncodingCt\n", " decoder_query_encoding: MultimodalPositionalEncodingCt\n", " n_processor_shards: int = 8\n", " n_processor_blocks: int = 6\n", " processor_attention: SelfAttentionCt = SelfAttention\n", " processor_feed_forward: FeedForwardCt = FeedForward\n", " encoder_attention: KVQAttentionCt = KVQAttention\n", " encoder_feed_forward: FeedForwardCt = FeedForward\n", " use_encoder_q_residual: bool = True\n", " decoder_attention: KVQAttentionCt = KVQAttention\n", " decoder_feed_forward: FeedForwardCt = FeedForward\n", " use_decoder_q_residual: bool = False\n", " deterministic: bool = True\n", " precision: Optional[Precision] = None\n", "\n", " @nn.compact\n", " def __call__(\n", " self,\n", " inputs: Union[Array, List[Array]],\n", " input_mask: Optional[Array] = None,\n", " targets: Optional[Union[Array, List[Array]]] = None,\n", " target_mask: Optional[Array] = None,\n", " output_positions: MaybePositions = None,\n", " ) -> Array:\n", " input_embedding = self.input_embedding(name=\"EmbeddingEncoder\")\n", "\n", " # Use the same vocabulary for inputs and outputs.\n", " def shared_output_embedding():\n", " return input_embedding\n", "\n", " def decoder_query_encoding(use_teacher_forcing: bool) -> type:\n", " if use_teacher_forcing:\n", " return target_encoding(\n", " Encoding,\n", " preprocessing=shared_output_embedding,\n", " aggregation=cb.add(),\n", " positional_encoding=self.decoder_query_encoding,\n", " )\n", " else:\n", " return self.decoder_query_encoding\n", "\n", " return PerceiverIO(\n", " input_encoding=input_encoding(\n", " Encoding,\n", " preprocessing=shared_output_embedding,\n", " aggregation=cb.add(),\n", " positional_encoding=self.input_positional_encoding,\n", " ),\n", " encoder_query_encoding=self.encoder_query_encoding,\n", " decoder_query_encoding=decoder_query_encoding(\n", " use_teacher_forcing=targets is not None,\n", " ),\n", " output_decoding=output_decoding(\n", " Decoding,\n", " embedding_decoding=partial(\n", " EmbedDecoding,\n", " embedding=shared_output_embedding().embedding,\n", " )\n", " ),\n", " n_processor_shards=self.n_processor_shards,\n", " n_processor_blocks=self.n_processor_blocks,\n", " processor_attention=self.processor_attention,\n", " processor_feed_forward=self.processor_feed_forward,\n", " encoder_attention=self.encoder_attention,\n", " encoder_feed_forward=self.encoder_feed_forward,\n", " use_encoder_q_residual=self.use_encoder_q_residual,\n", " decoder_attention=self.decoder_attention,\n", " decoder_feed_forward=self.decoder_feed_forward,\n", " use_decoder_q_residual=self.use_decoder_q_residual,\n", " deterministic=self.deterministic,\n", " precision=self.precision,\n", " name=\"PerceiverIO\",\n", " )(inputs, input_mask, targets, target_mask, output_positions)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "MAX_INPUT_LENGTH = 2048\n", "D_INPUT = 768\n", "\n", "tokenizer = data.bytes_tokenizer([\"PAD\", \"BOS\", \"EOS\", \"MASK\", \"CLS\", \"SEP\"])\n", "model = PerceiverMLM(\n", " input_embedding=partial(\n", " nn.Embed,\n", " num_embeddings=tokenizer.vocab_size,\n", " features=D_INPUT,\n", " ),\n", " input_positional_encoding=partial(\n", " TrainablePositionalEncoding,\n", " seqlen=MAX_INPUT_LENGTH,\n", " dimension=D_INPUT,\n", " ),\n", " encoder_query_encoding=query_encoding(\n", " TrainablePositionalEncoding,\n", " seqlen=256,\n", " dimension=1280,\n", " ),\n", " decoder_query_encoding=query_encoding(\n", " TrainablePositionalEncoding,\n", " seqlen=MAX_INPUT_LENGTH,\n", " dimension=D_INPUT,\n", " ),\n", " n_processor_shards=1,\n", " n_processor_blocks=26,\n", " processor_attention=partial(SelfAttention, n_heads=8, d_qk=256, d_v=1280),\n", " encoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=1280),\n", " decoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=768),\n", ")\n", "model_init = model.init\n", "model_apply = model.apply" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "%%bash\n", "\n", "wget -cO \"/tmp/perceiver_mlm_bytes.pickle\" \"https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle\"" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "initial_variables = variables(\"/tmp/perceiver_mlm_bytes.pickle\")\n", "rng = random.sequence(seed=0)\n", "collections = [\"params\"]\n", "\n", "input_tokens = \"This is an incomplete sentence where some words are missing.\"\n", "input_ids = tokenizer.to_ids(input_tokens)\n", "# Mask \" missing.\". Note that the model performs much better if the masked chunk starts with a space.\n", "input_ids[51:60] = tokenizer.reserved_ids.get(\"MASK\")\n", "print(f\"Input sequence without masked text:\\n`{tokenizer.to_tokens(input_ids)}`\")\n", "inputs = tokenizer.pad(input_ids[None], max_length=MAX_INPUT_LENGTH)\n", "input_mask = tokenizer.pad(jnp.ones_like(input_ids)[None], max_length=MAX_INPUT_LENGTH)\n", "\n", "outputs = model_apply(\n", " initial_variables,\n", " inputs=inputs,\n", " input_mask=input_mask,\n", " rngs=random.into_collection(key=next(rng), labels=collections),\n", ")\n", "\n", "output_ids = outputs[0, 51:60].argmax(axis=-1)\n", "print(f\"Predicted text:\\n`{tokenizer.to_tokens(output_ids)}` <- {output_ids}\")" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Input sequence without masked text:\n", "`This is an incomplete sentence where some words are`\n", "Predicted text:\n", "` missing.` <- [ 38 115 111 121 121 111 116 109 52]\n" ] } ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [], "outputs": [], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }