Test Linear4bit Memory Eff Loading

import torch
import bitsandbytes as bnb
import safetensors
from safetensors.torch import save_file
/home/paperspace/git/bitsandbytes/bitsandbytes/cuda_setup/main.py:109: UserWarning: 

================================================================================
WARNING: Manual override via BNB_CUDA_VERSION env variable detected!
BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=123
================================================================================


  warn((f'\n\n{"="*80}\n'
from bitsandbytes.nn import Linear4bit, Params4bit
import bitsandbytes.functional as F
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers import AutoConfig, AutoModelForCausalLM
import torch.nn as nn

This will test that each rank has the correct quant state and params, also compare with original weights loaded.

params_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank0.pt")
params_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_base_layer_params_rank1.pt")
quant_state_rank0 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank0.pt", map_location="cpu")
quant_state_rank1 = torch.load("../data/summoned_lora_layer0_q_proj_quant_state_rank1.pt",  map_location="cpu")
# check gathered quantized weights are same in each rank
for p1, p2 in zip(params_rank0, params_rank1):
    p1 = p1[~p1.data.isnan()]
    p2 = p2[~p2.data.isnan()]
    assert torch.allclose(p1, p2)
# check quant states are same in each rank
for k,v in quant_state_rank0.as_dict().items():
    print(k)
    if isinstance(v, torch.Tensor):
        assert torch.equal(v, quant_state_rank1.as_dict()[k])
    else:
        assert v == quant_state_rank1.as_dict()[k]
quant_type
absmax
blocksize
quant_map
dtype
shape
nested_absmax
nested_blocksize
nested_quant_map
nested_dtype
nested_offset
quantized_param = Params4bit(data=params_rank0[0], 
                               requires_grad=False, 
                               quant_state=quant_state_rank0,
                               quant_type=quant_state_rank0.quant_type,
                               quant_storage=params_rank0[0].dtype, 
                               bnb_quantized=True)
quant_state_rank0.to("cuda");
quant_state_rank0.as_dict()
{'quant_type': 'nf4',
 'absmax': tensor([230, 149,  74,  ..., 194, 175, 203], device='cuda:0',
        dtype=torch.uint8),
 'blocksize': 64,
 'quant_map': tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
          0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000]),
 'dtype': 'bfloat16',
 'shape': (8192, 8192),
 'nested_absmax': tensor([0.0736, 0.0258, 0.0224,  ..., 0.0658, 0.0902, 0.0638], device='cuda:0'),
 'nested_blocksize': 256,
 'nested_quant_map': tensor([-9.9297e-01, -9.7891e-01, -9.6484e-01, -9.5078e-01, -9.3672e-01,
         -9.2266e-01, -9.0859e-01, -8.9453e-01, -8.8047e-01, -8.6641e-01,
         -8.5234e-01, -8.3828e-01, -8.2422e-01, -8.1016e-01, -7.9609e-01,
         -7.8203e-01, -7.6797e-01, -7.5391e-01, -7.3984e-01, -7.2578e-01,
         -7.1172e-01, -6.9766e-01, -6.8359e-01, -6.6953e-01, -6.5547e-01,
         -6.4141e-01, -6.2734e-01, -6.1328e-01, -5.9922e-01, -5.8516e-01,
         -5.7109e-01, -5.5703e-01, -5.4297e-01, -5.2891e-01, -5.1484e-01,
         -5.0078e-01, -4.8672e-01, -4.7266e-01, -4.5859e-01, -4.4453e-01,
         -4.3047e-01, -4.1641e-01, -4.0234e-01, -3.8828e-01, -3.7422e-01,
         -3.6016e-01, -3.4609e-01, -3.3203e-01, -3.1797e-01, -3.0391e-01,
         -2.8984e-01, -2.7578e-01, -2.6172e-01, -2.4766e-01, -2.3359e-01,
         -2.1953e-01, -2.0547e-01, -1.9141e-01, -1.7734e-01, -1.6328e-01,
         -1.4922e-01, -1.3516e-01, -1.2109e-01, -1.0703e-01, -9.8594e-02,
         -9.5781e-02, -9.2969e-02, -9.0156e-02, -8.7344e-02, -8.4531e-02,
         -8.1719e-02, -7.8906e-02, -7.6094e-02, -7.3281e-02, -7.0469e-02,
         -6.7656e-02, -6.4844e-02, -6.2031e-02, -5.9219e-02, -5.6406e-02,
         -5.3594e-02, -5.0781e-02, -4.7969e-02, -4.5156e-02, -4.2344e-02,
         -3.9531e-02, -3.6719e-02, -3.3906e-02, -3.1094e-02, -2.8281e-02,
         -2.5469e-02, -2.2656e-02, -1.9844e-02, -1.7031e-02, -1.4219e-02,
         -1.1406e-02, -9.7187e-03, -9.1562e-03, -8.5938e-03, -8.0312e-03,
         -7.4687e-03, -6.9063e-03, -6.3437e-03, -5.7813e-03, -5.2188e-03,
         -4.6562e-03, -4.0937e-03, -3.5312e-03, -2.9687e-03, -2.4062e-03,
         -1.8438e-03, -1.2812e-03, -9.4375e-04, -8.3125e-04, -7.1875e-04,
         -6.0625e-04, -4.9375e-04, -3.8125e-04, -2.6875e-04, -1.5625e-04,
         -8.8750e-05, -6.6250e-05, -4.3750e-05, -2.1250e-05, -7.7500e-06,
         -3.2500e-06, -5.5000e-07,  0.0000e+00,  5.5000e-07,  3.2500e-06,
          7.7500e-06,  2.1250e-05,  4.3750e-05,  6.6250e-05,  8.8750e-05,
          1.5625e-04,  2.6875e-04,  3.8125e-04,  4.9375e-04,  6.0625e-04,
          7.1875e-04,  8.3125e-04,  9.4375e-04,  1.2812e-03,  1.8438e-03,
          2.4062e-03,  2.9687e-03,  3.5312e-03,  4.0937e-03,  4.6562e-03,
          5.2188e-03,  5.7813e-03,  6.3437e-03,  6.9063e-03,  7.4687e-03,
          8.0312e-03,  8.5938e-03,  9.1562e-03,  9.7187e-03,  1.1406e-02,
          1.4219e-02,  1.7031e-02,  1.9844e-02,  2.2656e-02,  2.5469e-02,
          2.8281e-02,  3.1094e-02,  3.3906e-02,  3.6719e-02,  3.9531e-02,
          4.2344e-02,  4.5156e-02,  4.7969e-02,  5.0781e-02,  5.3594e-02,
          5.6406e-02,  5.9219e-02,  6.2031e-02,  6.4844e-02,  6.7656e-02,
          7.0469e-02,  7.3281e-02,  7.6094e-02,  7.8906e-02,  8.1719e-02,
          8.4531e-02,  8.7344e-02,  9.0156e-02,  9.2969e-02,  9.5781e-02,
          9.8594e-02,  1.0703e-01,  1.2109e-01,  1.3516e-01,  1.4922e-01,
          1.6328e-01,  1.7734e-01,  1.9141e-01,  2.0547e-01,  2.1953e-01,
          2.3359e-01,  2.4766e-01,  2.6172e-01,  2.7578e-01,  2.8984e-01,
          3.0391e-01,  3.1797e-01,  3.3203e-01,  3.4609e-01,  3.6016e-01,
          3.7422e-01,  3.8828e-01,  4.0234e-01,  4.1641e-01,  4.3047e-01,
          4.4453e-01,  4.5859e-01,  4.7266e-01,  4.8672e-01,  5.0078e-01,
          5.1484e-01,  5.2891e-01,  5.4297e-01,  5.5703e-01,  5.7109e-01,
          5.8516e-01,  5.9922e-01,  6.1328e-01,  6.2734e-01,  6.4141e-01,
          6.5547e-01,  6.6953e-01,  6.8359e-01,  6.9766e-01,  7.1172e-01,
          7.2578e-01,  7.3984e-01,  7.5391e-01,  7.6797e-01,  7.8203e-01,
          7.9609e-01,  8.1016e-01,  8.2422e-01,  8.3828e-01,  8.5234e-01,
          8.6641e-01,  8.8047e-01,  8.9453e-01,  9.0859e-01,  9.2266e-01,
          9.3672e-01,  9.5078e-01,  9.6484e-01,  9.7891e-01,  9.9297e-01,
          1.0000e+00], device='cuda:0'),
 'nested_dtype': 'float32',
 'nested_offset': 0.03480497747659683}
data = params_rank0[0].data.to("cuda")
dequantized_weight = F.dequantize_4bit(data, quant_state_rank0)
# put here the model name used to save the summoned weights
model_name = "codellama/CodeLlama-34b-hf"
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
orig_weight = None
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        if name == "model.layers.0.self_attn.q_proj.weight":
            orig_weight = param
            break
# some devation is expected from dequantization
# Taken from : peft/tests/.../test_4bit_merge_and_disable_lora - Stricter tolerance values needed?
assert torch.allclose(dequantized_weight.cpu(), orig_weight, atol=0.01, rtol=10)