Using TPU to speedup and recreate GPT-2 result
This project was inspired by Andrej Karparthy’s video. Here I stick to the GPT-2 and GPT-3 paper to reproduce the model and techniques. I applied gradient accumulation, distributed data parallel (GPU and TPU), half-precision, flash attention, and nice numbers (the number that can be divided by 2 the most). I trained on FineWeb (EDU), which is the same dataset that GPT -2 has trained on. For evaluation, I used a different dataset and a HellaSwag for comparison to the GPT-2 paper.
But if you don’t have powerful GPUs or have money for GPU rental. We still can achieve GPT-2 124M performance with TPU on Kaggle! But we have some problems to solve.
In Kaggle we only have 40GB of disk memory. If we use the saving and loading data technique in Andrej Karpathy’s videos, we end up running out of disk space before the training stuff begins.
Solution: Implemented streaming techniques from the datasets library, allowing for efficient data handling during training and evaluation.
GPU T4 doesn’t support BF16. If we use float16 the loss will increase.
Using float32 causes pain in the neck when it slows down the training process very much. We achieved 7.400 tokens/sec. For comparison, Andrej Karparthy achieved 1.242.000 tokens/sec. After 12 hours of training, we reach the 295/19073 step. Not even close!
Solution: TPU supports BF16 and it is even faster!
After applying TPU, BF16, and some other TPU optimization and running for 18 hours I finally surpassed GPT-2 124M model with validation loss 3.2754 over 3.2924 and HellaSwag evaluation 0.2962 over 0.294463! We achieved 243.000 tokens/sec meaning that we sped up the training by 243.000 / 7.400 = 33 times compared to GPU T4 x2!
Here are some fun text examples that generated from my model:
For a detailed exploration of the code, datasets, and methods, you can view this Kaggle Notebook.
This project demonstrates the feasibility of training large language models like GPT-2 on Kaggle’s TPU, overcoming hardware limitations and achieving significant performance improvements.