import bitsandbytes as bnb
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.pytorch_utils import Conv1D
import transformers
from transformers import LlamaConfig, LlamaForCausalLM
from transformers.integrations.bitsandbytes import replace_with_bnb_linear
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from peft.tuners.lora.config import LoraConfig
from peft.mapping import get_peft_model
from peft.utils.peft_types import *
import os
import gc
import inspect
from accelerate.utils import set_seed
from functools import partial
from pathlib import PathBase Model
save_dir = Path("profile_snapshots/")
os.makedirs(save_dir, exist_ok=True)transformers.logging.set_verbosity_warning()def malloc_in_gb():
return torch.cuda.memory_allocated()/1e9def get_model_size_config(model_size):
if model_size == "DEBUG":
model_size_config = dict(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
intermediate_size=256)
elif model_size == "60M":
model_size_config = dict(hidden_size=512,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
intermediate_size=1024)
elif model_size == "120M":
model_size_config = dict(hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
num_key_value_heads=12,
intermediate_size=1536)
elif model_size == "290M":
model_size_config = dict(hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=4096)
elif model_size == "1B":
model_size_config = dict(hidden_size=2048,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=4096)
elif model_size == "7B":
model_size_config = {}
return model_size_configdef create_model(model_size="1B"):
model_size_config = get_model_size_config(model_size)
# download model weights and config files.
config = LlamaConfig()
config.update(model_size_config)
model = LlamaForCausalLM(config)
return modeldef free_memory():
gc.collect()
torch.cuda.empty_cache()print(f"Memory allocated: {malloc_in_gb():.3f} GB")Memory allocated: 0.000 GB
# create dummy inputs
model = create_model("DEBUG")
vocab_size = model.model.embed_tokens.weight.size(0)
inputs = [torch.randint(0, vocab_size, (1, sl)) for sl in [512,1024,2048,3072]]
print(f"Memory allocated: {malloc_in_gb():.3f} GB")Memory allocated: 0.000 GB
def profile_model(create_model_func, inference=False, save_filename="mem_profile.pickle"):
"""
https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups
https://pytorch.org/docs/stable/torch_cuda_memory.html
https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d
https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html
"""
set_seed(42)
torch.cuda.memory._record_memory_history()
for x in inputs:
print(f"Input Size:{tuple(x.size())}")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if inference:
with torch.no_grad():
model = create_model_func()
model.to("cuda", torch.bfloat16);
print(f"Memory allocated [MODEL]: {malloc_in_gb():.3f} GB")
output = model(x.to("cuda"))
print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")
else:
model = create_model_func()
model.to("cuda", torch.bfloat16);
print(f"Memory allocated [MODEL): {malloc_in_gb():.3f} GB")
output = model(x.to("cuda"))
print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")
output.logits.mean().backward()
print(f"Memory allocated [BWD]: {malloc_in_gb():.3f} GB")
end.record()
torch.cuda.synchronize()
secs = start.elapsed_time(end) / 1000
print(f"Elapsed time: {secs:.3f}\n\n")
output, model = None, None
free_memory()
torch.cuda.memory._dump_snapshot(save_filename)
print(f"Memory allocated [finish]: {malloc_in_gb():.3f} GB")# warmup
profile_model(partial(create_model, "DEBUG"), inference=True, save_filename=save_dir/"debug-inference.pickle")Input Size:(1, 512)
Memory allocated [MODEL]: 0.018 GB
Memory allocated [FWD]: 0.093 GB
Elapsed time: 0.562
Input Size:(1, 1024)
Memory allocated [MODEL]: 0.027 GB
Memory allocated [FWD]: 0.160 GB
Elapsed time: 0.111
Input Size:(1, 2048)
Memory allocated [MODEL]: 0.027 GB
Memory allocated [FWD]: 0.291 GB
Elapsed time: 0.096
Input Size:(1, 3072)
Memory allocated [MODEL]: 0.027 GB
Memory allocated [FWD]: 0.425 GB
Elapsed time: 0.104
Memory allocated [finish]: 0.009 GB
profile_model(partial(create_model, "1B"), inference=True, save_filename=save_dir/"base-inference.pickle")Input Size:(1, 512)
Memory allocated [MODEL]: 2.311 GB
Memory allocated [FWD]: 2.478 GB
Elapsed time: 12.858
Input Size:(1, 1024)
Memory allocated [MODEL]: 2.311 GB
Memory allocated [FWD]: 2.645 GB
Elapsed time: 12.719
Input Size:(1, 2048)
Memory allocated [MODEL]: 2.311 GB
Memory allocated [FWD]: 2.976 GB
Elapsed time: 12.735
Input Size:(1, 3072)
Memory allocated [MODEL]: 2.311 GB
Memory allocated [FWD]: 3.322 GB
Elapsed time: 12.682
Memory allocated [finish]: 0.009 GB
# (1, 4096) OOMs with a 16GB GPU
profile_model(partial(create_model, "1B"), inference=False, save_filename=save_dir/"base-training.pickle")Input Size:(1, 512)
Memory allocated [MODEL): 2.311 GB
Memory allocated [FWD]: 3.605 GB
Memory allocated [BWD]: 4.764 GB
Elapsed time: 11.823
Input Size:(1, 1024)
Memory allocated [MODEL): 2.320 GB
Memory allocated [FWD]: 4.907 GB
Memory allocated [BWD]: 4.930 GB
Elapsed time: 12.106
Input Size:(1, 2048)
Memory allocated [MODEL): 2.320 GB
Memory allocated [FWD]: 7.493 GB
Memory allocated [BWD]: 5.260 GB
Elapsed time: 12.611
Input Size:(1, 3072)
Memory allocated [MODEL): 2.320 GB
Memory allocated [FWD]: 10.093 GB
Memory allocated [BWD]: 5.606 GB
Elapsed time: 13.033
Memory allocated [finish]: 0.017 GB
LoRA
def create_lora_model(model_size="1B", gc_enabled=False):
model_size_config = get_model_size_config(model_size)
# download model weights and config files.
config = LlamaConfig()
config.update(model_size_config)
model = LlamaForCausalLM(config)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
if gc_enabled: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
return modelprofile_model(partial(create_lora_model, "1B"), inference=True, save_filename=save_dir/"lora-inference.pickle")Input Size:(1, 512)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 2.489 GB
Elapsed time: 12.622
Input Size:(1, 1024)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 2.657 GB
Elapsed time: 12.293
Input Size:(1, 2048)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 2.988 GB
Elapsed time: 12.341
Input Size:(1, 3072)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 3.334 GB
Elapsed time: 12.339
Memory allocated [finish]: 0.017 GB
profile_model(partial(create_lora_model, "1B"), inference=False, save_filename=save_dir/"lora-training.pickle")Input Size:(1, 512)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 3.451 GB
Memory allocated [BWD]: 2.492 GB
Elapsed time: 11.359
Input Size:(1, 1024)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 4.580 GB
Memory allocated [BWD]: 2.660 GB
Elapsed time: 11.946
Input Size:(1, 2048)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 6.835 GB
Memory allocated [BWD]: 2.991 GB
Elapsed time: 12.710
Input Size:(1, 3072)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 9.105 GB
Memory allocated [BWD]: 3.337 GB
Elapsed time: 13.298
Memory allocated [finish]: 0.017 GB
LORA + Gradient Ckpt.
Using default HF grad ckpt strategy which wraps each individual decoder layers.
profile_model(partial(create_lora_model, "1B", gc_enabled=True), inference=False, save_filename=save_dir/"lora-gc-training.pickle")Input Size:(1, 512)
Memory allocated [MODEL): 2.315 GB
Memory allocated [FWD]: 2.439 GB
Memory allocated [BWD]: 2.392 GB
Elapsed time: 11.923
Input Size:(1, 1024)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 2.573 GB
Memory allocated [BWD]: 2.458 GB
Elapsed time: 12.374
Input Size:(1, 2048)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 2.820 GB
Memory allocated [BWD]: 2.588 GB
Elapsed time: 12.543
Input Size:(1, 3072)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 3.082 GB
Memory allocated [BWD]: 2.733 GB
Elapsed time: 13.120
Memory allocated [finish]: 0.017 GB
QLoRA
def replace_with_bnb_4bit_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
quant_storage=torch.uint8,
keep_trainable=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
# with init_empty_weights():
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
quant_storage=quant_storage
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
if keep_trainable:
model._modules[name].requires_grad_(True)
else:
model._modules[name].requires_grad_(True)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_bnb_4bit_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaceddef create_qlora_model(model_size="1B", with_lora=True, gc_enabled=False, keep_trainable=False):
model_size_config = get_model_size_config(model_size)
# download model weights and config files.
config = LlamaConfig()
config.update(model_size_config)
model = LlamaForCausalLM(config)
qconfig = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.bfloat16)
model, has_been_replaced = replace_with_bnb_4bit_linear(model,
modules_to_not_convert=["lm_head"],
quantization_config=qconfig,
keep_trainable=keep_trainable,
quant_storage=torch.bfloat16)
assert has_been_replaced
if with_lora:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
if gc_enabled: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
return modelprofile_model(partial(create_qlora_model, "1B"), inference=True, save_filename=save_dir/"qlora-inference.pickle")Input Size:(1, 512)
Memory allocated [MODEL]: 0.859 GB
Memory allocated [FWD]: 1.034 GB
Elapsed time: 19.783
Input Size:(1, 1024)
Memory allocated [MODEL]: 0.868 GB
Memory allocated [FWD]: 1.201 GB
Elapsed time: 17.461
Input Size:(1, 2048)
Memory allocated [MODEL]: 0.868 GB
Memory allocated [FWD]: 1.532 GB
Elapsed time: 17.779
Input Size:(1, 3072)
Memory allocated [MODEL]: 0.868 GB
Memory allocated [FWD]: 1.878 GB
Elapsed time: 17.819
Memory allocated [finish]: 0.009 GB
profile_model(partial(create_qlora_model, "1B"), inference=False, save_filename=save_dir/"qlora-training.pickle")Input Size:(1, 512)
Memory allocated [MODEL): 0.868 GB
Memory allocated [FWD]: 2.195 GB
Memory allocated [BWD]: 1.295 GB
Elapsed time: 17.303
Input Size:(1, 1024)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 3.532 GB
Memory allocated [BWD]: 1.712 GB
Elapsed time: 17.051
Input Size:(1, 2048)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 6.185 GB
Memory allocated [BWD]: 2.542 GB
Elapsed time: 17.963
Input Size:(1, 3072)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 8.853 GB
Memory allocated [BWD]: 3.387 GB
Elapsed time: 18.167
Memory allocated [finish]: 0.017 GB
QLORA + Gradient Ckpt.
Using default HF grad ckpt strategy which wraps each individual decoder layer.
profile_model(partial(create_qlora_model, "1B", gc_enabled=True), inference=False, save_filename=save_dir/"qlora-gc-training.pickle")Input Size:(1, 512)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 1.250 GB
Memory allocated [BWD]: 1.194 GB
Elapsed time: 17.265
Input Size:(1, 1024)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 1.625 GB
Memory allocated [BWD]: 1.511 GB
Elapsed time: 16.252
Input Size:(1, 2048)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 2.371 GB
Memory allocated [BWD]: 2.140 GB
Elapsed time: 17.468
Input Size:(1, 3072)
Memory allocated [MODEL): 0.876 GB
Memory allocated [FWD]: 3.133 GB
Memory allocated [BWD]: 2.783 GB
Elapsed time: 18.704
Memory allocated [finish]: 0.017 GB
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
for x in inputs:
set_seed(42)
model = create_qlora_model("1B", gc_enabled=True)
model.to("cuda", torch.bfloat16);
with torch.no_grad():
model(inputs[0].to("cuda"))
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda.reset_peak_memory_stats()
print(f"Memory allocated [MODEL): {malloc_in_gb():.3f} GB")
output = model(x.to("cuda"))
print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")
output.logits.mean().backward()
print(f"Memory allocated [BWD]: {malloc_in_gb():.3f} GB")
max_memory = torch.cuda.memory.max_memory_allocated()/1e9
print(f"Max MemAlloc: {max_memory}")
end.record()
torch.cuda.synchronize()
secs = start.elapsed_time(end) / 1000
print(f"Elapsed time: {secs:.3f}\n\n")
output, model = None, None
free_memory()`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Memory allocated [MODEL): 0.882 GB
Memory allocated [FWD]: 1.260 GB
Memory allocated [BWD]: 1.210 GB
Max MemAlloc: 1.360229376
Elapsed time: 0.195
Memory allocated [MODEL): 0.891 GB
Memory allocated [FWD]: 1.646 GB
Memory allocated [BWD]: 1.532 GB
Max MemAlloc: 1.844102144
Elapsed time: 0.194
Memory allocated [MODEL): 0.891 GB
Memory allocated [FWD]: 2.397 GB
Memory allocated [BWD]: 2.174 GB
Max MemAlloc: 2.791502848
Elapsed time: 0.231
Memory allocated [MODEL): 0.891 GB
Memory allocated [FWD]: 3.142 GB
Memory allocated [BWD]: 2.784 GB
Max MemAlloc: 3.733136384
Elapsed time: 0.417