{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e28fb85c", "metadata": {}, "outputs": [], "source": [ "# %pip install gputil\n", "# %pip install setuptools\n", "# %pip install transformers\n", "# %pip install torch\n", "\n", "# %pip install auto-gptq #==0.4.0" ] }, { "cell_type": "code", "execution_count": 1, "id": "0667e71a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mick/pycharmprojects/Frankenstein/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import GPUtil\n", "\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", "import torch\n", "# from auto_gptq import AutoGPTQForCausalLM" ] }, { "cell_type": "code", "execution_count": 7, "id": "0273f299", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No GPU detected on this system.\n" ] } ], "source": [ "gpus = GPUtil.getGPUs()\n", "if not gpus:\n", " print(\"No GPU detected on this system.\")\n", "else:\n", " for gpu in gpus:\n", " print(f\"GPU Name: {gpu.name}\")\n", " print(f\"Total VRAM: {gpu.memoryTotal} MB\")\n", " print(f\"Free VRAM: {gpu.memoryFree} MB\")\n", " print(f\"Used VRAM: {gpu.memoryUsed} MB\")\n", " print(\"-\" * 40)" ] }, { "cell_type": "code", "execution_count": 2, "id": "67d7e006", "metadata": {}, "outputs": [], "source": [ "def grab_model(model_name, quantized = False):\n", " if quantized:\n", " model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n", " else:\n", " model = AutoModelForCausalLM.from_pretrained(model_name)\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 63, "id": "153e9ff5", "metadata": {}, "outputs": [], "source": [ "modelA, tokenizerA = grab_model(\"gpt2\")\n", "modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n", "\n", "# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n", "# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1da291ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.config.hidden_size == modelB.config.hidden_size " ] }, { "cell_type": "code", "execution_count": 60, "id": "dcfc2d85", "metadata": {}, "outputs": [], "source": [ "# replace tokenizer:\n", "# modelA.Tokenizer = tokenizerB # optional when not accessing directly \n", "\n", "# replace token embeddings for input and output:\n", "# modelA.set_input_embeddings(modelB.get_input_embeddings())\n", "# modelA.lm_head.weight = modelB.get_input_embeddings().weight\n", "# modelA.resize_token_embeddings(tokenizerB.vocab_size)\n", "\n", "# modelA.transformer.wpe.weight = modelB.transformer.wpe.weight\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1011d3ad", "metadata": {}, "outputs": [], "source": [ "# emb1 = modelA.get_input_embeddings().weight\n", "# emb2 = modelB.get_input_embeddings().weight\n", "\n", "# print(\"ModelA mean norms:\", torch.norm(emb1, dim=1).mean().item())\n", "# print(\"ModelB mean norms:\", torch.norm(emb2, dim=1).mean().item())\n", "\n", "# scaling_factor = torch.norm(emb1, dim=1).mean().item() / torch.norm(emb2, dim=1).mean().item()\n", "\n", "# print(scaling_factor)\n", "\n", "# new_embedding = torch.nn.Embedding.from_pretrained(emb2*scaling_factor)\n", "\n", "# print(\"new_embedding mean norms:\", torch.norm(new_embedding.weight, dim=1).mean().item())\n", "\n", "# modelA.set_input_embeddings(new_embedding)\n", "# modelA.lm_head.weight = new_embedding.weight\n" ] }, { "cell_type": "code", "execution_count": 113, "id": "c62b2f41", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(False)\n", "tensor(22.2842, grad_fn=)\n", "tensor(11.5013, grad_fn=)\n" ] } ], "source": [ "print(torch.isnan(modelB.get_input_embeddings().weight).any())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).max())\n", "print(torch.norm(modelB.get_input_embeddings().weight, dim=1).mean())" ] }, { "cell_type": "code", "execution_count": 4, "id": "2b9893a3", "metadata": {}, "outputs": [], "source": [ "def check_orthogonal(R):\n", " I = torch.eye(R.size(0), device=R.device)\n", " delta = torch.norm(R.T @ R - I)\n", " print(f\"Delta: {delta:.6e}\")\n", " " ] }, { "cell_type": "code", "execution_count": 8, "id": "e1a54c24", "metadata": {}, "outputs": [], "source": [ "# use proscrustes:\n", "def procrustes(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n", " # A_centered = A - A.mean(dim=0, keepdim=True)\n", " # B_centered = B - B.mean(dim=0, keepdim=True)\n", "\n", " #M = B_centered.T @ A_centered\n", " M = B.T @ A\n", " # find optimal rotation with svd\n", " U, _, Vt = torch.linalg.svd(M)\n", "\n", " # get rotation matrix that aligns B to A\n", " R = U @ Vt\n", "\n", " check_orthogonal(R)\n", " \n", " return R # return rotated tensor\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fedd4d04", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ff93495e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Delta: 6.706436e-05\n", "torch.Size([1024, 768])\n" ] } ], "source": [ "emb1 = modelA.get_input_embeddings().weight\n", "emb2 = modelB.get_input_embeddings().weight\n", "\n", "# get rotation matrix\n", "R = procrustes(emb2, emb1)\n", "emb1_R = emb1 @ R\n", "\n", "new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n", "\n", "modelA.set_input_embeddings(new_embedding)\n", "modelA.lm_head.weight = new_embedding.weight\n", "\n", "# def rotate_weight(W, R):\n", "# if W.shape[1] == R.shape[0]:\n", "# return W @ R\n", "# if W.shape[0] == R.shape[0]:\n", "# return R.T @ W\n", "\n", "# now fix the other layers by conjugating:\n", "# for block in modelA.transformer.h:\n", "# for M in [block.attn.c_attn, block.mlp.c_fc]:\n", "# W = M.weight.data\n", "# W[:] = R.T @ W\n", "# for M in [block.attn.c_proj, block.mlp.c_proj]:\n", "# W = M.weight.data\n", "# W[:] = R.T @ W @ R\n", "\n", "def split_rotate_concat(W):\n", " parts1 = [x for x in W.split(768, dim=1)]\n", " for i, v in enumerate(parts1):\n", " parts2 = [x for x in v.split(768, dim=0)]\n", " for j, w in enumerate(parts2):\n", " parts2[j] = R.T @ w @ R\n", " parts1[i] = torch.cat(parts2, dim=0)\n", " return torch.cat(parts1, dim=1)\n", "\n", "\n", "def rotate_layernorm(ln):\n", " ln.weight.data[:] = ln.weight.data @ R\n", " ln.bias.data[:] = ln.bias.data @ R\n", "\n", "for block in modelA.transformer.h:\n", " # print(block.attn.c_attn.weight.data.shape)\n", " # print(block.mlp.c_fc.weight.data.shape)\n", " # print(block.attn.c_proj.weight.data.shape)\n", " # print(block.mlp.c_proj.weight.data.shape)\n", " # block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data.T).T\n", " # block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data.T).T\n", " block.attn.c_attn.weight.data[:] = split_rotate_concat(block.attn.c_attn.weight.data)\n", " block.mlp.c_fc.weight.data[:] = split_rotate_concat(block.mlp.c_fc.weight.data)\n", " block.attn.c_proj.weight.data[:] = split_rotate_concat(block.attn.c_proj.weight.data)\n", " block.mlp.c_proj.weight.data[:] = split_rotate_concat(block.mlp.c_proj.weight.data)\n", " rotate_layernorm(block.ln_1)\n", " rotate_layernorm(block.ln_2)\n", "\n", "rotate_layernorm(modelA.transformer.ln_f)\n", "\n", "\n", "print(modelA.transformer.wpe.weight.data.shape)\n", "modelA.transformer.wpe.weight.data[:] = modelA.transformer.wpe.weight.data @ R\n", "\n", " # for name in ['c_attn', 'c_proj']:\n", " # W = getattr(block.attn, name).weight.data\n", " # W[:] = R.T @ W @ R\n", " # w1 = block.mlp.c_fc.weight.data\n", " # w2 = block.mlp.c_proj.weight.data\n", " # w1[:] = R.T @ W1 @ R\n", " # w2[:] = R.T @ W2 @ R\n", " \n" ] }, { "cell_type": "code", "execution_count": null, "id": "9b671b41", "metadata": {}, "outputs": [], "source": [ "# modelA.transformer.wpe.weight = modelB.transformer.wpe.weight" ] }, { "cell_type": "code", "execution_count": null, "id": "85957357", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# modelA.lm_head.out_features == tokenizerA.vocab_size" ] }, { "cell_type": "markdown", "id": "f6b39638", "metadata": {}, "source": [ "Text:" ] }, { "cell_type": "markdown", "id": "fbfa8d62", "metadata": {}, "source": [ "With it:\n" ] }, { "cell_type": "code", "execution_count": null, "id": "998a0ed6", "metadata": {}, "outputs": [], "source": [ "# modelA.lm_head.weight = modelA.get_input_embeddings().weight # should change nothing: they are the same object.\n" ] }, { "cell_type": "markdown", "id": "aa8b7ca4", "metadata": {}, "source": [ "Text:" ] }, { "cell_type": "code", "execution_count": 66, "id": "d8d9d612", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Device set to use cpu\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[{'generated_text': 'Hello, how are you?orm Coulormormorm Coulorm Coul Coulorm Coul Coulinion Coulorm Coulonomousonomous Coulonomousonomousonomousonomousonomousormonomous Coulorm Coulonomousonomousonomous Coulonomousonomous Coulonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomous Coulonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Amenonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomousonomous Coulonomous Coulonomousonomousonomousonomousonomousonomousonomoushered…] Coulonomousonomousonomous Amenomniaifulonomousonomouskeleyifulonomous Amenomniaifulhered Amenkeleyomniastad Coulonomousifulifulomniaifulomniaifulomniaifulifulifulomniaifulomnia…]hered…]ifulomniaifulifulomniastadkeleyomniaifulifulomniaifulomniaifulomniakeleyomniaomniaomnia Coulomniaifulomnia Coulifulomnia Coul Coulkeleyomniastad Coulomnia Coulkeleyomnia Coulkeleyomnia Coulkeleyomniaomnia Coulkeleyomniaomniaomniaomniastadomniaomniaomniaomnia Coulkeleyonomousomnia Coulomniaomniaomnia Coulkeleyomnia Coulomniaomniaomniaomnia Coulomniaomniakeleyomniakeleyomniakeleyomniaomniaomniaomniakeleystadkeleyomniakeleyomniaomnia'}]\n" ] } ], "source": [ "# use model\n", "pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n", "print(pipe(\"Hello, how are you?\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "fc72ea8a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d7673a5e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 26, "id": "79616f5c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelA.lm_head.weight.data_ptr() == modelA.get_input_embeddings().weight.data_ptr()" ] }, { "cell_type": "code", "execution_count": 27, "id": "7fc76499", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "modelB.lm_head.weight.data_ptr() == modelB.get_input_embeddings().weight.data_ptr()" ] }, { "cell_type": "code", "execution_count": 10, "id": "7d51d201", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 1212, 318, 257, 1332, 290, 1312, 4240, 1521, 262, 11241,\n", " 11341, 389, 262, 976])\n" ] } ], "source": [ "tok = tokenizerA(\"This is a test and i wonder why the tokenizers are the same\", return_tensors = \"pt\")\n", "print(tok.input_ids[0])" ] }, { "cell_type": "code", "execution_count": 11, "id": "2e76534a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 1212, 318, 257, 1332, 290, 1312, 4240, 1521, 262, 11241,\n", " 11341, 389, 262, 976])\n" ] } ], "source": [ "tok = tokenizerB(\"This is a test and i wonder why the tokenizers are the same\", return_tensors = \"pt\")\n", "print(tok.input_ids[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "a44c465a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "381c712f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "153995fe", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }