Text classification with a transformer language model using JAX
I am currently looking into: Google Colab's fastest deep learning accelerator: v6e-1 (Trillium) TPUs!! 2x the high bandwidth memory as v5e-1 (32GB) and a whopping peak rating of 918 BF16 TFLOPS (nearly 3x A100)
I will be trying them out with the JAX & Flax example notebook
This is a fork of a JAX Stack notebook.
Reference - Find the original [here](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_transformer_text_classification.ipynb)._
Quoting:
This tutorial will demonstrate how to build a text classifier using a transformer language model using JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io), by performing sentiment analysis on movie reviews from the [IMDB dataset](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb), classifying them as positive or negative, using the steps outlined from the above reference.
- Load and preprocess the dataset.
- Define the transformer model with Flax NNX and JAX.
- Create loss and training step functions.
- Train the model.
- Evaluate the model with an example.

0 Comments:
Post a Comment
<< Home