Supervised and unsupervised learning

The implementation of the training loop allow using loss functions that accept arbitrary number of arguments (model outputs and targets) making possible supervised and unsupervised learning.

[1]:
import jax
from jax import numpy as jnp
import optax
import redex
from flax import linen as nn
from flax_extra import random
from flax_extra.training import TrainLoop, TrainTask
[2]:
class PseudoModel(nn.Module):
    n_outputs: int

    @nn.compact
    def __call__(self, *inputs):
        outputs = nn.Dense(features=1)(inputs)
        outputs = redex.util.squeeze_tuple(tuple([outputs] * self.n_outputs))
        return outputs

def presudo_data_steam(n_inputs, n_targets, rnkey, shape=(4, 3), bounds=(1, 256)):
    minval, maxval = bounds
    while True:
        y = x = jax.random.uniform(
            key=rnkey,
            shape=shape,
            minval=minval,
            maxval=maxval,
        ).astype(jnp.int32)
        inputs = tuple([x] * n_inputs)
        targets = tuple([y] * n_targets)
        yield (inputs, targets)


def train_loop(n_inputs, n_targets, n_outputs, loss):
    rnkeyg = random.sequence(seed=0)
    model = PseudoModel(n_outputs=n_outputs)
    return TrainLoop(
        init=model.init,
        task=TrainTask(
            apply=model.apply,
            optimizer=optax.sgd(learning_rate=0.1),
            loss=loss,
            data=presudo_data_steam(
                n_inputs=n_inputs,
                n_targets=n_targets,
                rnkey=next(rnkeyg),
            ),
        ),
        rnkey=next(rnkeyg),
    )

Supervised learning: a single output

The most common use case for supervised learning: the model outputs a single array o1 that gets passed to the loss function along with targets arrays y1 and y2.
As an example, y1 may represent labels for training examples and y2 might be weights for these labels.
[3]:
def pseudo_loss(o1, y1, y2):
    print(
        "The loss function received:"
        f"\n\t1 model output of a shape {o1.shape}."
        f"\n\t2 targets with shapes {y1.shape} and {y2.shape}."
    )
    return 1.

_ = train_loop(n_inputs=1, n_targets=2, n_outputs=1, loss=pseudo_loss).next_step()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Total model initialization time is 0.35 seconds.
Total number of trainable weights: 4 = 16 B.

The loss function received:
        1 model output of a shape (1, 4, 1).
        2 targets with shapes (4, 3) and (4, 3).

Supervised learning: multiple outputs

That is possible to handle multiple model outputs in a loss function.

[4]:
def pseudo_loss(o1, o2, y1):
    print(
        "The loss function received:"
        f"\n\t2 outputs with shapes {o1.shape} and {o2.shape}."
        f"\n\t1 target of a shape {y1.shape}."
    )
    return 1.

_ = train_loop(n_inputs=1, n_targets=1, n_outputs=2, loss=pseudo_loss).next_step()
Total model initialization time is 0.38 seconds.
Total number of trainable weights: 4 = 16 B.

The loss function received:
        2 outputs with shapes (1, 4, 1) and (1, 4, 1).
        1 target of a shape (4, 3).

Supervised learning: multiple outputs and targets

This example demonstrates a use case with multiple model outputs and targets.

[5]:
def pseudo_loss(o1, o2, y1, y2):
    print(
        "The loss function received:"
        f"\n\t2 outputs with shapes {o1.shape} and {o2.shape}."
        f"\n\t2 targets with shapes {y1.shape} and {y2.shape}."
    )
    return 1.

_ = train_loop(n_inputs=1, n_targets=2, n_outputs=2, loss=pseudo_loss).next_step()
Total model initialization time is 0.46 seconds.
Total number of trainable weights: 4 = 16 B.

The loss function received:
        2 outputs with shapes (1, 4, 1) and (1, 4, 1).
        2 targets with shapes (4, 3) and (4, 3).

Unsupervised learning

A loss function may not require targets at all as in a case with unsupervised learning.

[6]:
def pseudo_loss(o1):
    print(
        "The loss function received:"
        f"\n\t1 model output of a shape {o1.shape}."
    )
    return 1.

_ = train_loop(n_inputs=3, n_targets=0, n_outputs=1, loss=pseudo_loss).next_step()
Total model initialization time is 0.35 seconds.
Total number of trainable weights: 4 = 16 B.

The loss function received:
        1 model output of a shape (3, 4, 1).