{ "cells": [ { "cell_type": "code", "execution_count": 53, "id": "e28fb85c", "metadata": {}, "outputs": [], "source": [ "# %pip install gputil\n", "# %pip install setuptools\n", "# %pip install transformers\n", "# %pip install torch\n", "\n", "# %pip install auto-gptq #==0.4.0" ] }, { "cell_type": "markdown", "id": "10e0a35a", "metadata": {}, "source": [ "What happens if you try to rotate the input embedding, and then rotate back just before the first activation function in the first neural network? \n", "\n", "That should work, but the input and output embedding are tied, so they have to be untied.\n" ] }, { "cell_type": "code", "execution_count": 54, "id": "0667e71a", "metadata": {}, "outputs": [], "source": [ "import GPUtil\n", "\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", "import torch\n", "# from auto_gptq import AutoGPTQForCausalLM" ] }, { "cell_type": "code", "execution_count": 55, "id": "0273f299", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No GPU detected on this system.\n" ] } ], "source": [ "gpus = GPUtil.getGPUs()\n", "if not gpus:\n", " print(\"No GPU detected on this system.\")\n", "else:\n", " for gpu in gpus:\n", " print(f\"GPU Name: {gpu.name}\")\n", " print(f\"Total VRAM: {gpu.memoryTotal} MB\")\n", " print(f\"Free VRAM: {gpu.memoryFree} MB\")\n", " print(f\"Used VRAM: {gpu.memoryUsed} MB\")\n", " print(\"-\" * 40)" ] }, { "cell_type": "code", "execution_count": 56, "id": "67d7e006", "metadata": {}, "outputs": [], "source": [ "def grab_model(model_name, quantized = False):\n", " if quantized:\n", " model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n", " else:\n", " model = AutoModelForCausalLM.from_pretrained(model_name)\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 57, "id": "153e9ff5", "metadata": {}, "outputs": [], "source": [ "modelA, tokenizerA = grab_model(\"gpt2\")\n", "modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n", "\n", "# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n", "# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)" ] }, { "cell_type": "code", "execution_count": 58, "id": "1da291ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.config.hidden_size == modelB.config.hidden_size " ] }, { "cell_type": "code", "execution_count": 59, "id": "c62b2f41", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(False)\n", "tensor(22.2842, grad_fn=)\n", "tensor(11.5013, grad_fn=)\n" ] } ], "source": [ "print(torch.isnan(modelB.get_input_embeddings().weight).any())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).max())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).mean())" ] }, { "cell_type": "code", "execution_count": 60, "id": "2b9893a3", "metadata": {}, "outputs": [], "source": [ "def check_orthogonal(R):\n", " I = torch.eye(R.size(0), device=R.device)\n", " delta = torch.norm(R.T @ R - I)\n", " print(f\"Delta: {delta:.6e}\")\n", " " ] }, { "cell_type": "code", "execution_count": 61, "id": "e1a54c24", "metadata": {}, "outputs": [], "source": [ "# use proscrustes:\n", "def procrustes(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n", " # A_centered = A - A.mean(dim=0, keepdim=True)\n", " # B_centered = B - B.mean(dim=0, keepdim=True)\n", "\n", " #M = B_centered.T @ A_centered\n", " M = B.T @ A\n", " # find optimal rotation with svd\n", " U, _, Vt = torch.linalg.svd(M)\n", "\n", " # get rotation matrix that aligns B to A\n", " R = U @ Vt\n", "\n", " check_orthogonal(R)\n", " \n", " return R # return rotated tensor\n" ] }, { "cell_type": "code", "execution_count": 62, "id": "fedd4d04", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([768, 3072])" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.transformer.h[0].mlp.c_fc.weight.shape" ] }, { "cell_type": "code", "execution_count": 63, "id": "ff93495e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Delta: 6.706436e-05\n", "torch.Size([1024, 768])\n" ] } ], "source": [ "emb1 = modelA.get_input_embeddings().weight\n", "emb2 = modelB.get_input_embeddings().weight\n", "\n", "# get rotation matrix\n", "R = procrustes(emb2, emb1)\n", "emb1_R = emb1 @ R\n", "\n", "original_embedding = emb1.data.clone()\n", "\n", "new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n", "\n", "modelA.set_input_embeddings(new_embedding)\n", "# modelA.lm_head.weight = new_embedding.weight\n", "# modelA.lm_head.weight.data[:] = new_embedding.weight.data[:]\n", "\n", "# untie the head:\n", "modelA.lm_head = torch.nn.Linear(modelA.config.n_embd, modelA.config.vocab_size, bias = False)\n", "# modelA.lm_head.weight.data[:] = original_embedding\n", "modelA.lm_head.weight.data = original_embedding\n", "\n", "# rotate back only the weights of first layer of the first NN encountered (before first activation function)\n", "modelA.transformer.h[0].mlp.c_fc.weight.data[:] = (modelA.transformer.h[0].mlp.c_fc.weight.data.T @ R.T).T\n", "\n", "print(modelA.transformer.wpe.weight.data.shape)\n", "modelA.transformer.wpe.weight.data[:] = modelA.transformer.wpe.weight.data @ R\n", "\n" ] }, { "cell_type": "code", "execution_count": 64, "id": "a8ef9109", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " ...,\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True]])" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.lm_head.weight == emb1\n", "\n", "# but somehow, input and output embedding are still tied...?" ] }, { "cell_type": "code", "execution_count": 65, "id": "848d2fc2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.transformer.wte is modelA.lm_head.weight" ] }, { "cell_type": "code", "execution_count": 66, "id": "33888de4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(890.5013)\n", "tensor(890.4819, grad_fn=)\n" ] } ], "source": [ "# modelA.lm_head.weight.data.zero_()\n", "print(torch.norm(modelA.transformer.wte.weight))\n", "print(torch.norm(modelA.lm_head.weight))\n", "\n", "# proof: they are untied, and seem rotated." ] }, { "cell_type": "code", "execution_count": 67, "id": "d8d9d612", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Device set to use cpu\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[{'generated_text': 'Hello, how are you?.,, that..,. and\\n\\n the. of...,. this.., and. and.,.\\'the.!. and (...,... the, for, to...,......,.. and,, if your the more.,.., the,., and.., that or.,..,., the the..\\'of the an...... the or.. or and to... to., the.,. the and.. do. to [ that. of that the and,.. who that..,...,... for. for or,.. with the,.. a.,. in a.. (,\\n..... and. or the... the-. the). to, the.. that \",.... you. the on. or.-- of.,. and are on..:\\n of:. and. that that or.,,,.. of, for,,., for the,\\n or,... and, the\\n'}]\n" ] } ], "source": [ "# use model\n", "pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n", "print(pipe(\"Hello, how are you?\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "fc72ea8a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }