Source code for flax_extra.data._bytes_tokenizer

r"""Bytes tokenizer."""

from typing import AnyStr, List, Optional, Mapping
from dataclasses import dataclass
from jax import numpy as jnp
import numpy as np

Array = jnp.ndarray


[docs]@dataclass class BytesTokenizer: r"""Tokenizer mapping text strings to their UTF-8 bytes.""" reserved_ids: Mapping[str, int] r"""reserved token identifiers.""" @property def vocab_size(self) -> int: r"""a vocabulary size.""" return self.d_reserved + 256 @property def d_reserved(self) -> int: r"""a number of reserved tokens.""" return len(self.reserved_ids)
[docs] def to_tokens(self, ids: Array) -> str: r"""Maps text characters to UTF-8 bytes. Args: ids: a byte sequence. Returns: a sequence of text characters. """ ids_no_special = ids[ids >= self.d_reserved] - self.d_reserved buf: bytes = ids_no_special.astype(np.uint8).tobytes() return buf.decode("utf-8", errors="replace")
[docs] def to_ids(self, tokens: AnyStr) -> Array: r"""Maps UTF-8 bytes to text characters. Args: tokens: a sequence of text characters. Returns: a byte sequence. """ if isinstance(tokens, str): tokens = tokens.encode("utf-8") # type: ignore ids = np.frombuffer(tokens, np.uint8).astype(np.int32) return ids + self.d_reserved # type: ignore
[docs] def pad(self, inputs: Array, max_length: int) -> Array: r"""Pads the sequence up to desired length. Args: inputs: an input sequence. max_length: desired sequence length. Returns: a padded sequence. """ if "PAD" not in self.reserved_ids: raise ValueError( "Cannot pad a token sequence, because " "reserved token `PAD` is missing." ) return _pad(inputs, max_length, self.reserved_ids.get("PAD"))
[docs]def bytes_tokenizer(reserved_tokens: List[str]) -> BytesTokenizer: r"""Creates a tokenizer mapping text strings to their UTF-8 bytes. Args: reserved_tokens: a list of reserved tokens. Returns: a tokenizer. """ reserved_ids = {token: id for id, token in enumerate(reserved_tokens)} return BytesTokenizer(reserved_ids=reserved_ids)
def _pad(inputs: Array, max_length: int, pad_token: Optional[int]) -> Array: length = inputs.shape[1] if length > max_length: raise ValueError( "Cannot pad a token sequence, because " f"an example length {length} exceeds the maximum length {max_length}." ) pad_length = max_length - length padded_inputs = np.pad( inputs, pad_width=((0, 0), (0, pad_length)), constant_values=pad_token, ) return padded_inputs # type: ignore