{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e28fb85c", "metadata": {}, "outputs": [], "source": [ "# %pip install gputil\n", "# %pip install setuptools\n", "# %pip install transformers\n", "# %pip install torch\n", "\n", "# %pip install auto-gptq #==0.4.0" ] }, { "cell_type": "markdown", "id": "10e0a35a", "metadata": {}, "source": [ "What happens if you try to rotate an entire LMM model. Will it still work if you consistently rotate all trained matrices?\n", "\n", "Doing this is very specific to the internal representations of a particular LMM. Different models have very different internal layers and representations. Layers may have different shapes, or are concatenated (such as the kvq matrices). \n", "\n", "Should all matrices be rotated, and which should be conjugated? \n", "\n", "This notebook just offers some base code, it's still far removed from the right approach." ] }, { "cell_type": "code", "execution_count": 2, "id": "0667e71a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mick/pycharmprojects/Frankenstein/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import GPUtil\n", "\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", "import torch\n", "# from auto_gptq import AutoGPTQForCausalLM" ] }, { "cell_type": "code", "execution_count": 3, "id": "0273f299", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No GPU detected on this system.\n" ] } ], "source": [ "gpus = GPUtil.getGPUs()\n", "if not gpus:\n", " print(\"No GPU detected on this system.\")\n", "else:\n", " for gpu in gpus:\n", " print(f\"GPU Name: {gpu.name}\")\n", " print(f\"Total VRAM: {gpu.memoryTotal} MB\")\n", " print(f\"Free VRAM: {gpu.memoryFree} MB\")\n", " print(f\"Used VRAM: {gpu.memoryUsed} MB\")\n", " print(\"-\" * 40)" ] }, { "cell_type": "code", "execution_count": 4, "id": "67d7e006", "metadata": {}, "outputs": [], "source": [ "def grab_model(model_name, quantized = False):\n", " if quantized:\n", " model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n", " else:\n", " model = AutoModelForCausalLM.from_pretrained(model_name)\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "153e9ff5", "metadata": {}, "outputs": [], "source": [ "modelA, tokenizerA = grab_model(\"gpt2\")\n", "modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n", "\n", "# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n", "# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "1da291ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.config.hidden_size == modelB.config.hidden_size " ] }, { "cell_type": "code", "execution_count": 7, "id": "c62b2f41", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(False)\n", "tensor(22.2842, grad_fn=)\n", "tensor(11.5013, grad_fn=)\n" ] } ], "source": [ "print(torch.isnan(modelB.get_input_embeddings().weight).any())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).max())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).mean())" ] }, { "cell_type": "code", "execution_count": 8, "id": "2b9893a3", "metadata": {}, "outputs": [], "source": [ "def check_orthogonal(R):\n", " I = torch.eye(R.size(0), device=R.device)\n", " delta = torch.norm(R.T @ R - I)\n", " print(f\"Delta: {delta:.6e}\")\n", " " ] }, { "cell_type": "code", "execution_count": 9, "id": "e1a54c24", "metadata": {}, "outputs": [], "source": [ "# use proscrustes:\n", "def procrustes(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n", " # A_centered = A - A.mean(dim=0, keepdim=True)\n", " # B_centered = B - B.mean(dim=0, keepdim=True)\n", "\n", " #M = B_centered.T @ A_centered\n", " M = B.T @ A\n", " # find optimal rotation with svd\n", " U, _, Vt = torch.linalg.svd(M)\n", "\n", " # get rotation matrix that aligns B to A\n", " R = U @ Vt\n", "\n", " check_orthogonal(R)\n", " \n", " return R # return rotated tensor\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fedd4d04", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "ff93495e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Delta: 6.706436e-05\n", "torch.Size([1024, 768])\n" ] } ], "source": [ "emb1 = modelA.get_input_embeddings().weight\n", "emb2 = modelB.get_input_embeddings().weight\n", "\n", "# get rotation matrix\n", "R = procrustes(emb2, emb1)\n", "emb1_R = emb1 @ R\n", "\n", "new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n", "\n", "modelA.set_input_embeddings(new_embedding)\n", "modelA.lm_head.weight = new_embedding.weight\n", "\n", "# def rotate_weight(W, R):\n", "# if W.shape[1] == R.shape[0]:\n", "# return W @ R\n", "# if W.shape[0] == R.shape[0]:\n", "# return R.T @ W\n", "\n", "# now fix the other layers by conjugating:\n", "# for block in modelA.transformer.h:\n", "# for M in [block.attn.c_attn, block.mlp.c_fc]:\n", "# W = M.weight.data\n", "# W[:] = R.T @ W\n", "# for M in [block.attn.c_proj, block.mlp.c_proj]:\n", "# W = M.weight.data\n", "# W[:] = R.T @ W @ R\n", "\n", "def split_rotate_concat(W):\n", " parts1 = [x for x in W.split(768, dim=1)]\n", " for i, v in enumerate(parts1):\n", " parts2 = [x for x in v.split(768, dim=0)]\n", " for j, w in enumerate(parts2):\n", " parts2[j] = R.T @ w @ R\n", " parts1[i] = torch.cat(parts2, dim=0)\n", " return torch.cat(parts1, dim=1)\n", "\n", "\n", "def rotate_layernorm(ln):\n", " ln.weight.data[:] = ln.weight.data @ R\n", " ln.bias.data[:] = ln.bias.data @ R\n", "\n", "for block in modelA.transformer.h:\n", " # print(block.attn.c_attn.weight.data.shape)\n", " # print(block.mlp.c_fc.weight.data.shape)\n", " # print(block.attn.c_proj.weight.data.shape)\n", " # print(block.mlp.c_proj.weight.data.shape)\n", " # block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data.T).T\n", " # block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data.T).T\n", " block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data)\n", " block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data)\n", " block.attn.c_proj.weight.data[:] = split_rotate_concat(block.attn.c_proj.weight.data)\n", " block.mlp.c_proj.weight.data[:] = split_rotate_concat(block.mlp.c_proj.weight.data)\n", " rotate_layernorm(block.ln_1)\n", " rotate_layernorm(block.ln_2)\n", "\n", "rotate_layernorm(modelA.transformer.ln_f)\n", "\n", "\n", "print(modelA.transformer.wpe.weight.data.shape)\n", "modelA.transformer.wpe.weight.data[:] = modelA.transformer.wpe.weight.data @ R\n", "\n", " # for name in ['c_attn', 'c_proj']:\n", " # W = getattr(block.attn, name).weight.data\n", " # W[:] = R.T @ W @ R\n", " # w1 = block.mlp.c_fc.weight.data\n", " # w2 = block.mlp.c_proj.weight.data\n", " # w1[:] = R.T @ W1 @ R\n", " # w2[:] = R.T @ W2 @ R\n", " \n" ] }, { "cell_type": "code", "execution_count": 11, "id": "d8d9d612", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Device set to use cpu\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[{'generated_text': 'Hello, how are you?orm Coulorm Coulormorm Coulão Coulorm Coulorm Coulorm Coulorm Coulorm Coulonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomous Coulonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousidableonomousonomousonomousato Coulonomousonomousonomous Coulonomousonomousonomouskeleyonomous Coulorm Amenomniaifulstad Amenonomousonomous Amenstad Amenomnia Amenomniaifulomniaomniaormstadifulifulomnia Coulonomousifulomniaomniaomniaifulomniaomnia Coulifulomniahered Coul Amenomniakeleyomniaomniastadifulomnia Amenomniaomniaomniakeleyomniaomniaomniastad…]omniaomnia Coulkeleyomniaomniaomnia Coulomniakeleyomnia Coulomniaomniaomniaifulomniaomniaomniakeleyomniaomniaomniaomniaomniaomniaomniaomniastadomniaomniaomnia CoulkeleyomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaNRSormkeleyomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomniaomnia'}]\n" ] } ], "source": [ "# use model\n", "pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n", "print(pipe(\"Hello, how are you?\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "fc72ea8a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }