Frankenstein/rotation.ipynb
2025-06-20 14:47:57 +02:00

286 lines
8.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e28fb85c",
"metadata": {},
"outputs": [],
"source": [
"# %pip install gputil\n",
"# %pip install setuptools\n",
"# %pip install transformers\n",
"# %pip install torch\n",
"\n",
"# %pip install auto-gptq #==0.4.0"
]
},
{
"cell_type": "markdown",
"id": "9c5573f3",
"metadata": {},
"source": [
"What happens if you rotate the input and output vectors?"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0667e71a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mick/pycharmprojects/Frankenstein/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import GPUtil\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
"import torch\n",
"# from auto_gptq import AutoGPTQForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "67d7e006",
"metadata": {},
"outputs": [],
"source": [
"def grab_model(model_name, quantized = False):\n",
" if quantized:\n",
" model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n",
" else:\n",
" model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" return model, tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "153e9ff5",
"metadata": {},
"outputs": [],
"source": [
"modelA, tokenizerA = grab_model(\"gpt2\")\n",
"modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n",
"\n",
"# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n",
"# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1da291ed",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"modelA.config.hidden_size == modelB.config.hidden_size "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c62b2f41",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(False)\n",
"tensor(22.2842, grad_fn=<MaxBackward1>)\n",
"tensor(11.5013, grad_fn=<MeanBackward0>)\n"
]
}
],
"source": [
"print(torch.isnan(modelB.get_input_embeddings().weight).any())\n",
"print(torch.norm(modelB.get_input_embeddings().weight, dim=1).max())\n",
"print(torch.norm(modelB.get_input_embeddings().weight, dim=1).mean())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2b9893a3",
"metadata": {},
"outputs": [],
"source": [
"def check_orthogonal(R):\n",
" I = torch.eye(R.size(0), device=R.device)\n",
" delta = torch.norm(R.T @ R - I)\n",
" print(f\"Delta: {delta:.6e}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e1a54c24",
"metadata": {},
"outputs": [],
"source": [
"# use proscrustes:\n",
"def procrustes(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n",
" # A_centered = A - A.mean(dim=0, keepdim=True)\n",
" # B_centered = B - B.mean(dim=0, keepdim=True)\n",
"\n",
" #M = B_centered.T @ A_centered\n",
" M = B.T @ A\n",
" # find optimal rotation with svd\n",
" U, _, Vt = torch.linalg.svd(M)\n",
"\n",
" # get rotation matrix that aligns B to A\n",
" R = U @ Vt\n",
"\n",
" check_orthogonal(R)\n",
" \n",
" return B @ R # return rotated tensor\n",
"\n",
"def get_rotated_matrix(A, B, n = 1000):\n",
" # use only the first n tokens for rotation:\n",
" # return procrustes(A[:n], B[:n])\n",
" return procrustes(A, B)\n",
" \n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ff93495e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Delta: 6.706436e-05\n",
"ModelA mean norms: 3.9585366249084473\n",
"ModelB mean norms: 11.50130844116211\n",
"Rotated modelA mean norms: 3.9585366249084473\n",
"new_embedding mean norms: 3.9585366249084473\n"
]
}
],
"source": [
"emb1 = modelA.get_input_embeddings().weight\n",
"emb2 = modelB.get_input_embeddings().weight\n",
"\n",
"emb1_R = get_rotated_matrix(emb2, emb1)\n",
"\n",
"print(\"ModelA mean norms:\", torch.norm(emb1, dim=1).mean().item())\n",
"print(\"ModelB mean norms:\", torch.norm(emb2, dim=1).mean().item())\n",
"print(\"Rotated modelA mean norms:\", torch.norm(emb1_R, dim=1).mean().item())\n",
"\n",
"new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n",
"\n",
"print(\"new_embedding mean norms:\", torch.norm(new_embedding.weight, dim=1).mean().item())\n",
"\n",
"modelA.set_input_embeddings(new_embedding)\n",
"modelA.lm_head.weight = new_embedding.weight\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "85957357",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"modelA.lm_head.out_features == tokenizerA.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d8d9d612",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Device set to use cpu\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'generated_text': 'Hello, how are you?erderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderderd cue cue cue cuecue cuecueerd Nicotineerd Nicotineerd Nicotineerd Nicotine cue Nicotine cue Nicotine cue Nicotine cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue cue'}]\n"
]
}
],
"source": [
"# use model\n",
"pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n",
"print(pipe(\"Hello, how are you?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc72ea8a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}