{ "cells": [ { "cell_type": "markdown", "source": [ "# Supervised and unsupervised learning" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The implementation of the training loop allow using loss functions that accept arbitrary number of arguments (model outputs and targets) making possible supervised and unsupervised learning." ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "import jax\n", "from jax import numpy as jnp\n", "import optax\n", "import redex\n", "from flax import linen as nn\n", "from flax_extra import random\n", "from flax_extra.training import TrainLoop, TrainTask" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "class PseudoModel(nn.Module):\n", " n_outputs: int\n", "\n", " @nn.compact\n", " def __call__(self, *inputs):\n", " outputs = nn.Dense(features=1)(inputs)\n", " outputs = redex.util.squeeze_tuple(tuple([outputs] * self.n_outputs))\n", " return outputs\n", "\n", "def presudo_data_steam(n_inputs, n_targets, rnkey, shape=(4, 3), bounds=(1, 256)):\n", " minval, maxval = bounds\n", " while True:\n", " y = x = jax.random.uniform(\n", " key=rnkey,\n", " shape=shape,\n", " minval=minval,\n", " maxval=maxval,\n", " ).astype(jnp.int32)\n", " inputs = tuple([x] * n_inputs)\n", " targets = tuple([y] * n_targets)\n", " yield (inputs, targets)\n", "\n", "\n", "def train_loop(n_inputs, n_targets, n_outputs, loss):\n", " rnkeyg = random.sequence(seed=0)\n", " model = PseudoModel(n_outputs=n_outputs)\n", " return TrainLoop(\n", " init=model.init,\n", " task=TrainTask(\n", " apply=model.apply,\n", " optimizer=optax.sgd(learning_rate=0.1),\n", " loss=loss,\n", " data=presudo_data_steam(\n", " n_inputs=n_inputs,\n", " n_targets=n_targets,\n", " rnkey=next(rnkeyg),\n", " ),\n", " ),\n", " rnkey=next(rnkeyg),\n", " )" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Supervised learning: a single output\n", "\n", "The most common use case for supervised learning: the model outputs a single array `o1` that gets passed to the loss function along with targets arrays `y1` and `y2`. \n", "As an example, `y1` may represent labels for training examples and `y2` might be weights for these labels." ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "def pseudo_loss(o1, y1, y2):\n", " print(\n", " \"The loss function received:\"\n", " f\"\\n\\t1 model output of a shape {o1.shape}.\"\n", " f\"\\n\\t2 targets with shapes {y1.shape} and {y2.shape}.\"\n", " )\n", " return 1.\n", "\n", "_ = train_loop(n_inputs=1, n_targets=2, n_outputs=1, loss=pseudo_loss).next_step()" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Total model initialization time is 0.35 seconds.\n", "Total number of trainable weights: 4 = 16 B.\n", "\n", "The loss function received:\n", "\t1 model output of a shape (1, 4, 1).\n", "\t2 targets with shapes (4, 3) and (4, 3).\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Supervised learning: multiple outputs\n", "\n", "That is possible to handle multiple model outputs in a loss function." ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "def pseudo_loss(o1, o2, y1):\n", " print(\n", " \"The loss function received:\"\n", " f\"\\n\\t2 outputs with shapes {o1.shape} and {o2.shape}.\"\n", " f\"\\n\\t1 target of a shape {y1.shape}.\"\n", " )\n", " return 1.\n", "\n", "_ = train_loop(n_inputs=1, n_targets=1, n_outputs=2, loss=pseudo_loss).next_step()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total model initialization time is 0.38 seconds.\n", "Total number of trainable weights: 4 = 16 B.\n", "\n", "The loss function received:\n", "\t2 outputs with shapes (1, 4, 1) and (1, 4, 1).\n", "\t1 target of a shape (4, 3).\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Supervised learning: multiple outputs and targets\n", "\n", "This example demonstrates a use case with multiple model outputs and targets." ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "def pseudo_loss(o1, o2, y1, y2):\n", " print(\n", " \"The loss function received:\"\n", " f\"\\n\\t2 outputs with shapes {o1.shape} and {o2.shape}.\"\n", " f\"\\n\\t2 targets with shapes {y1.shape} and {y2.shape}.\"\n", " )\n", " return 1.\n", "\n", "_ = train_loop(n_inputs=1, n_targets=2, n_outputs=2, loss=pseudo_loss).next_step()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total model initialization time is 0.46 seconds.\n", "Total number of trainable weights: 4 = 16 B.\n", "\n", "The loss function received:\n", "\t2 outputs with shapes (1, 4, 1) and (1, 4, 1).\n", "\t2 targets with shapes (4, 3) and (4, 3).\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Unsupervised learning\n", "\n", "A loss function may not require targets at all as in a case with unsupervised learning." ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "def pseudo_loss(o1):\n", " print(\n", " \"The loss function received:\"\n", " f\"\\n\\t1 model output of a shape {o1.shape}.\"\n", " )\n", " return 1.\n", "\n", "_ = train_loop(n_inputs=3, n_targets=0, n_outputs=1, loss=pseudo_loss).next_step()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total model initialization time is 0.35 seconds.\n", "Total number of trainable weights: 4 = 16 B.\n", "\n", "The loss function received:\n", "\t1 model output of a shape (3, 4, 1).\n" ] } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }