Base Model

from accelerate.utils import set_seed
from hqq.core.peft import PeftUtils
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.pytorch_utils import Conv1D

import transformers
from transformers import LlamaConfig, LlamaForCausalLM
from transformers.integrations.bitsandbytes import replace_with_bnb_linear
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from peft.tuners.lora.config import LoraConfig
from peft.mapping import get_peft_model
from peft.utils.peft_types import *

import os
import gc
import inspect
from accelerate.utils import set_seed
from functools import partial
from pathlib import Path
save_dir = Path("profile_snapshots/")
os.makedirs(save_dir, exist_ok=True)
transformers.logging.set_verbosity_warning()
def malloc_in_gb():
    return torch.cuda.memory_allocated()/1e9
def get_model_size_config(model_size):
    if model_size == "DEBUG":
        model_size_config = dict(hidden_size=128,
                                num_hidden_layers=2,
                                num_attention_heads=2,
                                num_key_value_heads=2,
                                intermediate_size=256)
    elif model_size == "60M":
        model_size_config = dict(hidden_size=512,
                                num_hidden_layers=4,
                                num_attention_heads=4,
                                num_key_value_heads=4,
                                intermediate_size=1024)
    elif model_size == "120M":
        model_size_config = dict(hidden_size=768,
                                num_hidden_layers=12,
                                num_attention_heads=12,
                                num_key_value_heads=12,
                                intermediate_size=1536)
    elif model_size == "290M":
        model_size_config = dict(hidden_size=1024,
                                num_hidden_layers=12,
                                num_attention_heads=16,
                                num_key_value_heads=16,
                                intermediate_size=4096)
    elif model_size == "1B":
        model_size_config = dict(hidden_size=2048,
                                num_hidden_layers=24,
                                num_attention_heads=16,
                                num_key_value_heads=16,
                                intermediate_size=4096)
    elif model_size == "7B":
        model_size_config = {}
    return model_size_config
def create_model(model_size="1B"):
    model_size_config = get_model_size_config(model_size)
    # download model weights and config files.
    config = LlamaConfig()
    config.update(model_size_config)
    model = LlamaForCausalLM(config)
    return model
def free_memory():
    gc.collect()
    torch.cuda.empty_cache()
print(f"Memory allocated: {malloc_in_gb():.3f} GB")
Memory allocated: 0.000 GB
# create dummy inputs
model = create_model("DEBUG")
vocab_size = model.model.embed_tokens.weight.size(0)
inputs = [torch.randint(0, vocab_size, (1, sl)) for sl in [512,1024,2048,3072]]
print(f"Memory allocated: {malloc_in_gb():.3f} GB")
Memory allocated: 0.000 GB
def profile_model(create_model_func, inference=False, save_filename="mem_profile.pickle"):

    """
    https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups

    https://pytorch.org/docs/stable/torch_cuda_memory.html

    https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d

    https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html
    """
    set_seed(42)
    torch.cuda.memory._record_memory_history()
    for x in inputs:
        print(f"Input Size:{tuple(x.size())}")
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        if inference:
            with torch.no_grad():
                model = create_model_func()
                model.to("cuda", torch.bfloat16);
                print(f"Memory allocated [MODEL]: {malloc_in_gb():.3f} GB")
                output = model(x.to("cuda"))
                print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")
        else:
            model = create_model_func()
            model.to("cuda", torch.bfloat16);
            print(f"Memory allocated [MODEL): {malloc_in_gb():.3f} GB")
            output = model(x.to("cuda"))
            print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")            
            output.logits.mean().backward()
            print(f"Memory allocated [BWD]: {malloc_in_gb():.3f} GB")
        end.record()
        torch.cuda.synchronize()
        secs = start.elapsed_time(end) / 1000
        print(f"Elapsed time: {secs:.3f}\n\n")
        output, model = None, None
        free_memory()
    torch.cuda.memory._dump_snapshot(save_filename)
    print(f"Memory allocated [finish]: {malloc_in_gb():.3f} GB")
# warmup
profile_model(partial(create_model, "DEBUG"), inference=True, save_filename=save_dir/"debug-inference.pickle")
Input Size:(1, 512)
Memory allocated [MODEL]: 0.051 GB
Memory allocated [FWD]: 0.125 GB
Elapsed time: 1.338


Input Size:(1, 1024)
Memory allocated [MODEL]: 0.059 GB
Memory allocated [FWD]: 0.193 GB
Elapsed time: 0.142


Input Size:(1, 2048)
Memory allocated [MODEL]: 0.059 GB
Memory allocated [FWD]: 0.324 GB
Elapsed time: 0.135


Input Size:(1, 3072)
Memory allocated [MODEL]: 0.059 GB
Memory allocated [FWD]: 0.425 GB
Elapsed time: 0.201


Memory allocated [finish]: 0.009 GB
profile_model(partial(create_model, "1B"), inference=True, save_filename=save_dir/"base-inference.pickle")
Input Size:(1, 512)
Memory allocated [MODEL]: 2.320 GB
Memory allocated [FWD]: 2.492 GB
Elapsed time: 15.401


Input Size:(1, 1024)
Memory allocated [MODEL]: 2.320 GB
Memory allocated [FWD]: 2.666 GB
Elapsed time: 14.997


Input Size:(1, 2048)
Memory allocated [MODEL]: 2.320 GB
Memory allocated [FWD]: 3.010 GB
Elapsed time: 14.370


Input Size:(1, 3072)
Memory allocated [MODEL]: 2.320 GB
Memory allocated [FWD]: 3.322 GB
Elapsed time: 14.218


Memory allocated [finish]: 0.009 GB
# (1, 4096) OOMs with a 16GB GPU
profile_model(partial(create_model, "1B"), inference=False, save_filename=save_dir/"base-training.pickle")
Input Size:(1, 512)
Memory allocated [MODEL): 2.320 GB
Memory allocated [FWD]: 3.521 GB
Memory allocated [BWD]: 4.779 GB
Elapsed time: 13.765


Input Size:(1, 1024)
Memory allocated [MODEL): 2.328 GB
Memory allocated [FWD]: 4.757 GB
Memory allocated [BWD]: 4.952 GB
Elapsed time: 13.277


Input Size:(1, 2048)
Memory allocated [MODEL): 2.328 GB
Memory allocated [FWD]: 7.283 GB
Memory allocated [BWD]: 5.294 GB
Elapsed time: 13.706


Input Size:(1, 3072)
Memory allocated [MODEL): 2.328 GB
Memory allocated [FWD]: 9.879 GB
Memory allocated [BWD]: 5.606 GB
Elapsed time: 14.203


Memory allocated [finish]: 0.017 GB

LoRA

def create_lora_model(model_size="1B", gc_enabled=False):
    model_size_config = get_model_size_config(model_size)
    # download model weights and config files.
    config = LlamaConfig()
    config.update(model_size_config)
    model = LlamaForCausalLM(config)
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)
    if gc_enabled: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    return model
profile_model(partial(create_lora_model, "1B"), inference=True, save_filename=save_dir/"lora-inference.pickle")
Input Size:(1, 512)
Memory allocated [MODEL]: 2.314 GB
Memory allocated [FWD]: 2.495 GB
Elapsed time: 17.398


Input Size:(1, 1024)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 2.669 GB
Elapsed time: 15.746


Input Size:(1, 2048)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 3.013 GB
Elapsed time: 15.481


Input Size:(1, 3072)
Memory allocated [MODEL]: 2.323 GB
Memory allocated [FWD]: 3.325 GB
Elapsed time: 15.432


Memory allocated [finish]: 0.009 GB
profile_model(partial(create_lora_model, "1B"), inference=False, save_filename=save_dir/"lora-training.pickle")
Input Size:(1, 512)
Memory allocated [MODEL): 2.323 GB
Memory allocated [FWD]: 3.363 GB
Memory allocated [BWD]: 2.507 GB
Elapsed time: 16.125


Input Size:(1, 1024)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 4.437 GB
Memory allocated [BWD]: 2.681 GB
Elapsed time: 15.417


Input Size:(1, 2048)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 6.642 GB
Memory allocated [BWD]: 3.025 GB
Elapsed time: 15.374


Input Size:(1, 3072)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 8.916 GB
Memory allocated [BWD]: 3.337 GB
Elapsed time: 15.821


Memory allocated [finish]: 0.017 GB

LoRA + Gradient Ckpt.

Using default HF grad ckpt strategy which wraps each individual decoder layers.

profile_model(partial(create_lora_model, "1B", gc_enabled=True), inference=False, save_filename=save_dir/"lora-gc-training.pickle")
Input Size:(1, 512)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 2.466 GB
Memory allocated [BWD]: 2.406 GB
Elapsed time: 15.596


Input Size:(1, 1024)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 2.594 GB
Memory allocated [BWD]: 2.479 GB
Elapsed time: 14.345


Input Size:(1, 2048)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 2.845 GB
Memory allocated [BWD]: 2.622 GB
Elapsed time: 14.974


Input Size:(1, 3072)
Memory allocated [MODEL): 2.331 GB
Memory allocated [FWD]: 3.091 GB
Memory allocated [BWD]: 2.733 GB
Elapsed time: 15.887


Memory allocated [finish]: 0.017 GB
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.

HQQ LoRA

from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, HQQBackend
from hqq.models.hf.llama import LlamaHQQ
def replace_with_hqq_linear(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
    quant_storage=torch.uint8, 
    compute_dtype=torch.bfloat16,
    keep_trainable=False,
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """    
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
                # with init_empty_weights():
                model._modules[name] = HQQLinear(module, 
                                                 quantization_config, 
                                                 del_orig=True,
                                                 compute_dtype=compute_dtype, 
                                                 device_n=torch.cuda.current_device())
                has_been_replaced = True
                # Store the module class in case we need to transpose the weight later
                model._modules[name].source_cls = type(module)
                # Force requires grad to False to avoid unexpected errors
                if keep_trainable: 
                    model._modules[name].requires_grad_(True)
                else:
                    model._modules[name].requires_grad_(False)
        if len(list(module.children())) > 0:
            _, has_been_replaced = replace_with_hqq_linear(
                module,
                modules_to_not_convert,
                current_key_name,
                quantization_config,
                has_been_replaced=has_been_replaced
            )
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced
def create_qlora_model(model_size="1B", with_lora=True, gc_enabled=False, keep_trainable=False, backend=HQQBackend.ATEN):
    
    model_size_config = get_model_size_config(model_size)

    # download model weights and config files.
    config = LlamaConfig()
    config.update(model_size_config)
    model = LlamaForCausalLM(config)
    
    quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False)
    model, has_been_replaced = replace_with_hqq_linear(model,
                                                        modules_to_not_convert=["lm_head"], 
                                                        quantization_config=quant_config, 
                                                        keep_trainable=keep_trainable, 
                                                        quant_storage=torch.bfloat16,
                                                        compute_dtype=torch.bfloat16)
    
    assert has_been_replaced
    if with_lora:
        base_lora_params = {'lora_type':'default',
                            'r':8, 
                            'lora_alpha':32, 
                            'dropout':0.1,
                            'compute_dtype':torch.bfloat16,
                            'train_dtype':torch.bfloat16}
        
        lora_params      = {'self_attn.q_proj': base_lora_params,
                            'self_attn.k_proj': base_lora_params,
                            'self_attn.v_proj': base_lora_params,
                            'self_attn.o_proj': base_lora_params,
                            'mlp.gate_proj'   : base_lora_params,
                            'mlp.up_proj'     : base_lora_params,
                            'mlp.down_proj'   : base_lora_params}
        
        PeftUtils.add_lora(model, lora_params, base_class=LlamaHQQ, verbose=True)
    if gc_enabled: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    HQQLinear.set_backend(backend)
    return model
# set_seed(42)
# model = create_qlora_model(model_size="DEBUG", with_lora=True,
#                            gc_enabled=False, keep_trainable=False, backend=HQQBackend.PYTORCH_BACKPROP_COMPILE)
# model.to(0).to(torch.bfloat16);
# x = torch.randint(0,100,(4, 128)).cuda()#.to(torch.bfloat16)
# o = model(x)
# loss = o.logits.mean()
# loss.backward()
# for n,p in model.named_parameters(): 
#     if p.requires_grad:
#         print(n, p.dtype, p.device, p.grad.norm().data)
profile_model(partial(create_qlora_model, "1B", backend=HQQBackend.ATEN), inference=True, save_filename=save_dir/"qlora-inference.pickle")
Input Size:(1, 512)
Memory allocated [MODEL]: 0.862 GB
Memory allocated [FWD]: 1.043 GB
Elapsed time: 66.540


Input Size:(1, 1024)
Memory allocated [MODEL]: 0.871 GB
Memory allocated [FWD]: 1.217 GB
Elapsed time: 65.790


Input Size:(1, 2048)
Memory allocated [MODEL]: 0.871 GB
Memory allocated [FWD]: 1.561 GB
Elapsed time: 65.778


Input Size:(1, 3072)
Memory allocated [MODEL]: 0.871 GB
Memory allocated [FWD]: 1.873 GB
Elapsed time: 65.310


Memory allocated [finish]: 0.009 GB
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 197.56it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 195.93it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 203.04it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 212.20it/s]
profile_model(partial(create_qlora_model, "1B", backend=HQQBackend.ATEN_BACKPROP), inference=False, save_filename=save_dir/"qlora-training.pickle")
Input Size:(1, 512)
Memory allocated [MODEL): 0.871 GB
Memory allocated [FWD]: 2.563 GB
Memory allocated [BWD]: 1.065 GB
Elapsed time: 65.322


Input Size:(1, 1024)
Memory allocated [MODEL): 0.879 GB
Memory allocated [FWD]: 4.289 GB
Memory allocated [BWD]: 1.238 GB
Elapsed time: 64.854


Input Size:(1, 2048)
Memory allocated [MODEL): 0.879 GB
Memory allocated [FWD]: 7.798 GB
Memory allocated [BWD]: 1.582 GB
Elapsed time: 64.948


Input Size:(1, 3072)
Memory allocated [MODEL): 0.879 GB
Memory allocated [FWD]: 11.376 GB
Memory allocated [BWD]: 1.895 GB
Elapsed time: 65.371


Memory allocated [finish]: 0.017 GB
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 217.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 208.45it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 207.33it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 207.45it/s]

QLORA + Gradient Ckpt.

Using default HF grad ckpt strategy which wraps each individual decoder layer.

profile_model(partial(create_qlora_model, "DEBUG", gc_enabled=True, backend=HQQBackend.PYTORCH_BACKPROP),
              inference=False, save_filename=save_dir/"qlora-gc-training.pickle")
# for n,p in model.named_parameters():
#     print(n, p.name, p.requires_grad)
model = create_qlora_model("DEBUG", gc_enabled=True, backend=HQQBackend.PYTORCH_BACKPROP)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 307.93it/s]
model.to("cuda", torch.bfloat16);

This is the correct timing, because this excludes model initialization and quantization.

for x in inputs:
    set_seed(42)
    model = create_qlora_model("1B", gc_enabled=True, backend=HQQBackend.ATEN_BACKPROP)
    model.to("cuda", torch.bfloat16);
    with torch.no_grad():
        model(inputs[0].to("cuda"))
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    
    torch.cuda.reset_peak_memory_stats()
    print(f"Memory allocated [MODEL): {malloc_in_gb():.3f} GB")
    output = model(x.to("cuda"))
    print(f"Memory allocated [FWD]: {malloc_in_gb():.3f} GB")            
    output.logits.mean().backward()
    print(f"Memory allocated [BWD]: {malloc_in_gb():.3f} GB")
    max_memory = torch.cuda.memory.max_memory_allocated()/1e9
    print(f"Max MemAlloc: {max_memory}")
    
    end.record()
    torch.cuda.synchronize()
    secs = start.elapsed_time(end) / 1000
    print(f"Elapsed time: {secs:.3f}\n\n")

    output, model = None, None
    free_memory()
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 193.32it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 196.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 197.02it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 138.58it/s]
Memory allocated [MODEL): 0.964 GB
Memory allocated [FWD]: 1.092 GB
Memory allocated [BWD]: 1.043 GB
Max MemAlloc: 1.190423552
Elapsed time: 0.402


Memory allocated [MODEL): 0.964 GB
Memory allocated [FWD]: 1.220 GB
Memory allocated [BWD]: 1.115 GB
Max MemAlloc: 1.417184256
Elapsed time: 0.401


Memory allocated [MODEL): 0.964 GB
Memory allocated [FWD]: 1.471 GB
Memory allocated [BWD]: 1.258 GB
Max MemAlloc: 1.865462784
Elapsed time: 0.411


Memory allocated [MODEL): 0.964 GB
Memory allocated [FWD]: 1.717 GB
Memory allocated [BWD]: 1.369 GB
Max MemAlloc: 2.307974144
Elapsed time: 0.500