FSDP

import time
import torch
import torch.nn as nn
import hqq_aten
from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig, HQQBackend
hqq_aten package available. Set backend to HQQBackend.ATEN for faster inference and HQQBackend.ATEN_BACKPROP for faster training!
from typing import List
from torch import Tensor
from torch.nn import functional as F
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
import safetensors
from fastcore.parallel import parallel
# Optionally use the context manager to ensure one of the fused kernels is run
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(True, False, False):
    F.scaled_dot_product_attention(query,key,value)
set_seed(42)
m = torch.nn.Linear(16,128)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)
hqq_linear = HQQLinear(m, quant_config=quant_config)
hqq_linear.compute_dtype
torch.float16
next(hqq_linear.parameters())
Parameter containing:
tensor([[-1.8690e+31, -1.7469e-07, -9.8312e-20,  4.3347e+23, -1.0372e-23,
         -5.6423e+16,  1.3304e-05,  6.1785e-24],
        [-5.7602e+10,  5.1494e+18, -1.7353e+27, -7.9082e-32,  8.7318e+06,
         -4.3186e-06,  1.4261e-18,  3.5633e+17],
        [ 2.8733e-02, -6.6121e-15,  4.6052e-22, -5.8633e+18,  1.6486e+06,
          1.2226e-18,  9.0436e+25,  5.9841e-04],
        [ 6.3572e-37,  2.1430e-10,  5.6341e-01, -5.9994e-36,  1.9233e+11,
          2.9263e-09,  3.3071e-09,  1.0180e-20],
        [-1.0810e-13,  8.8023e+08,  6.2707e+18,  1.3579e-24, -4.7377e+23,
          3.5615e+17,  2.6324e-14,  4.2122e-09],
        [ 2.4662e-25, -3.4900e+27,  9.6193e+29,  2.6624e+03,  2.2651e-29,
          3.0514e+14,  6.9221e+30,  1.6402e+19],
        [ 7.4646e+22, -9.6859e-28, -4.3350e-10,  5.1519e-34, -4.1487e-07,
         -7.7171e+37,  9.2547e+13,  8.3544e+23],
        [-1.6869e-09, -2.6847e+18, -8.0041e-29,  9.5645e-38,  1.3935e-02,
         -1.4938e-13,  1.0959e-11,  1.0414e-32],
        [-3.7106e-07,  1.6020e-09,  5.3166e+36,  1.1653e-30,  5.6269e+17,
          1.7686e-32,  2.3617e+02, -4.2526e+28],
        [ 1.7555e+13,  7.6786e-05,  9.5206e+14,  4.9653e-02, -2.7269e-24,
         -1.1017e-01, -4.1573e-16, -4.8174e-23],
        [-2.9936e+07,  1.9641e-36, -8.3284e-35,  1.8591e-26,  1.4642e+25,
          5.6287e-28,  7.7592e+09, -5.0669e+06],
        [-1.8897e-21, -2.0112e+20,  4.7147e+34,  9.6051e-25, -5.1717e+05,
          9.1546e+00,  5.4721e-24, -1.5698e+24],
        [ 1.0694e+16,  5.4373e+04,  1.2801e-03,  4.4126e-09, -1.2773e-35,
          3.7246e+07,  3.6701e+15,  6.3485e+06],
        [ 2.6589e-09, -2.5449e+06,  9.6047e-39,  4.2585e+20, -1.7479e+02,
         -4.3529e-26, -1.1987e+24, -1.1508e+25],
        [ 4.6449e-32, -1.5308e-26,  3.9841e-18,  1.1292e-21,  3.8489e-08,
         -2.8361e+01, -3.1611e+09, -2.5271e-27],
        [-9.7359e-24,  2.7734e+28, -4.8315e-12,  3.0113e+32,  3.9759e+09,
         -8.1162e+25,  1.6537e+08,  7.9032e-37],
        [ 3.6381e-26,  1.4493e+38, -2.5790e+05, -2.4838e-34,  1.4304e+06,
         -1.1399e-36, -2.0599e+23, -4.4556e-23],
        [-4.8743e+26, -3.2384e-06,  8.0767e-16, -6.6715e+24,  3.5411e-24,
          3.4405e+07,  4.9961e-37,  7.5914e+18],
        [ 4.9612e+04, -1.9965e+25,  2.3972e+35, -9.3756e+10,  1.6764e-25,
         -3.3598e-22,  3.7503e+10,  3.1760e+21],
        [ 2.4561e-08,  1.1222e+35, -1.7132e+34,  4.8265e-19, -5.3818e-17,
          4.3160e+01,  1.5106e+13,  4.2396e+25],
        [-8.7586e+18,  2.2979e+16,  2.8853e-02, -5.4119e+12, -4.8991e+27,
         -1.3176e+05, -1.5185e-35, -5.2663e-08],
        [-4.9525e+22,  2.6456e+21, -6.6132e-16,  5.9137e+08, -6.8673e+30,
         -1.1277e+03, -8.7609e+29,  5.9418e-28],
        [-3.2768e-10, -5.1658e-14, -2.3504e+27,  3.2130e+06, -2.6921e+19,
          7.4000e-20,  1.3070e-24, -1.1684e+29],
        [-1.9485e+33, -1.6401e+27,  5.9458e-18, -1.1368e-24,  7.1163e-09,
         -5.2176e+34,  1.3326e-02,  1.3937e-38],
        [-3.4272e-07,  7.0026e+22,  3.3191e+23, -3.8086e-24, -3.1557e-28,
         -1.4411e+19,  8.2169e-20, -2.2000e+35],
        [-3.9428e+01, -4.0882e-06, -6.5982e-25,  1.6298e+12, -1.0176e+12,
          3.0798e+06,  4.0689e+02,  1.3383e+38],
        [-1.6804e+08,  3.0361e-01,  5.0893e-34,  1.2463e+18,  1.4580e+06,
         -1.8916e+05, -9.8710e+36,  2.9459e+04],
        [-2.7046e-11, -4.2445e+21,  5.9648e+01,  4.2992e+14, -3.0052e+05,
          4.9578e+23,  1.8172e+25, -2.4127e-17],
        [ 6.3310e+13,  1.4881e+32, -6.1006e-36, -6.1947e+11,  5.1969e+05,
          1.7885e+25, -1.1800e-37, -4.9508e+04],
        [ 1.3706e+17,  5.2504e-05,  8.2312e+13,  8.1923e+08,  5.6115e-25,
          4.6359e+16,  1.9769e-20, -8.4875e-32],
        [ 1.9187e+23,  9.1218e+25, -1.9125e-17,  5.3448e+23, -1.4947e+32,
         -2.7552e+25, -1.3683e-25, -8.3450e-10],
        [ 1.8771e+06,  7.4212e-37, -9.7615e-27,  5.3814e+07,  1.0501e-27,
         -2.9047e+08, -5.6822e+03,  5.3259e-01]], device='cuda:0')
w = m.weight.data
w.shape
torch.Size([128, 16])
W_q, meta = Quantizer.quantize(w, round_zero=True, optimize=True, view_as_float=False)
W_q.shape, W_q.dtype
(torch.Size([32, 32]), torch.uint8)
meta['scale'].dtype
torch.float16
w_dq = Quantizer.dequantize(W_q, meta)
w, w_dq
(tensor([[ 0.1196,  0.0683, -0.0960,  ..., -0.2410, -0.1544, -0.0864],
         [-0.0278, -0.0483,  0.1141,  ...,  0.0873,  0.0023,  0.2011],
         [ 0.0982, -0.0460,  0.0086,  ...,  0.0627, -0.0216, -0.0140],
         ...,
         [-0.0208,  0.1148, -0.0562,  ..., -0.0961,  0.2354,  0.2077],
         [ 0.1820,  0.1345, -0.0235,  ...,  0.0432, -0.1749,  0.1510],
         [-0.2125,  0.0024, -0.2045,  ..., -0.1916,  0.1080,  0.0231]]),
 tensor([[ 0.1224,  0.0717, -0.0930,  ..., -0.2524, -0.1595, -0.0937],
         [-0.0320, -0.0627,  0.1289,  ...,  0.0945,  0.0091,  0.1919],
         [ 0.0917, -0.0519,  0.0014,  ...,  0.0705, -0.0320,  0.0009],
         ...,
         [-0.0320,  0.1304, -0.0645,  ..., -0.0981,  0.2344,  0.1919],
         [ 0.1841,  0.1334, -0.0301,  ...,  0.0382, -0.1595,  0.1584],
         [-0.2222,  0.0016, -0.1934,  ..., -0.1943,  0.1057,  0.0273]],
        dtype=torch.float16))
torch.norm(w - w_dq, p=0.7)
tensor(390.0982)
BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False)
{'weight_quant_params': {'nbits': 4,
  'channel_wise': True,
  'group_size': 64,
  'optimize': True,
  'round_zero': True},
 'scale_quant_params': None,
 'zero_quant_params': None,
 'offload_meta': False}
quant_configs = [
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=False),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True),
                 BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, offload_meta=True)
]

w_dqs = []
for quant_cfg in quant_configs:
    if quant_cfg['scale_quant_params']: 
        quant_cfg['scale_quant_params']['group_size'] = 8
    if quant_cfg['zero_quant_params']: 
        if quant_cfg['offload_meta']:
            quant_cfg['zero_quant_params']['group_size'] = 8
            quant_cfg['zero_quant_params']['channel_wise'] = True
        else:
            quant_cfg['zero_quant_params']['group_size'] = None
            quant_cfg['zero_quant_params']['channel_wise'] = False
    mq = HQQLinear(m, quant_cfg, compute_dtype=torch.bfloat16, initialize=False)
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
    mq.initialize()
    print(mq.W_q.dtype, mq.meta)
    print()
    w_dqs.append(mq.dequantize_aten())
(torch.norm(w.cuda() - w_dqs[0], p=0.7),
torch.norm(w.cuda() - w_dqs[1], p=0.7),
torch.norm(w.cuda() - w_dqs[2], p=0.7),
torch.norm(w.cuda() - w_dqs[3], p=0.7),
torch.norm(w.cuda() - w_dqs[4], p=0.7))
(tensor(390.9176, device='cuda:0'),
 tensor(390.5967, device='cuda:0'),
 tensor(390.7930, device='cuda:0'),
 tensor(390.1439, device='cuda:0'),
 tensor(392.0999, device='cuda:0'))
def replace_linear_hqq(model:nn.Module, quant_config, skip_modules:List[str]=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        quant_config (`Dict[str, Any]`):
            The quantization configuration for the new linear module.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear_hqq(module, quant_config, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            model._modules[name] = HQQLinear(
                module,
                quant_config,
                **kwargs
            )
    return model
def load_and_quantize_hqq(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
                                  skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
    """
    def place_on_device(value):
        if is_meta_rank:
            device = 'meta'
        elif low_memory:
            device = 'cpu'
        return value.to(device=device, dtype=dtype)

    if any([skip_name in name for skip_name in skip_names]):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition('.')
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as e:
        print(f"Module {module_key} not found:\n{e}")
        return

    start = time.time()
    try:
        if isinstance(submodule, HQQLinear):
            if value_key == "weight":
                # init meta weights as empty on cpu
                submodule.linear_layer.to_empty(device="cpu")
                # copy pretrained weights
                submodule.linear_layer.weight.data.copy_(value)
                # quantize and update metadata
                submodule.initialize()
                
                if is_meta_rank:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("meta")))
                elif low_memory:
                    setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
                submodule.in_gpu = False

            if value_key == "bias":
                raise ValueError("Bias not supported in HQQLinear yet!")
        
            end = time.time()
            if not is_meta_rank:
                print(f"Loaded HQQLinear quantized {module_key} in {end-start:.3f} seconds")
            return
        
        else:
            param = submodule.get_parameter(value_key)
            value = type(param)(place_on_device(value).data)

    except AttributeError:
        # it's a buffer
        value = place_on_device(value)
        pass
    
    setattr(submodule, value_key, value)
    end = time.time()
    torch.cuda.empty_cache()
    if not is_meta_rank:
        print(f"Loaded {module_key} and {value_key} in {end-start:.3f} seconds")
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model.model = replace_linear_hqq(model.model, quant_config, device_n=torch.cuda.current_device(),
                                    compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model.is_loaded_in_4bit = True
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    for name, param in weights.items():
        load_and_quantize_hqq(model, name, param, dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names,
                                is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
Loading model 0
Loaded model.embed_tokens and weight in 0.067 seconds
Loaded model.layers.0.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 0.271 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 0.243 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 0.236 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 0.065 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_proj in 0.062 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 0.063 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 0.060 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.mlp.down_proj in 0.239 seconds
Loaded HQQLinear quantized model.layers.1.mlp.gate_proj in 0.247 seconds
Loaded HQQLinear quantized model.layers.1.mlp.up_proj in 0.283 seconds
Loaded model.layers.1.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.k_proj in 0.078 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.o_proj in 0.065 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.q_proj in 0.061 seconds
Loaded model.layers.1.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.v_proj in 0.074 seconds
Loaded model.layers.10.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.10.mlp.down_proj in 0.976 seconds
Loaded HQQLinear quantized model.layers.10.mlp.gate_proj in 1.748 seconds
Loaded HQQLinear quantized model.layers.10.mlp.up_proj in 1.001 seconds
Loaded model.layers.10.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.k_proj in 0.358 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.o_proj in 0.383 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.q_proj in 0.390 seconds
Loaded model.layers.10.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.v_proj in 0.394 seconds
Loaded model.layers.11.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.11.mlp.down_proj in 0.971 seconds
Loaded HQQLinear quantized model.layers.11.mlp.gate_proj in 0.959 seconds
Loaded HQQLinear quantized model.layers.11.mlp.up_proj in 1.649 seconds
Loaded model.layers.11.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.k_proj in 0.410 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.o_proj in 0.391 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.q_proj in 0.375 seconds
Loaded model.layers.11.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.v_proj in 0.401 seconds
Loaded model.layers.12.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.12.mlp.down_proj in 0.961 seconds
Loaded HQQLinear quantized model.layers.12.mlp.gate_proj in 0.927 seconds
Loaded HQQLinear quantized model.layers.12.mlp.up_proj in 0.967 seconds
Loaded model.layers.12.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.k_proj in 0.418 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.o_proj in 1.161 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.q_proj in 0.388 seconds
Loaded model.layers.12.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.v_proj in 0.385 seconds
Loaded model.layers.13.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.13.mlp.down_proj in 0.953 seconds
Loaded HQQLinear quantized model.layers.13.mlp.gate_proj in 0.949 seconds
Loaded HQQLinear quantized model.layers.13.mlp.up_proj in 0.950 seconds
Loaded model.layers.13.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.k_proj in 0.382 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.o_proj in 0.370 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.q_proj in 0.386 seconds
Loaded model.layers.13.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.v_proj in 1.341 seconds
Loaded model.layers.14.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.14.mlp.down_proj in 0.947 seconds
Loaded HQQLinear quantized model.layers.14.mlp.gate_proj in 0.946 seconds
Loaded HQQLinear quantized model.layers.14.mlp.up_proj in 0.984 seconds
Loaded model.layers.14.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.k_proj in 0.386 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.o_proj in 0.387 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.q_proj in 0.378 seconds
Loaded model.layers.14.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.v_proj in 0.376 seconds
Loaded model.layers.15.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.15.mlp.down_proj in 1.806 seconds
Loaded HQQLinear quantized model.layers.15.mlp.gate_proj in 0.921 seconds
Loaded HQQLinear quantized model.layers.15.mlp.up_proj in 0.939 seconds
Loaded model.layers.15.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.k_proj in 0.386 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.o_proj in 0.378 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.q_proj in 0.377 seconds
Loaded model.layers.15.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.v_proj in 0.391 seconds
Loaded model.layers.16.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.16.mlp.down_proj in 0.981 seconds
Loaded HQQLinear quantized model.layers.16.mlp.gate_proj in 1.731 seconds
Loaded HQQLinear quantized model.layers.16.mlp.up_proj in 0.962 seconds
Loaded model.layers.16.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.k_proj in 0.387 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.o_proj in 0.382 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.q_proj in 0.361 seconds
Loaded model.layers.16.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.v_proj in 0.365 seconds
Loaded model.layers.17.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.17.mlp.down_proj in 0.938 seconds
Loaded HQQLinear quantized model.layers.17.mlp.gate_proj in 0.966 seconds
Loaded HQQLinear quantized model.layers.17.mlp.up_proj in 1.776 seconds
Loaded model.layers.17.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.k_proj in 0.397 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.o_proj in 0.401 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.q_proj in 0.400 seconds
Loaded model.layers.17.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.v_proj in 0.359 seconds
Loaded model.layers.18.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.18.mlp.down_proj in 0.956 seconds
Loaded HQQLinear quantized model.layers.18.mlp.gate_proj in 0.964 seconds
Loaded HQQLinear quantized model.layers.18.mlp.up_proj in 0.946 seconds
Loaded model.layers.18.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.k_proj in 0.429 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.o_proj in 1.168 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.q_proj in 0.363 seconds
Loaded model.layers.18.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.v_proj in 0.367 seconds
Loaded model.layers.19.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.19.mlp.down_proj in 0.962 seconds
Loaded HQQLinear quantized model.layers.19.mlp.gate_proj in 0.942 seconds
Loaded HQQLinear quantized model.layers.19.mlp.up_proj in 0.956 seconds
Loaded model.layers.19.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.k_proj in 0.407 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.o_proj in 0.373 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.q_proj in 0.404 seconds
Loaded model.layers.19.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.v_proj in 1.342 seconds
Loaded model.layers.2.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.2.mlp.down_proj in 0.251 seconds
Loaded HQQLinear quantized model.layers.2.mlp.gate_proj in 0.241 seconds
Loaded HQQLinear quantized model.layers.2.mlp.up_proj in 0.238 seconds
Loaded model.layers.2.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.k_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.o_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.q_proj in 0.093 seconds
Loaded model.layers.2.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.v_proj in 0.094 seconds
Loaded model.layers.20.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.20.mlp.down_proj in 0.951 seconds
Loaded HQQLinear quantized model.layers.20.mlp.gate_proj in 0.962 seconds
Loaded HQQLinear quantized model.layers.20.mlp.up_proj in 0.947 seconds
Loaded model.layers.20.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.k_proj in 0.370 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.o_proj in 0.401 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.q_proj in 1.345 seconds
Loaded model.layers.20.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.v_proj in 0.411 seconds
Loaded model.layers.21.input_layernorm and weight in 0.002 seconds
Loaded HQQLinear quantized model.layers.21.mlp.down_proj in 0.966 seconds
Loaded HQQLinear quantized model.layers.21.mlp.gate_proj in 0.923 seconds
Loaded HQQLinear quantized model.layers.21.mlp.up_proj in 0.971 seconds
Loaded model.layers.21.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.k_proj in 0.391 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.o_proj in 0.376 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.q_proj in 0.398 seconds
Loaded model.layers.21.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.v_proj in 0.408 seconds
Loaded model.layers.22.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.22.mlp.down_proj in 1.392 seconds
Loaded HQQLinear quantized model.layers.22.mlp.gate_proj in 0.947 seconds
Loaded HQQLinear quantized model.layers.22.mlp.up_proj in 0.970 seconds
Loaded model.layers.22.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.k_proj in 0.398 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.o_proj in 0.383 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.q_proj in 0.443 seconds
Loaded model.layers.22.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.v_proj in 0.375 seconds
Loaded model.layers.23.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.23.mlp.down_proj in 0.961 seconds
Loaded HQQLinear quantized model.layers.23.mlp.gate_proj in 1.622 seconds
Loaded HQQLinear quantized model.layers.23.mlp.up_proj in 0.976 seconds
Loaded model.layers.23.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.k_proj in 0.362 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.o_proj in 0.406 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.q_proj in 0.391 seconds
Loaded model.layers.23.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.v_proj in 0.384 seconds
Loaded model.layers.3.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.mlp.down_proj in 0.250 seconds
Loaded HQQLinear quantized model.layers.3.mlp.gate_proj in 0.237 seconds
Loaded HQQLinear quantized model.layers.3.mlp.up_proj in 0.246 seconds
Loaded model.layers.3.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.k_proj in 0.091 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.o_proj in 0.091 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.q_proj in 0.094 seconds
Loaded model.layers.3.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.v_proj in 0.089 seconds
Loaded model.layers.4.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.mlp.down_proj in 0.235 seconds
Loaded HQQLinear quantized model.layers.4.mlp.gate_proj in 0.253 seconds
Loaded HQQLinear quantized model.layers.4.mlp.up_proj in 0.233 seconds
Loaded model.layers.4.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.k_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.o_proj in 0.093 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.q_proj in 0.095 seconds
Loaded model.layers.4.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.v_proj in 0.092 seconds
Loaded model.layers.5.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.5.mlp.down_proj in 1.329 seconds
Loaded HQQLinear quantized model.layers.5.mlp.gate_proj in 0.250 seconds
Loaded HQQLinear quantized model.layers.5.mlp.up_proj in 0.232 seconds
Loaded model.layers.5.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.k_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.o_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.q_proj in 0.092 seconds
Loaded model.layers.5.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.v_proj in 0.093 seconds
Loaded model.layers.6.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.6.mlp.down_proj in 0.248 seconds
Loaded HQQLinear quantized model.layers.6.mlp.gate_proj in 0.242 seconds
Loaded HQQLinear quantized model.layers.6.mlp.up_proj in 0.233 seconds
Loaded model.layers.6.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.k_proj in 0.098 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.o_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.q_proj in 0.095 seconds
Loaded model.layers.6.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.v_proj in 0.091 seconds
Loaded model.layers.7.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.mlp.down_proj in 0.250 seconds
Loaded HQQLinear quantized model.layers.7.mlp.gate_proj in 0.232 seconds
Loaded HQQLinear quantized model.layers.7.mlp.up_proj in 0.234 seconds
Loaded model.layers.7.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.k_proj in 0.096 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.o_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.q_proj in 0.096 seconds
Loaded model.layers.7.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.v_proj in 0.092 seconds
Loaded model.layers.8.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.8.mlp.down_proj in 0.955 seconds
Loaded HQQLinear quantized model.layers.8.mlp.gate_proj in 2.081 seconds
Loaded HQQLinear quantized model.layers.8.mlp.up_proj in 0.952 seconds
Loaded model.layers.8.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.k_proj in 0.378 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.o_proj in 0.388 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.q_proj in 0.365 seconds
Loaded model.layers.8.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.v_proj in 0.383 seconds
Loaded model.layers.9.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.9.mlp.down_proj in 0.943 seconds
Loaded HQQLinear quantized model.layers.9.mlp.gate_proj in 0.949 seconds
Loaded HQQLinear quantized model.layers.9.mlp.up_proj in 1.898 seconds
Loaded model.layers.9.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.k_proj in 0.375 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.o_proj in 0.392 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.q_proj in 0.389 seconds
Loaded model.layers.9.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.v_proj in 0.385 seconds
Loaded lm_head and weight in 0.066 seconds
Loaded model.layers.24.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.24.mlp.down_proj in 0.239 seconds
Loaded HQQLinear quantized model.layers.24.mlp.gate_proj in 0.252 seconds
Loaded HQQLinear quantized model.layers.24.mlp.up_proj in 0.248 seconds
Loaded model.layers.24.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.k_proj in 0.096 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.o_proj in 0.093 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.q_proj in 0.101 seconds
Loaded model.layers.24.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.v_proj in 0.095 seconds
Loaded model.layers.25.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.mlp.down_proj in 0.238 seconds
Loaded HQQLinear quantized model.layers.25.mlp.gate_proj in 0.261 seconds
Loaded HQQLinear quantized model.layers.25.mlp.up_proj in 0.250 seconds
Loaded model.layers.25.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.k_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.o_proj in 0.093 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.q_proj in 0.095 seconds
Loaded model.layers.25.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.v_proj in 0.103 seconds
Loaded model.layers.26.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.26.mlp.down_proj in 0.244 seconds
Loaded HQQLinear quantized model.layers.26.mlp.gate_proj in 0.241 seconds
Loaded HQQLinear quantized model.layers.26.mlp.up_proj in 1.210 seconds
Loaded model.layers.26.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.k_proj in 0.098 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.o_proj in 0.093 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.q_proj in 0.096 seconds
Loaded model.layers.26.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.v_proj in 0.152 seconds
Loaded model.layers.27.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.mlp.down_proj in 0.242 seconds
Loaded HQQLinear quantized model.layers.27.mlp.gate_proj in 0.237 seconds
Loaded HQQLinear quantized model.layers.27.mlp.up_proj in 0.235 seconds
Loaded model.layers.27.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.k_proj in 0.097 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.o_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.q_proj in 0.096 seconds
Loaded model.layers.27.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.v_proj in 0.097 seconds
Loaded model.layers.28.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.28.mlp.down_proj in 0.249 seconds
Loaded HQQLinear quantized model.layers.28.mlp.gate_proj in 0.236 seconds
Loaded HQQLinear quantized model.layers.28.mlp.up_proj in 0.235 seconds
Loaded model.layers.28.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.k_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.o_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.q_proj in 0.096 seconds
Loaded model.layers.28.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.v_proj in 0.095 seconds
Loaded model.layers.29.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.29.mlp.down_proj in 0.254 seconds
Loaded HQQLinear quantized model.layers.29.mlp.gate_proj in 0.240 seconds
Loaded HQQLinear quantized model.layers.29.mlp.up_proj in 0.240 seconds
Loaded model.layers.29.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.k_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.o_proj in 0.096 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.q_proj in 0.096 seconds
Loaded model.layers.29.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.v_proj in 0.095 seconds
Loaded model.layers.30.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.30.mlp.down_proj in 0.240 seconds
Loaded HQQLinear quantized model.layers.30.mlp.gate_proj in 0.236 seconds
Loaded HQQLinear quantized model.layers.30.mlp.up_proj in 0.236 seconds
Loaded model.layers.30.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.k_proj in 0.097 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.o_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.q_proj in 0.098 seconds
Loaded model.layers.30.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.v_proj in 0.097 seconds
Loaded model.layers.31.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.31.mlp.down_proj in 1.292 seconds
Loaded HQQLinear quantized model.layers.31.mlp.gate_proj in 0.255 seconds
Loaded HQQLinear quantized model.layers.31.mlp.up_proj in 0.235 seconds
Loaded model.layers.31.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.k_proj in 0.095 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.o_proj in 0.094 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.q_proj in 0.094 seconds
Loaded model.layers.31.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.v_proj in 0.094 seconds
Loaded model.norm and weight in 0.000 seconds
Loaded model weights in 103.558 seconds
def load_and_quantize_parallel(name_param, load_func, model, **kwargs):
    name, param = name_param
    load_func(model, name, param, **kwargs)
compute_dtype = torch.bfloat16

model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
# cfg.num_hidden_layers = 8 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
    model_fast = AutoModelForCausalLM.from_config(cfg)
    # TODO: Tune BaseQuantizeConfig.
    quant_config = BaseQuantizeConfig(nbits=4, 
                                      group_size=64, 
                                      quant_zero=True, 
                                      quant_scale=True, 
                                      offload_meta=True)
    model_fast.model = replace_linear_hqq(model_fast.model, quant_config, device_n=torch.cuda.current_device(),
                                          compute_dtype=compute_dtype, del_orig=True, initialize=False)     
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model_fast.is_loaded_in_4bit = True
local_rank = 0
low_memory = True
load_param_skip_names = []
rank = 0

print("Loading model", rank)
start = time.time()
for filename in files:
    weights = safetensors.torch.load_file(filename)
    parallel(load_and_quantize_parallel, weights.items(), n_workers=8, threadpool=True, 
             load_func=load_and_quantize_hqq, model=model_fast, 
             dtype=torch.bfloat16, device=local_rank, skip_names=load_param_skip_names, 
             is_meta_rank=(low_memory and rank!=0), verbose=True)
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
Loading model 0
Loaded model.layers.0.input_layernorm and weight in 0.003 seconds
Loaded model.layers.0.post_attention_layernorm and weight in 0.004 seconds
Loaded model.layers.0.self_attn.rotary_emb and inv_freq in 0.032 seconds
Loaded model.embed_tokens and weight in 0.203 seconds
Loaded model.layers.1.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.k_proj in 1.016 seconds
Loaded HQQLinear quantized model.layers.0.mlp.gate_proj in 1.065 seconds
Loaded HQQLinear quantized model.layers.0.mlp.down_proj in 1.201 seconds
Loaded model.layers.1.post_attention_layernorm and weight in 0.008 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.v_proj in 1.155 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.q_proj in 1.211 seconds
Loaded HQQLinear quantized model.layers.0.mlp.up_proj in 1.252 seconds
Loaded model.layers.1.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.0.self_attn.o_proj in 1.386 seconds
Loaded model.layers.10.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.mlp.down_proj in 1.298 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.o_proj in 0.402 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.v_proj in 1.823 seconds
Loaded model.layers.10.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.k_proj in 2.032 seconds
Loaded HQQLinear quantized model.layers.1.mlp.up_proj in 2.188 seconds
Loaded HQQLinear quantized model.layers.1.self_attn.q_proj in 2.030 seconds
Loaded model.layers.10.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.1.mlp.gate_proj in 2.246 seconds
Loaded model.layers.11.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.10.mlp.down_proj in 2.360 seconds
Loaded HQQLinear quantized model.layers.10.mlp.gate_proj in 2.378 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.v_proj in 0.571 seconds
Loaded model.layers.11.post_attention_layernorm and weight in 0.018 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.k_proj in 0.867 seconds
Loaded HQQLinear quantized model.layers.10.mlp.up_proj in 2.499 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.q_proj in 0.913 seconds
Loaded HQQLinear quantized model.layers.10.self_attn.o_proj in 0.953 seconds
Loaded model.layers.11.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded model.layers.12.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.11.mlp.down_proj in 0.997 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.k_proj in 0.773 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.o_proj in 1.063 seconds
Loaded model.layers.12.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.v_proj in 0.863 seconds
Loaded HQQLinear quantized model.layers.12.mlp.down_proj in 0.906 seconds
Loaded HQQLinear quantized model.layers.11.self_attn.q_proj in 1.017 seconds
Loaded model.layers.12.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.11.mlp.gate_proj in 1.516 seconds
Loaded model.layers.13.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.11.mlp.up_proj in 1.494 seconds
Loaded HQQLinear quantized model.layers.12.mlp.gate_proj in 1.054 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.o_proj in 0.673 seconds
Loaded model.layers.13.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.v_proj in 0.639 seconds
Loaded HQQLinear quantized model.layers.12.mlp.up_proj in 1.140 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.k_proj in 0.902 seconds
Loaded model.layers.13.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.12.self_attn.q_proj in 0.934 seconds
Loaded model.layers.14.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.13.mlp.down_proj in 0.963 seconds
Loaded HQQLinear quantized model.layers.13.mlp.up_proj in 0.965 seconds
Loaded HQQLinear quantized model.layers.13.mlp.gate_proj in 1.018 seconds
Loaded model.layers.14.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.k_proj in 0.812 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.q_proj in 0.942 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.v_proj in 0.828 seconds
Loaded model.layers.14.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.14.mlp.down_proj in 0.778 seconds
Loaded HQQLinear quantized model.layers.13.self_attn.o_proj in 1.024 seconds
Loaded model.layers.15.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.o_proj in 0.542 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.k_proj in 1.054 seconds
Loaded model.layers.15.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.14.mlp.up_proj in 1.978 seconds
Loaded HQQLinear quantized model.layers.14.mlp.gate_proj in 2.594 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.v_proj in 2.121 seconds
Loaded model.layers.15.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.14.self_attn.q_proj in 2.161 seconds
Loaded model.layers.16.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.15.mlp.down_proj in 2.245 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.k_proj in 1.701 seconds
Loaded HQQLinear quantized model.layers.15.mlp.up_proj in 2.032 seconds
Loaded model.layers.16.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.15.mlp.gate_proj in 2.374 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.o_proj in 1.184 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.v_proj in 0.704 seconds
Loaded model.layers.16.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.15.self_attn.q_proj in 0.981 seconds
Loaded model.layers.17.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.k_proj in 0.747 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.o_proj in 0.767 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.v_proj in 0.632 seconds
Loaded HQQLinear quantized model.layers.16.self_attn.q_proj in 0.738 seconds
Loaded model.layers.17.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.16.mlp.gate_proj in 1.288 seconds
Loaded HQQLinear quantized model.layers.16.mlp.up_proj in 1.285 seconds
Loaded HQQLinear quantized model.layers.16.mlp.down_proj in 1.503 seconds
Loaded model.layers.17.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded model.layers.18.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.17.mlp.down_proj in 1.219 secondsLoaded HQQLinear quantized model.layers.17.mlp.gate_proj in 1.209 seconds

Loaded HQQLinear quantized model.layers.17.self_attn.o_proj in 0.855 seconds
Loaded model.layers.18.post_attention_layernorm and weight in 0.029 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.k_proj in 0.922 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.q_proj in 0.810 seconds
Loaded HQQLinear quantized model.layers.17.self_attn.v_proj in 0.849 seconds
Loaded model.layers.18.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.17.mlp.up_proj in 1.460 seconds
Loaded model.layers.19.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.18.mlp.down_proj in 1.052 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.k_proj in 0.612 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.v_proj in 0.581 seconds
Loaded model.layers.19.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.o_proj in 1.007 seconds
Loaded HQQLinear quantized model.layers.18.self_attn.q_proj in 1.012 seconds
Loaded HQQLinear quantized model.layers.18.mlp.gate_proj in 1.167 seconds
Loaded model.layers.19.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.18.mlp.up_proj in 1.337 seconds
Loaded model.layers.2.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.19.mlp.down_proj in 1.059 seconds
Loaded HQQLinear quantized model.layers.19.mlp.gate_proj in 1.102 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.k_proj in 1.013 seconds
Loaded model.layers.2.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.19.mlp.up_proj in 1.142 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.v_proj in 0.642 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.q_proj in 0.751 seconds
Loaded model.layers.2.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.19.self_attn.o_proj in 0.763 seconds
Loaded model.layers.20.input_layernorm and weight in 0.006 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.q_proj in 0.689 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.o_proj in 0.734 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.k_proj in 0.771 seconds
Loaded model.layers.20.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.2.self_attn.v_proj in 0.785 seconds
Loaded HQQLinear quantized model.layers.2.mlp.down_proj in 1.439 seconds
Loaded HQQLinear quantized model.layers.2.mlp.up_proj in 2.440 seconds
Loaded model.layers.20.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.2.mlp.gate_proj in 2.582 seconds
Loaded model.layers.21.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.20.mlp.down_proj in 2.197 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.o_proj in 1.730 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.q_proj in 1.778 seconds
Loaded model.layers.21.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.v_proj in 0.687 seconds
Loaded HQQLinear quantized model.layers.20.mlp.up_proj in 2.315 seconds
Loaded HQQLinear quantized model.layers.20.self_attn.k_proj in 2.336 seconds
Loaded model.layers.21.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.21.mlp.down_proj in 1.099 seconds
Loaded HQQLinear quantized model.layers.20.mlp.gate_proj in 2.594 seconds
Loaded model.layers.22.input_layernorm and weight in 0.007 seconds
Loaded HQQLinear quantized model.layers.21.mlp.gate_proj in 1.152 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.o_proj in 0.748 seconds
Loaded model.layers.22.post_attention_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.k_proj in 0.829 seconds
Loaded HQQLinear quantized model.layers.21.mlp.up_proj in 1.203 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.v_proj in 0.771 seconds
Loaded model.layers.22.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.21.self_attn.q_proj in 0.923 seconds
Loaded model.layers.23.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.22.mlp.down_proj in 0.902 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.q_proj in 0.727 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.o_proj in 0.917 seconds
Loaded model.layers.23.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.v_proj in 0.663 seconds
Loaded HQQLinear quantized model.layers.22.mlp.gate_proj in 1.293 seconds
Loaded HQQLinear quantized model.layers.22.self_attn.k_proj in 1.033 seconds
Loaded model.layers.23.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.22.mlp.up_proj in 1.217 seconds
Loaded model.layers.3.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.v_proj in 0.604 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.o_proj in 0.804 seconds
Loaded HQQLinear quantized model.layers.23.mlp.down_proj in 1.380 seconds
Loaded model.layers.3.post_attention_layernorm and weight in 0.021 seconds
Loaded HQQLinear quantized model.layers.23.mlp.up_proj in 1.099 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.k_proj in 1.108 seconds
Loaded HQQLinear quantized model.layers.23.mlp.gate_proj in 1.493 seconds
Loaded model.layers.3.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.mlp.down_proj in 1.088 seconds
Loaded model.layers.4.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.23.self_attn.q_proj in 1.148 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.v_proj in 0.351 seconds
Loaded HQQLinear quantized model.layers.3.mlp.gate_proj in 1.057 seconds
Loaded model.layers.4.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.o_proj in 0.767 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.k_proj in 0.978 seconds
Loaded HQQLinear quantized model.layers.3.self_attn.q_proj in 0.947 seconds
Loaded model.layers.4.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.3.mlp.up_proj in 1.494 seconds
Loaded model.layers.5.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.mlp.gate_proj in 1.188 seconds
Loaded HQQLinear quantized model.layers.4.mlp.down_proj in 1.268 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.q_proj in 0.671 seconds
Loaded model.layers.5.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.k_proj in 2.018 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.o_proj in 1.968 seconds
Loaded HQQLinear quantized model.layers.4.self_attn.v_proj in 1.807 seconds
Loaded model.layers.5.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.4.mlp.up_proj in 2.425 seconds
Loaded model.layers.6.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.5.mlp.up_proj in 1.880 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.q_proj in 0.679 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.o_proj in 0.709 seconds
Loaded model.layers.6.post_attention_layernorm and weight in 0.007 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.v_proj in 0.771 seconds
Loaded HQQLinear quantized model.layers.5.self_attn.k_proj in 2.119 seconds
Loaded HQQLinear quantized model.layers.5.mlp.gate_proj in 2.472 seconds
Loaded model.layers.6.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.5.mlp.down_proj in 2.591 seconds
Loaded model.layers.7.input_layernorm and weight in 0.003 seconds
Loaded HQQLinear quantized model.layers.6.mlp.down_proj in 1.020 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.q_proj in 0.825 seconds
Loaded HQQLinear quantized model.layers.6.mlp.up_proj in 1.041 seconds
Loaded model.layers.7.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.k_proj in 1.067 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.o_proj in 0.937 seconds
Loaded HQQLinear quantized model.layers.6.self_attn.v_proj in 0.784 seconds
Loaded model.layers.7.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.6.mlp.gate_proj in 1.527 seconds
Loaded model.layers.8.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.mlp.gate_proj in 1.046 seconds
Loaded HQQLinear quantized model.layers.7.mlp.down_proj in 1.137 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.v_proj in 0.752 seconds
Loaded model.layers.8.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.q_proj in 0.925 seconds
Loaded HQQLinear quantized model.layers.7.mlp.up_proj in 1.073 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.o_proj in 1.033 seconds
Loaded model.layers.8.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.7.self_attn.k_proj in 1.133 seconds
Loaded model.layers.9.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.8.mlp.down_proj in 1.100 seconds
Loaded HQQLinear quantized model.layers.8.mlp.gate_proj in 1.235 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.v_proj in 0.645 seconds
Loaded model.layers.9.post_attention_layernorm and weight in 0.002 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.o_proj in 0.756 seconds
Loaded HQQLinear quantized model.layers.8.mlp.up_proj in 1.346 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.k_proj in 0.991 seconds
Loaded model.layers.9.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.8.self_attn.q_proj in 0.897 seconds
Loaded HQQLinear quantized model.layers.9.mlp.down_proj in 1.155 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.o_proj in 0.619 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.k_proj in 0.670 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.q_proj in 0.528 seconds
Loaded HQQLinear quantized model.layers.9.mlp.gate_proj in 0.970 seconds
Loaded HQQLinear quantized model.layers.9.self_attn.v_proj in 0.566 seconds
Loaded HQQLinear quantized model.layers.9.mlp.up_proj in 0.756 seconds
Loaded lm_head and weight in 0.330 secondsLoaded model.layers.24.input_layernorm and weight in 0.006 seconds

Loaded model.layers.24.post_attention_layernorm and weight in 0.016 seconds
Loaded model.layers.24.self_attn.rotary_emb and inv_freq in 0.001 seconds
Loaded model.layers.25.input_layernorm and weight in 0.008 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.o_proj in 1.008 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.v_proj in 1.013 seconds
Loaded HQQLinear quantized model.layers.24.mlp.down_proj in 1.464 seconds
Loaded model.layers.25.post_attention_layernorm and weight in 0.002 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.k_proj in 1.130 seconds
Loaded HQQLinear quantized model.layers.24.mlp.up_proj in 1.169 seconds
Loaded HQQLinear quantized model.layers.24.self_attn.q_proj in 1.338 seconds
Loaded model.layers.25.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.24.mlp.gate_proj in 1.436 seconds
Loaded model.layers.26.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.mlp.down_proj in 1.402 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.k_proj in 0.522 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.o_proj in 0.653 seconds
Loaded model.layers.26.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.mlp.up_proj in 0.961 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.q_proj in 0.841 seconds
Loaded HQQLinear quantized model.layers.25.mlp.gate_proj in 1.216 seconds
Loaded model.layers.26.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.25.self_attn.v_proj in 0.897 seconds
Loaded model.layers.27.input_layernorm and weight in 0.008 seconds
Loaded HQQLinear quantized model.layers.26.mlp.gate_proj in 0.943 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.k_proj in 0.647 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.q_proj in 0.673 seconds
Loaded model.layers.27.post_attention_layernorm and weight in 0.003 seconds
Loaded HQQLinear quantized model.layers.26.mlp.up_proj in 1.228 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.o_proj in 0.894 seconds
Loaded HQQLinear quantized model.layers.26.mlp.down_proj in 1.497 seconds
Loaded model.layers.27.self_attn.rotary_emb and inv_freq in 0.002 seconds
Loaded HQQLinear quantized model.layers.26.self_attn.v_proj in 0.723 seconds
Loaded model.layers.28.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.mlp.gate_proj in 1.199 seconds
Loaded HQQLinear quantized model.layers.27.mlp.up_proj in 1.211 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.o_proj in 0.845 seconds
Loaded model.layers.28.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.k_proj in 1.028 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.q_proj in 0.857 seconds
Loaded HQQLinear quantized model.layers.27.self_attn.v_proj in 0.933 seconds
Loaded model.layers.28.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.27.mlp.down_proj in 1.740 seconds
Loaded HQQLinear quantized model.layers.28.mlp.down_proj in 1.025 seconds
Loaded model.layers.29.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.q_proj in 0.835 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.o_proj in 0.862 seconds
Loaded model.layers.29.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.v_proj in 0.866 seconds
Loaded HQQLinear quantized model.layers.28.mlp.up_proj in 1.158 seconds
Loaded HQQLinear quantized model.layers.28.self_attn.k_proj in 1.129 seconds
Loaded model.layers.29.self_attn.rotary_emb and inv_freq in 0.002 seconds
Loaded HQQLinear quantized model.layers.28.mlp.gate_proj in 1.404 seconds
Loaded model.layers.30.input_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.29.mlp.gate_proj in 1.084 seconds
Loaded HQQLinear quantized model.layers.29.mlp.down_proj in 1.131 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.v_proj in 1.754 seconds
Loaded model.layers.30.post_attention_layernorm and weight in 0.003 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.k_proj in 2.057 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.o_proj in 1.930 seconds
Loaded HQQLinear quantized model.layers.29.self_attn.q_proj in 2.034 seconds
Loaded model.layers.30.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.29.mlp.up_proj in 2.393 seconds
Loaded model.layers.31.input_layernorm and weight in 0.001 seconds
Loaded HQQLinear quantized model.layers.30.mlp.up_proj in 1.942 seconds
Loaded HQQLinear quantized model.layers.30.mlp.gate_proj in 2.062 seconds
Loaded HQQLinear quantized model.layers.30.mlp.down_proj in 2.221 seconds
Loaded model.layers.31.post_attention_layernorm and weight in 0.000 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.o_proj in 0.757 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.v_proj in 0.664 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.k_proj in 1.169 seconds
Loaded model.layers.31.self_attn.rotary_emb and inv_freq in 0.000 seconds
Loaded HQQLinear quantized model.layers.30.self_attn.q_proj in 1.238 seconds
Loaded model.norm and weight in 0.015 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.k_proj in 0.725 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.q_proj in 0.440 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.o_proj in 0.576 seconds
Loaded HQQLinear quantized model.layers.31.mlp.gate_proj in 0.969 seconds
Loaded HQQLinear quantized model.layers.31.mlp.down_proj in 1.118 seconds
Loaded HQQLinear quantized model.layers.31.mlp.up_proj in 0.988 seconds
Loaded HQQLinear quantized model.layers.31.self_attn.v_proj in 0.358 seconds
Loaded model weights in 36.317 seconds
for (n1,p1), (n2,p2) in zip(model.named_parameters(), model_fast.named_parameters()):
    if n1 == n2:
        if "proj" in n1:
            assert torch.allclose(p1.view(torch.uint8), p2.view(torch.uint8))
        else:
            assert torch.allclose(p1, p2)
class HQQDORA(nn.Module):
    def __init__(self, base_layer, lora_rank, lora_dropout):
        super().__init__()
        self.base_layer = base_layer
        dtype = getattr(base_layer, "compute_dtype", next(base_layer.parameters()).dtype)
        device = next(base_layer.parameters()).device
        
        std_dev = 1 / torch.sqrt(torch.tensor(lora_rank).float())
        self.lora_A = nn.Parameter(torch.randn(base_layer.out_features, lora_rank).to(device=device,dtype=dtype)*std_dev)
        self.lora_B = nn.Parameter(torch.zeros(lora_rank, base_layer.in_features).to(device=device,dtype=dtype))

        self.m = nn.Parameter(self.base_layer.dequantize_aten().clone().norm(p=2, dim=0, keepdim=True))
    
    def forward(self, x):        

        lora = torch.matmul(self.lora_A, self.lora_B)
        adapted = self.base_layer.dequantize_aten() + lora
        column_norm = adapted.norm(p=2, dim=0, keepdim=True)

        assert torch.equal(self.m, column_norm)
        
        calc_weights = self.m * (adapted / column_norm)

        assert torch.allclose(self.base_layer.dequantize_aten(), calc_weights)
        
        return torch.matmul(x, calc_weights.t())
quant_config = BaseQuantizeConfig(nbits=4, 
                                  group_size=64, 
                                  quant_zero=True, 
                                  quant_scale=True, 
                                  offload_meta=True)

base_layer = HQQLinear(nn.Linear(128,256), quant_config, compute_dtype=torch.float32)
dora = HQQDORA(base_layer, 8, 0)
x = torch.randn(2,4,128).cuda()
torch.isclose(dora(x), torch.matmul(x, base_layer.dequantize_aten().t())).float().mean()
tensor(0.9985, device='cuda:0')
class DoRALayer(nn.Module):
    def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
        super().__init__()

        if weight is not None:
            self.weight = nn.Parameter(weight, requires_grad=False)
        else:
            self.weight = nn.Parameter(torch.Tensor(d_out, d_in), requires_grad=False)

        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False)
        else:
            self.bias = nn.Parameter(torch.Tensor(d_out), requires_grad=False)

        # m = Magnitude column-wise across output dimension
        self.m = nn.Parameter(self.weight.norm(p=2, dim=0, keepdim=True))
        
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.lora_A = nn.Parameter(torch.randn(d_out, rank)*std_dev)
        self.lora_B = nn.Parameter(torch.zeros(rank, d_in))

    def forward(self, x):
        lora = torch.matmul(self.lora_A, self.lora_B)
        adapted = self.weight + lora
        column_norm = adapted.norm(p=2, dim=0, keepdim=True)
        norm_adapted = adapted / column_norm
        calc_weights = self.m * norm_adapted
        return F.linear(x, calc_weights, self.bias)
m = nn.Linear(128,256,bias=False).cuda()
dora = DoRALayer(128,256,weight=m.weight).cuda()
dora(x)
tensor([[[-0.2144, -0.1476, -0.0111,  ...,  0.3745,  0.1425, -0.1142],
         [ 0.3202, -0.2039,  0.7589,  ..., -0.2859, -1.4159,  0.9623],
         [-0.1714,  0.4437, -0.3377,  ...,  1.4839,  1.1261,  0.1933],
         [-0.5015,  0.3812,  1.3170,  ...,  0.3666,  0.0282,  0.3237]],

        [[ 0.2638,  0.0497,  0.2547,  ...,  0.5097,  0.0237,  0.8447],
         [ 0.2788, -0.1295, -0.6743,  ...,  0.1924,  1.0936,  0.3154],
         [-0.4722,  0.2377,  0.0317,  ..., -0.6017, -0.4683, -0.1920],
         [-0.4582,  0.4022, -0.5113,  ...,  0.9794,  1.3093, -0.3878]]],
       device='cuda:0', grad_fn=<ViewBackward0>)
m(x)
tensor([[[-0.2144, -0.1476, -0.0111,  ...,  0.3745,  0.1425, -0.1142],
         [ 0.3202, -0.2039,  0.7589,  ..., -0.2859, -1.4159,  0.9623],
         [-0.1714,  0.4437, -0.3377,  ...,  1.4839,  1.1261,  0.1933],
         [-0.5015,  0.3812,  1.3170,  ...,  0.3666,  0.0282,  0.3237]],

        [[ 0.2638,  0.0497,  0.2547,  ...,  0.5097,  0.0237,  0.8447],
         [ 0.2788, -0.1295, -0.6743,  ...,  0.1924,  1.0936,  0.3154],
         [-0.4722,  0.2377,  0.0317,  ..., -0.6017, -0.4683, -0.1920],
         [-0.4582,  0.4022, -0.5113,  ...,  0.9794,  1.3093, -0.3878]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)
x.is_meta
False

Tests

from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
hqq_aten package available. Set backend to HQQBackend.ATEN for faster inference and HQQBackend.ATEN_BACKPROP for faster training!
compute_dtype = torch.bfloat16
model_name = "meta-llama/Llama-2-7b-hf"

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False
cfg._attn_implementation = "sdpa"
cfg.num_hidden_layers = 2 # DEBUG

# load model on meta device without calling init and replace nn.Linear with Linear4bit
model = AutoModelForCausalLM.from_config(cfg)
model
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, view_as_float=True)
HQQModelForCausalLM.quantize_model_(model, quant_config, compute_dtype=torch.bfloat16)
100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 144.69it/s]
100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00,  3.38s/it]
model.model.layers[0].self_attn.q_proj.meta
model.model.layers[0].self_attn.q_proj.W_q
model.save_quantized("/weka/home-keremturgutlu/models")
import json
quantized_config = json.load(open("/weka/home-keremturgutlu/models/config.json"))
quantized_weights = torch.load("/weka/home-keremturgutlu/models/qmodel.pt")
quantized_config
list(quantized_weights.keys())
quantized_weights['model.layers.0.self_attn.q_proj']
model_qt = HQQModelForCausalLM.from_quantized("/weka/home-keremturgutlu/models")
100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1804.39it/s]
100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 364.04it/s]
list(n for n,p in model_qt.named_modules())
['',
 'model',
 'model.embed_tokens',
 'model.layers',
 'model.layers.0',
 'model.layers.0.self_attn',
 'model.layers.0.self_attn.q_proj',
 'model.layers.0.self_attn.k_proj',
 'model.layers.0.self_attn.v_proj',
 'model.layers.0.self_attn.o_proj',
 'model.layers.0.self_attn.rotary_emb',
 'model.layers.0.mlp',
 'model.layers.0.mlp.gate_proj',
 'model.layers.0.mlp.up_proj',
 'model.layers.0.mlp.down_proj',
 'model.layers.0.mlp.act_fn',
 'model.layers.0.input_layernorm',
 'model.layers.0.post_attention_layernorm',
 'model.layers.1',
 'model.layers.1.self_attn',
 'model.layers.1.self_attn.q_proj',
 'model.layers.1.self_attn.k_proj',
 'model.layers.1.self_attn.v_proj',
 'model.layers.1.self_attn.o_proj',
 'model.layers.1.self_attn.rotary_emb',
 'model.layers.1.mlp',
 'model.layers.1.mlp.gate_proj',
 'model.layers.1.mlp.up_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.1.mlp.act_fn',
 'model.layers.1.input_layernorm',
 'model.layers.1.post_attention_layernorm',
 'model.norm',
 'lm_head']
def assert_state_dict(v1,v2):
    if isinstance(v1, torch.Tensor):
        assert torch.isclose(v1,v2, rtol=1e-5).float().mean().item() > 0.99
    if isinstance(v1, dict):
        for _k,_v in v1.items():
            if isinstance(_v, torch.Tensor):
                assert torch.equal(_v, v2[_k])
            else:
                assert _v == v2[_k]
for n,p in model.named_parameters():
    
    module_key, _, value_key = n.rpartition('.')
    
    d1 = model.get_submodule(module_key).state_dict()
    d2 = model_qt.get_submodule(module_key).state_dict()
    
    for (k1,v1),(k2,v2) in zip(d1.items(), d2.items()):
        assert k1 == k2
        assert_state_dict(v1,v2)
import safetensors
from safetensors.torch import save_file
import torch
weights_init = safetensors.torch.load_file("/weka/home-keremturgutlu/models/hqq_lora_dummy_init/model_state_dict.safetensors")
weights = safetensors.torch.load_file("/weka/home-keremturgutlu/models/hqq_lora_dummy/model_state_dict.safetensors")
weights
{'_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.0.weight': tensor([[-9.1553e-03,  6.0120e-03, -1.9379e-03,  ..., -7.8201e-04,
          -6.0120e-03,  7.2861e-04],
         [ 1.8616e-03,  8.5449e-03,  6.9275e-03,  ..., -1.3885e-03,
           7.6599e-03,  3.2043e-03],
         [ 7.6599e-03,  3.3417e-03,  4.3030e-03,  ...,  4.6082e-03,
          -5.3711e-03, -1.1139e-03],
         ...,
         [-4.0894e-03, -4.3945e-03,  8.1787e-03,  ...,  5.4321e-03,
          -8.4839e-03, -8.4839e-03],
         [-6.6757e-05,  3.9368e-03,  6.0272e-04,  ..., -5.1270e-03,
          -4.8218e-03, -5.3711e-03],
         [ 4.9744e-03,  1.6556e-03, -1.5640e-03,  ...,  4.1504e-03,
           7.7515e-03,  6.8359e-03]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.1.weight': tensor([[-6.2943e-05,  7.9155e-05, -7.9632e-05,  ...,  7.5340e-05,
           7.9632e-05,  7.6294e-05],
         [-6.8665e-05, -7.5817e-05,  7.2002e-05,  ...,  6.6757e-05,
          -7.6771e-05, -7.1526e-05],
         [ 5.6744e-05,  7.1049e-05,  3.7432e-05,  ..., -6.0320e-05,
           7.2956e-05,  6.6757e-05],
         ...,
         [ 7.4387e-05,  8.0109e-05, -8.0109e-05,  ...,  7.5817e-05,
           7.9155e-05,  7.8678e-05],
         [-7.5817e-05, -7.6771e-05, -7.2002e-05,  ..., -2.3365e-05,
          -7.7248e-05, -7.4863e-05],
         [-7.5817e-05, -7.9155e-05,  7.9632e-05,  ..., -7.4387e-05,
          -7.9632e-05, -7.8201e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.gate_proj.lora_AB.0.weight': tensor([[ 0.0073,  0.0133, -0.0061,  ..., -0.0149, -0.0030, -0.0018],
         [ 0.0068, -0.0081, -0.0049,  ...,  0.0010,  0.0132,  0.0133],
         [ 0.0018,  0.0052,  0.0026,  ..., -0.0033, -0.0059,  0.0154],
         ...,
         [ 0.0055, -0.0043,  0.0087,  ..., -0.0020,  0.0033, -0.0044],
         [-0.0128, -0.0116,  0.0094,  ...,  0.0137,  0.0044, -0.0029],
         [ 0.0077,  0.0098,  0.0051,  ..., -0.0092, -0.0049, -0.0122]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.gate_proj.lora_AB.1.weight': tensor([[ 6.4850e-05,  6.5327e-05,  4.8876e-05,  ..., -6.1512e-05,
          -6.6280e-05,  1.1921e-06],
         [ 7.6294e-05, -7.4387e-05, -7.2002e-05,  ...,  7.8678e-05,
          -7.8678e-05,  6.2466e-05],
         [-5.6744e-05, -3.1710e-05,  2.6226e-05,  ...,  5.3644e-05,
           4.9353e-05,  4.8637e-05],
         ...,
         [ 6.4850e-05,  4.3392e-05, -7.0572e-05,  ...,  7.5817e-05,
          -7.5340e-05,  3.7432e-05],
         [-4.5300e-05, -3.4809e-05,  6.9618e-05,  ..., -7.2956e-05,
           7.2479e-05, -1.7881e-05],
         [ 5.6744e-05, -4.6968e-05, -4.1723e-05,  ...,  6.9141e-05,
          -6.2466e-05, -2.6345e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.up_proj.lora_AB.0.weight': tensor([[ 0.0087,  0.0010,  0.0009,  ..., -0.0128,  0.0009, -0.0126],
         [-0.0003, -0.0109,  0.0051,  ...,  0.0079,  0.0143,  0.0076],
         [ 0.0022, -0.0090, -0.0013,  ...,  0.0071, -0.0138, -0.0023],
         ...,
         [-0.0103, -0.0153, -0.0061,  ..., -0.0076, -0.0004,  0.0093],
         [ 0.0066,  0.0066, -0.0040,  ...,  0.0046, -0.0043, -0.0063],
         [ 0.0049, -0.0040, -0.0118,  ...,  0.0065,  0.0112,  0.0110]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.up_proj.lora_AB.1.weight': tensor([[-3.7909e-05, -6.8665e-05, -7.6294e-05,  ...,  6.9141e-05,
           6.9618e-05,  7.4387e-05],
         [-6.0081e-05,  7.7724e-05,  7.8678e-05,  ..., -7.5817e-05,
          -7.6771e-05, -7.8201e-05],
         [-6.3419e-05, -6.9618e-05, -7.7248e-05,  ...,  6.9618e-05,
           7.0572e-05,  7.5817e-05],
         ...,
         [-6.2943e-05,  6.1512e-05,  6.5327e-05,  ..., -3.7432e-05,
          -5.5075e-05, -6.2466e-05],
         [ 4.5300e-05, -6.1512e-05, -6.9141e-05,  ...,  5.0068e-05,
           5.7936e-05,  6.5804e-05],
         [-2.7776e-05,  7.1526e-05,  7.6294e-05,  ..., -6.6757e-05,
          -7.1049e-05, -7.4387e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.k_proj.lora_AB.0.weight': tensor([[ 9.3994e-03, -5.8594e-03,  1.2085e-02,  ..., -5.6152e-03,
           1.2573e-02, -1.9531e-03],
         [-9.6436e-03,  8.5449e-04,  5.6152e-03,  ..., -1.2207e-04,
          -1.3672e-02,  5.6152e-03],
         [-2.4414e-04, -9.0332e-03,  1.5259e-02,  ..., -7.3242e-03,
           1.2451e-02,  1.4893e-02],
         ...,
         [-1.2085e-02,  1.0620e-02,  1.5503e-02,  ...,  1.1841e-02,
           8.9111e-03, -4.6387e-03],
         [ 1.2573e-02, -8.4229e-03, -1.0376e-02,  ..., -1.3794e-02,
           1.5381e-02,  8.5449e-04],
         [ 3.7842e-03, -8.0566e-03,  9.0804e-08,  ...,  7.5251e-07,
           1.2207e-04, -1.0986e-03]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.k_proj.lora_AB.1.weight': tensor([[-2.3842e-05, -2.2292e-05, -3.2187e-05,  ...,  3.5524e-05,
          -3.4094e-05, -1.0967e-05],
         [ 1.1444e-05,  2.6107e-05,  3.2425e-05,  ..., -3.6240e-05,
           3.7193e-05,  1.1206e-05],
         [-2.9325e-05, -1.8954e-05, -3.6955e-05,  ...,  4.2200e-05,
          -3.5048e-05,  3.7402e-06],
         ...,
         [ 3.1948e-05,  4.5300e-06,  1.0192e-05,  ..., -9.8944e-06,
           2.6941e-05,  7.8678e-06],
         [-5.2929e-05, -1.3590e-05, -2.5392e-05,  ...,  3.3855e-05,
          -5.3644e-05, -2.2173e-05],
         [ 3.5286e-05, -1.1623e-06,  1.7524e-05,  ..., -2.5988e-05,
           4.7445e-05,  2.3961e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.q_proj.lora_AB.0.weight': tensor([[-0.0155, -0.0035,  0.0033,  ..., -0.0059,  0.0007, -0.0093],
         [ 0.0115, -0.0034,  0.0081,  ...,  0.0051,  0.0127, -0.0049],
         [-0.0087,  0.0144,  0.0103,  ..., -0.0065,  0.0093,  0.0146],
         ...,
         [ 0.0151, -0.0115, -0.0122,  ..., -0.0070, -0.0148, -0.0117],
         [-0.0115, -0.0093, -0.0039,  ..., -0.0133,  0.0023,  0.0063],
         [-0.0115,  0.0020,  0.0040,  ..., -0.0060, -0.0133,  0.0048]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.q_proj.lora_AB.1.weight': tensor([[ 1.3888e-05, -2.3246e-05,  2.0504e-05,  ...,  3.6478e-05,
           1.6332e-05,  1.6570e-05],
         [-3.5763e-05,  3.3379e-05, -3.5048e-05,  ..., -4.7922e-05,
          -2.2650e-05, -3.0756e-05],
         [ 3.5524e-05, -7.8082e-06,  1.1206e-05,  ...,  2.5749e-05,
          -1.3113e-05,  2.5034e-05],
         ...,
         [-6.2943e-05,  5.7936e-05, -4.1246e-05,  ..., -6.4850e-05,
           3.9339e-05, -6.4373e-05],
         [ 5.0306e-05, -1.9185e-07,  4.5538e-05,  ...,  5.2214e-05,
          -3.9101e-05,  4.6730e-05],
         [ 1.1802e-05,  3.9101e-05, -3.6716e-05,  ..., -5.8651e-05,
          -4.5776e-05, -3.1948e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.v_proj.lora_AB.0.weight': tensor([[-0.0106,  0.0065, -0.0109,  ...,  0.0062,  0.0038,  0.0002],
         [-0.0055,  0.0057,  0.0050,  ..., -0.0070, -0.0024, -0.0087],
         [ 0.0095,  0.0143,  0.0037,  ...,  0.0115,  0.0078, -0.0049],
         ...,
         [-0.0072,  0.0030,  0.0105,  ..., -0.0118,  0.0081, -0.0072],
         [-0.0040, -0.0140, -0.0146,  ..., -0.0135, -0.0066, -0.0125],
         [ 0.0120,  0.0150,  0.0098,  ..., -0.0070,  0.0013,  0.0040]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.v_proj.lora_AB.1.weight': tensor([[ 7.9155e-05, -7.3910e-05, -6.2466e-05,  ...,  7.9632e-05,
           7.8678e-05,  7.9632e-05],
         [ 7.8201e-05, -7.6294e-05,  7.6294e-05,  ...,  7.8678e-05,
           7.8678e-05,  7.9632e-05],
         [-6.9618e-05, -7.8678e-05,  5.3883e-05,  ..., -7.8678e-05,
          -7.9155e-05, -7.9632e-05],
         ...,
         [ 5.9128e-05, -7.6294e-05,  7.0572e-05,  ..., -3.2425e-05,
          -7.6294e-05, -7.6294e-05],
         [ 7.6771e-05,  3.5048e-05,  6.8665e-05,  ...,  7.8678e-05,
           7.6771e-05,  7.9155e-05],
         [-7.9155e-05,  7.2002e-05, -7.4863e-05,  ..., -7.9632e-05,
          -7.9632e-05, -7.9632e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.0.weight': tensor([[-0.0056,  0.0058,  0.0009,  ...,  0.0093,  0.0085, -0.0095],
         [ 0.0070,  0.0086,  0.0059,  ...,  0.0032, -0.0076,  0.0060],
         [-0.0048, -0.0082, -0.0031,  ..., -0.0081,  0.0025,  0.0034],
         ...,
         [-0.0009, -0.0007, -0.0081,  ...,  0.0042,  0.0076,  0.0089],
         [ 0.0038,  0.0073,  0.0059,  ..., -0.0019,  0.0092, -0.0081],
         [ 0.0038,  0.0071, -0.0018,  ...,  0.0075, -0.0034,  0.0079]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.down_proj.lora_AB.1.weight': tensor([[-7.5817e-05,  7.8678e-05, -7.9632e-05,  ...,  7.8678e-05,
          -7.8678e-05,  7.9632e-05],
         [ 7.6294e-05, -7.8678e-05,  7.9632e-05,  ..., -7.9155e-05,
           7.8678e-05, -7.9632e-05],
         [-7.7724e-05,  7.9155e-05, -7.9632e-05,  ...,  7.9632e-05,
          -7.9632e-05,  7.9632e-05],
         ...,
         [-7.8201e-05,  7.9632e-05, -8.0109e-05,  ...,  7.9632e-05,
          -7.9632e-05,  7.9632e-05],
         [ 7.7724e-05, -7.8678e-05,  7.9632e-05,  ..., -7.9632e-05,
           7.8678e-05, -7.9632e-05],
         [ 7.9155e-05, -7.9632e-05,  8.0109e-05,  ..., -7.9632e-05,
           8.0109e-05, -8.0109e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.gate_proj.lora_AB.0.weight': tensor([[-0.0096, -0.0100,  0.0037,  ..., -0.0073, -0.0101, -0.0040],
         [-0.0025,  0.0040,  0.0065,  ..., -0.0127,  0.0104, -0.0142],
         [-0.0060, -0.0090, -0.0045,  ..., -0.0031,  0.0145,  0.0132],
         ...,
         [ 0.0122, -0.0121,  0.0054,  ...,  0.0054, -0.0125,  0.0112],
         [-0.0071,  0.0063,  0.0035,  ..., -0.0060, -0.0054,  0.0007],
         [ 0.0020,  0.0083, -0.0073,  ..., -0.0084,  0.0153, -0.0142]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.gate_proj.lora_AB.1.weight': tensor([[-7.6294e-05, -7.6771e-05, -7.6294e-05,  ..., -6.8665e-05,
          -7.8678e-05,  7.6294e-05],
         [-6.2466e-05, -6.2943e-05, -5.5075e-05,  ..., -4.1008e-05,
          -7.0095e-05,  6.0558e-05],
         [ 7.1049e-05,  7.2479e-05,  7.2002e-05,  ...,  6.0558e-05,
           7.6771e-05, -7.2002e-05],
         ...,
         [-3.7193e-05, -5.5313e-05, -6.4373e-05,  ..., -3.5286e-05,
          -7.1049e-05,  6.0320e-05],
         [-6.6757e-05, -6.7234e-05, -6.2466e-05,  ..., -4.4584e-05,
          -7.2956e-05,  6.5804e-05],
         [ 3.0249e-06, -2.0504e-05, -4.1723e-05,  ..., -1.6570e-05,
          -5.3167e-05,  3.1233e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.up_proj.lora_AB.0.weight': tensor([[-0.0005,  0.0094, -0.0146,  ..., -0.0083, -0.0120, -0.0103],
         [ 0.0025, -0.0045, -0.0135,  ...,  0.0118, -0.0095, -0.0140],
         [ 0.0032,  0.0143, -0.0052,  ...,  0.0096, -0.0054, -0.0072],
         ...,
         [-0.0143, -0.0050, -0.0090,  ..., -0.0144, -0.0083, -0.0112],
         [-0.0150,  0.0100,  0.0040,  ...,  0.0137, -0.0118,  0.0140],
         [-0.0010,  0.0009, -0.0063,  ...,  0.0103, -0.0009, -0.0050]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.mlp._fsdp_wrapped_module.up_proj.lora_AB.1.weight': tensor([[-7.6294e-05,  6.2943e-05,  4.2677e-05,  ...,  6.9141e-05,
           6.8188e-05,  6.5327e-05],
         [ 7.2002e-05, -4.6492e-05, -3.1948e-05,  ..., -5.7697e-05,
          -5.5552e-05, -5.2929e-05],
         [-7.7248e-05,  6.8665e-05,  6.1035e-05,  ...,  7.1049e-05,
           7.2479e-05,  6.9141e-05],
         ...,
         [ 7.5340e-05, -6.4850e-05, -5.9366e-05,  ..., -5.9843e-05,
          -6.9618e-05, -5.6267e-05],
         [-7.8678e-05,  7.2956e-05,  6.1512e-05,  ...,  7.6294e-05,
           7.5817e-05,  7.5340e-05],
         [-6.3896e-05,  2.2888e-05, -1.5199e-05,  ...,  4.9353e-05,
           2.7776e-05,  4.1962e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.k_proj.lora_AB.0.weight': tensor([[-0.0049, -0.0134, -0.0111,  ...,  0.0154,  0.0094,  0.0090],
         [-0.0101, -0.0021, -0.0040,  ...,  0.0038, -0.0110, -0.0116],
         [-0.0076,  0.0057, -0.0142,  ...,  0.0046,  0.0100,  0.0110],
         ...,
         [ 0.0057,  0.0115, -0.0063,  ...,  0.0096,  0.0128,  0.0013],
         [-0.0142, -0.0150, -0.0146,  ...,  0.0126,  0.0061,  0.0038],
         [ 0.0066, -0.0099,  0.0096,  ..., -0.0072,  0.0090, -0.0112]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.k_proj.lora_AB.1.weight': tensor([[-8.8811e-06,  1.8120e-05, -8.8215e-06,  ..., -1.4424e-05,
           2.6464e-05,  8.2254e-06],
         [ 6.9439e-06, -1.4365e-05,  6.8843e-06,  ...,  1.4782e-05,
          -2.8014e-05, -1.2636e-05],
         [-9.8348e-07, -2.2650e-06, -1.0133e-06,  ..., -1.2591e-06,
           5.5507e-07,  1.9372e-06],
         ...,
         [-5.1975e-05, -7.4387e-05, -6.6280e-05,  ..., -7.2479e-05,
           7.2956e-05,  6.9618e-05],
         [ 5.2452e-05,  7.4387e-05,  6.5327e-05,  ...,  7.2479e-05,
          -7.2956e-05, -6.9618e-05],
         [-5.8889e-05, -7.5340e-05, -6.9141e-05,  ..., -7.4387e-05,
           7.5340e-05,  7.2002e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.q_proj.lora_AB.0.weight': tensor([[-0.0050,  0.0043, -0.0043,  ..., -0.0057,  0.0144,  0.0094],
         [-0.0121, -0.0088, -0.0100,  ...,  0.0059, -0.0149, -0.0121],
         [-0.0100,  0.0126, -0.0060,  ..., -0.0100,  0.0118, -0.0099],
         ...,
         [ 0.0122, -0.0095, -0.0039,  ..., -0.0140, -0.0016, -0.0140],
         [-0.0048,  0.0043, -0.0027,  ..., -0.0020, -0.0090, -0.0046],
         [-0.0150, -0.0138, -0.0146,  ...,  0.0029,  0.0095,  0.0100]],
        dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.q_proj.lora_AB.1.weight': tensor([[ 3.4809e-05, -9.7752e-06,  2.8491e-05,  ..., -3.5524e-05,
           1.6928e-05,  4.7445e-05],
         [ 2.4319e-05, -4.1425e-06,  1.8716e-05,  ..., -2.5153e-05,
           8.8215e-06,  3.6001e-05],
         [ 1.1563e-05, -1.8254e-06,  7.8678e-06,  ..., -1.0610e-05,
           2.9206e-06,  1.8239e-05],
         ...,
         [ 7.2956e-05, -3.6240e-05,  6.8665e-05,  ..., -7.0095e-05,
           6.5804e-05,  7.5340e-05],
         [-7.3433e-05,  3.6240e-05, -6.8665e-05,  ...,  7.1049e-05,
          -6.5327e-05, -7.5817e-05],
         [ 7.2002e-05, -3.2187e-05,  6.6280e-05,  ..., -6.8665e-05,
           6.2943e-05,  7.4863e-05]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.v_proj.lora_AB.0.weight': tensor([[-3.7537e-03,  1.1108e-02, -3.1281e-04,  ...,  7.4463e-03,
           1.1230e-02, -1.5015e-02],
         [-1.4114e-03,  9.1553e-03,  1.2695e-02,  ..., -8.4229e-03,
           1.2817e-02,  8.0566e-03],
         [-1.5259e-02, -1.5335e-03,  2.9907e-03,  ..., -1.2817e-02,
          -1.4114e-03, -1.2329e-02],
         ...,
         [ 1.0254e-02,  4.5166e-03,  1.2939e-02,  ..., -1.1108e-02,
           6.7139e-03, -1.3062e-02],
         [ 6.4373e-05, -1.0452e-03, -1.0452e-03,  ..., -1.7624e-03,
           6.7139e-03,  1.1841e-02],
         [-5.8594e-03, -1.2329e-02, -1.1841e-02,  ..., -2.2583e-03,
          -3.7384e-03,  9.1553e-03]], dtype=torch.bfloat16),
 '_fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module.self_attn._fsdp_wrapped_module.v_proj.lora_AB.1.weight': tensor([[-7.9632e-05,  7.9632e-05, -7.9632e-05,  ..., -7.9632e-05,
          -7.9632e-05,  7.8678e-05],
         [-7.5340e-05, -7.7724e-05,  7.9632e-05,  ...,  7.9155e-05,
           7.4387e-05, -6.9141e-05],
         [ 8.0109e-05, -8.0109e-05,  8.0109e-05,  ...,  8.0109e-05,
           8.0109e-05, -8.0109e-05],
         ...,
         [ 7.9632e-05, -8.0109e-05,  7.9632e-05,  ...,  7.9632e-05,
           7.9632e-05, -7.9632e-05],
         [ 7.8678e-05, -7.8678e-05,  7.9632e-05,  ...,  7.9632e-05,
           7.9632e-05, -7.8678e-05],
         [ 7.9632e-05, -8.0109e-05,  7.9632e-05,  ...,  7.9632e-05,
           7.9632e-05, -7.9632e-05]], dtype=torch.bfloat16)}
for k, v in weights_init.items():

    if ('base_layer' in k) or ('W_q' in k):    
        if not torch.equal(v.view(torch.uint8), weights[k].view(torch.uint8)):
            print("Changed", k)
    else:
        if not torch.equal(v, weights[k]):
            print("Changed", k)
Changed model.layers.0.mlp.down_proj.lora_AB.0.weight
Changed model.layers.0.mlp.down_proj.lora_AB.1.weight
Changed model.layers.0.mlp.gate_proj.lora_AB.0.weight
Changed model.layers.0.mlp.gate_proj.lora_AB.1.weight
Changed model.layers.0.mlp.up_proj.lora_AB.0.weight
Changed model.layers.0.mlp.up_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.k_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.k_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.q_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.q_proj.lora_AB.1.weight
Changed model.layers.0.self_attn.v_proj.lora_AB.0.weight
Changed model.layers.0.self_attn.v_proj.lora_AB.1.weight
Changed model.layers.1.mlp.down_proj.lora_AB.0.weight
Changed model.layers.1.mlp.down_proj.lora_AB.1.weight
Changed model.layers.1.mlp.gate_proj.lora_AB.0.weight
Changed model.layers.1.mlp.gate_proj.lora_AB.1.weight
Changed model.layers.1.mlp.up_proj.lora_AB.0.weight
Changed model.layers.1.mlp.up_proj.lora_AB.1.weight
Changed model.layers.1.self_attn.k_proj.lora_AB.0.weight
Changed model.layers.1.self_attn.k_proj.lora_AB.1.weight
Changed model.layers.1.self_attn.q_proj.lora_AB.0.weight
Changed model.layers.1.self_attn.q_proj.lora_AB.1.weight
Changed model.layers.1.self_attn.v_proj.lora_AB.0.weight
Changed model.layers.1.self_attn.v_proj.lora_AB.1.weight