Frankenstein/rotation_fixed.ipynb

349 lines
11 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e28fb85c",
"metadata": {},
"outputs": [],
"source": [
"# %pip install gputil\n",
"# %pip install setuptools\n",
"# %pip install transformers\n",
"# %pip install torch\n",
"\n",
"# %pip install auto-gptq #==0.4.0"
]
},
{
"cell_type": "markdown",
"id": "10e0a35a",
"metadata": {},
"source": [
"What happens if you try to rotate an entire LMM model. Will it still work if you consistently rotate all trained matrices?\n",
"\n",
"Doing this is very specific to the internal representations of a particular LMM. Different models have very different internal layers and representations. Layers may have different shapes, or are concatenated (such as the kvq matrices). \n",
"\n",
"Should all matrices be rotated, and which should be conjugated? \n",
"\n",
"This notebook just offers some base code, it's still far removed from the right approach."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0667e71a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mick/pycharmprojects/Frankenstein/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import GPUtil\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
"import torch\n",
"# from auto_gptq import AutoGPTQForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0273f299",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No GPU detected on this system.\n"
]
}
],
"source": [
"gpus = GPUtil.getGPUs()\n",
"if not gpus:\n",
" print(\"No GPU detected on this system.\")\n",
"else:\n",
" for gpu in gpus:\n",
" print(f\"GPU Name: {gpu.name}\")\n",
" print(f\"Total VRAM: {gpu.memoryTotal} MB\")\n",
" print(f\"Free VRAM: {gpu.memoryFree} MB\")\n",
" print(f\"Used VRAM: {gpu.memoryUsed} MB\")\n",
" print(\"-\" * 40)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "67d7e006",
"metadata": {},
"outputs": [],
"source": [
"def grab_model(model_name, quantized = False):\n",
" if quantized:\n",
" model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n",
" else:\n",
" model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" return model, tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "153e9ff5",
"metadata": {},
"outputs": [],
"source": [
"modelA, tokenizerA = grab_model(\"gpt2\")\n",
"modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n",
"\n",
"# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n",
"# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1da291ed",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"modelA.config.hidden_size == modelB.config.hidden_size "
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "c62b2f41",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(False)\n",
"tensor(22.2842, grad_fn=<MaxBackward1>)\n",
"tensor(11.5013, grad_fn=<MeanBackward0>)\n"
]
}
],
"source": [
"print(torch.isnan(modelB.get_input_embeddings().weight).any())\n",
"print(torch.norm(modelB.get_input_embeddings().weight, dim=1).max())\n",
"print(torch.norm(modelB.get_input_embeddings().weight, dim=1).mean())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2b9893a3",
"metadata": {},
"outputs": [],
"source": [
"def check_orthogonal(R):\n",
" I = torch.eye(R.size(0), device=R.device)\n",
" delta = torch.norm(R.T @ R - I)\n",
" print(f\"Delta: {delta:.6e}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1a54c24",
"metadata": {},
"outputs": [],
"source": [
"# use proscrustes:\n",
"def procrustes(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n",
" # A_centered = A - A.mean(dim=0, keepdim=True)\n",
" # B_centered = B - B.mean(dim=0, keepdim=True)\n",
"\n",
" #M = B_centered.T @ A_centered\n",
" M = B.T @ A\n",
" # find optimal rotation with svd\n",
" U, _, Vt = torch.linalg.svd(M)\n",
"\n",
" # get rotation matrix that aligns B to A\n",
" R = U @ Vt\n",
"\n",
" check_orthogonal(R)\n",
" \n",
" return R # return rotated tensor\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fedd4d04",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff93495e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Delta: 6.706436e-05\n",
"torch.Size([1024, 768])\n"
]
}
],
"source": [
"emb1 = modelA.get_input_embeddings().weight\n",
"emb2 = modelB.get_input_embeddings().weight\n",
"\n",
"# get rotation matrix\n",
"R = procrustes(emb2, emb1)\n",
"emb1_R = emb1 @ R\n",
"\n",
"new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n",
"\n",
"modelA.set_input_embeddings(new_embedding)\n",
"modelA.lm_head.weight = new_embedding.weight\n",
"\n",
"# def rotate_weight(W, R):\n",
"# if W.shape[1] == R.shape[0]:\n",
"# return W @ R\n",
"# if W.shape[0] == R.shape[0]:\n",
"# return R.T @ W\n",
"\n",
"# now fix the other layers by conjugating:\n",
"# for block in modelA.transformer.h:\n",
"# for M in [block.attn.c_attn, block.mlp.c_fc]:\n",
"# W = M.weight.data\n",
"# W[:] = R.T @ W\n",
"# for M in [block.attn.c_proj, block.mlp.c_proj]:\n",
"# W = M.weight.data\n",
"# W[:] = R.T @ W @ R\n",
"\n",
"def split_rotate_concat(W):\n",
" parts1 = [x for x in W.split(768, dim=1)]\n",
" for i, v in enumerate(parts1):\n",
" parts2 = [x for x in v.split(768, dim=0)]\n",
" for j, w in enumerate(parts2):\n",
" parts2[j] = R.T @ w @ R\n",
" parts1[i] = torch.cat(parts2, dim=0)\n",
" return torch.cat(parts1, dim=1)\n",
"\n",
"\n",
"def rotate_layernorm(ln):\n",
" ln.weight.data[:] = ln.weight.data @ R\n",
" ln.bias.data[:] = ln.bias.data @ R\n",
"\n",
"for block in modelA.transformer.h:\n",
" # print(block.attn.c_attn.weight.data.shape)\n",
" # print(block.mlp.c_fc.weight.data.shape)\n",
" # print(block.attn.c_proj.weight.data.shape)\n",
" # print(block.mlp.c_proj.weight.data.shape)\n",
" # block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data.T).T\n",
" # block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data.T).T\n",
" block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data)\n",
" block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data)\n",
" block.attn.c_proj.weight.data[:] = split_rotate_concat(block.attn.c_proj.weight.data)\n",
" block.mlp.c_proj.weight.data[:] = split_rotate_concat(block.mlp.c_proj.weight.data)\n",
" rotate_layernorm(block.ln_1)\n",
" rotate_layernorm(block.ln_2)\n",
"\n",
"rotate_layernorm(modelA.transformer.ln_f)\n",
"\n",
"\n",
"print(modelA.transformer.wpe.weight.data.shape)\n",
"modelA.transformer.wpe.weight.data[:] = modelA.transformer.wpe.weight.data @ R\n",
"\n",
" # for name in ['c_attn', 'c_proj']:\n",
" # W = getattr(block.attn, name).weight.data\n",
" # W[:] = R.T @ W @ R\n",
" # w1 = block.mlp.c_fc.weight.data\n",
" # w2 = block.mlp.c_proj.weight.data\n",
" # w1[:] = R.T @ W1 @ R\n",
" # w2[:] = R.T @ W2 @ R\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "d8d9d612",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Device set to use cpu\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'generated_text': 'Hello, how are you?orm Coulormormorm Coulorm Coul Coulorm Coul Coulinion Coulorm Coulonomousonomous Coulonomousonomousonomousonomousonomousormonomous Coulorm Coulonomousonomousonomous Coulonomousonomous Coulonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Amenonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomous Coulonomousonomousonomousonomousonomousonomousonomoushered…] Coulonomousonomousonomous Amenomniaifulonomousonomouskeleyifulonomous Amenomniaifulhered Amenkeleyomniastad Coulonomousifulifulomniaifulomniaifulomniaifulifulifulomniaifulomnia…]hered…]ifulomniaifulifulomniastadkeleyomniaifulifulomniaifulomniaifulomniakeleyomniaomniaomnia Coulomniaifulomnia Coulifulomnia Coul Coulkeleyomniastad Coulomnia Coulkeleyomnia Coulkeleyomnia Coulkeleyomniaomnia Coulkeleyomniaomniaomniaomniastadomniaomniaomniaomnia Coulkeleyonomousomnia Coulomniaomniaomnia Coulkeleyomnia Coulomniaomniaomniaomnia Coulomniaomniakeleyomniakeleyomniakeleyomniaomniaomniaomniakeleystadkeleyomniakeleyomniaomnia'}]\n"
]
}
],
"source": [
"# use model\n",
"pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n",
"print(pipe(\"Hello, how are you?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc72ea8a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}