flax_extra.data package

Data processing modules and functions.

class flax_extra.data.BytesTokenizer(reserved_ids: Mapping[str, int])[source]

Bases: object

Tokenizer mapping text strings to their UTF-8 bytes.

property d_reserved: int

a number of reserved tokens.

pad(inputs: jax._src.numpy.lax_numpy.ndarray, max_length: int) jax._src.numpy.lax_numpy.ndarray[source]

Pads the sequence up to desired length.

Parameters
  • inputs – an input sequence.

  • max_length – desired sequence length.

Returns

a padded sequence.

reserved_ids: Mapping[str, int]

reserved token identifiers.

to_ids(tokens: AnyStr) jax._src.numpy.lax_numpy.ndarray[source]

Maps UTF-8 bytes to text characters.

Parameters

tokens – a sequence of text characters.

Returns

a byte sequence.

to_tokens(ids: jax._src.numpy.lax_numpy.ndarray) str[source]

Maps text characters to UTF-8 bytes.

Parameters

ids – a byte sequence.

Returns

a sequence of text characters.

property vocab_size: int

a vocabulary size.

flax_extra.data.bytes_tokenizer(reserved_tokens: List[str]) flax_extra.data._bytes_tokenizer.BytesTokenizer[source]

Creates a tokenizer mapping text strings to their UTF-8 bytes.

Parameters

reserved_tokens – a list of reserved tokens.

Returns

a tokenizer.