Paligemma GPU Memory Calculator
An interactive tool to estimate the VRAM requirements for training and inference with Paligemma models.
Precision Guide
Numerical precision determines the number of bytes used to store each parameter in the model. Lower precision reduces memory usage and can speed up computation, but may come with a trade-off in accuracy.
Data Type |
Bytes per Parameter |
Typical Use Case |
Key Considerations |
FP32 | 4 bytes | Baseline calculations | Highest memory cost, but most numerically stable. |
BF16 | 2 bytes | Standard for training LLMs | Best balance of range and precision for stable training. |
FP16 | 2 bytes | Training & Inference | Can suffer from instability (underflow/overflow). |
FP8 | 1 byte | Accelerated Training | Requires modern hardware (e.g., H100). |
INT8 | 1 byte | Inference (Quantization) | 4x memory reduction vs FP32. Small potential accuracy loss. |
INT4 | 0.5 bytes | Inference (Quantization) | 8x memory reduction. Higher risk of accuracy degradation. |
Example Training Script
This script demonstrates a more realistic training setup for Paligemma using PyTorch, `transformers`, and the `datasets` library. The comments highlight key parameters you can change to manage performance and memory, directly relating to the concepts explained in this calculator.
import torch
from datasets import load_dataset
from transformers import (
PaliGemmaForConditionalGeneration,
AutoProcessor,
TrainingArguments,
Trainer
)
from PIL import Image
import requests
model_id = "google/paligemma-3b-pt-224"
model_precision = torch.bfloat16
use_flash_attention = True
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=model_precision,
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
device_map="auto"
)
ds = load_dataset('graphcore/gqa-tiny', split='train')
def preprocess_data(examples):
max_length = 128
try:
image_url = examples['image_url']
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
except Exception:
image = Image.new('RGB', (224, 224))
prompt = "answer " + examples['question']
inputs = processor(
text=prompt,
images=image,
return_tensors="pt",
padding='max_length',
max_length=max_length,
truncation=True
)
labels = processor(
text=examples['answer'],
return_tensors="pt",
padding='max_length',
max_length=max_length,
truncation=True
).input_ids
inputs['labels'] = labels
return inputs
processed_ds = ds.map(preprocess_data, remove_columns=ds.column_names)
training_args = TrainingArguments(
output_dir="./paligemma-finetuned-gqa",
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
optim="adamw_torch",
bf16=True,
gradient_checkpointing=True,
dataloader_num_workers=4,
num_train_epochs=1,
logging_steps=10,
save_strategy="epoch",
learning_rate=2e-5,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=processed_ds,
)
trainer.train()