Frankenstein/reflection.ipynb

224 lines
6.9 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e28fb85c",
"metadata": {},
"outputs": [],
"source": [
"# %pip install gputil\n",
"# %pip install setuptools\n",
"# %pip install transformers\n",
"# %pip install torch\n",
"\n",
"# %pip install auto-gptq #==0.4.0"
]
},
{
"cell_type": "markdown",
"id": "f16a013c",
"metadata": {},
"source": [
"What happens when you reflect or invert the input and output embeddings?"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0667e71a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mick/pycharmprojects/Frankenstein/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import GPUtil\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
"import torch\n",
"# from auto_gptq import AutoGPTQForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "67d7e006",
"metadata": {},
"outputs": [],
"source": [
"def grab_model(model_name, quantized = False):\n",
" if quantized:\n",
" model = AutoGPTQForCausalLM.from_quantized(model_name, device=\"cpu\", use_safetensors=True)\n",
" else:\n",
" model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" return model, tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "153e9ff5",
"metadata": {},
"outputs": [],
"source": [
"modelA, tokenizerA = grab_model(\"gpt2\")\n",
"modelB, tokenizerB = grab_model(\"EleutherAI/gpt-neo-125M\")\n",
"\n",
"# modelA, tokenizerA = grab_model(\"EleutherAI/gpt-neo-125M-4bit\", quantized=True)\n",
"# modelB, tokenizerB = grab_model(\"iproskurina/opt-125m-GPTQ-4bit-g128\", quantized=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1da291ed",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"modelA.config.hidden_size == modelB.config.hidden_size "
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0160d672",
"metadata": {},
"outputs": [],
"source": [
"emb1 = modelA.get_input_embeddings().weight\n",
"emb2 = modelB.get_input_embeddings().weight"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff93495e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelA mean norms: 3.9585366249084473\n",
"Rotated modelB mean norms: 3.958536148071289\n",
"new_embedding mean norms: 3.958536148071289\n"
]
}
],
"source": [
"\n",
"# flip the matrix\n",
"emb1_R = emb1.flip(dims=[0]) # .reverse() #[::-1]\n",
"\n",
"print(\"ModelA mean norms:\", torch.norm(emb1, dim=1).mean().item())\n",
"print(\"Rotated modelA mean norms:\", torch.norm(emb1_R, dim=1).mean().item())\n",
"\n",
"new_embedding = torch.nn.Embedding.from_pretrained(emb1_R)\n",
"\n",
"print(\"new_embedding mean norms:\", torch.norm(new_embedding.weight, dim=1).mean().item())\n",
"\n",
"modelA.set_input_embeddings(new_embedding)\n",
"modelA.lm_head.weight = new_embedding.weight\n"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "85957357",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"modelA.lm_head.out_features == tokenizerA.vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "d8d9d612",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Device set to use cpu\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'generated_text': 'Hello, how are you? Hitman chemicallychev lobeassi composure029capitalist composure exacerbateightingataka harsher Hoy 1886 typew composure curlsidad harsher Babe lobeMach Titus kindred chemicallyRush Intelligent Scare annihilationoblchev harsher Christy Christyansky peppighting typew composure OPEC HitmanEngineDebugchev lobe conceptions partying IGF partying composure 1886 harsherRush castlesGbLESS composure peppightinglations Optical ENTER Tel harsherRush siph composure 1886 chemicallyRushAbysstechnology Rated instructional Scare annihilationchev harsher Christy Christy Leilan repaidevaluate clamp composure peppighting partyingkie partyingRush522 HitmanEngineDebugRushspective629chevAbyss Rated Ada doesnt harsherRush MILL THESE CSI AchievementsCollinschev lobe Ada doesnt harsher Christy Christy feats! kW unjust Ker workaround Hitman Mondays bunnyManufact Mercenary composuregradient OnePlustown grandmaansky upbringingsei781Alternative 465 Kafka partyingMach Trash totem typew grandma733 composure184capitalist Naplesreditary MISS gazedcele composure528 lobe supremacists Hitman DISTRICT Dominic lair harsher Christy Christy Livebushimaruansky upbringingsei amplification SERVikanovych Christy Christyansky Kafkacanon Wanted spears tamp chemically Ker workaround typew 451 annihilation 1889Mach Trash Mondaysprinted]+ RatedPlotcele NETWORK Trash lobe Curve Mercenary composure vitri833 HitmanManufact harsher Christy Christy gazed jihadists typewMach trout baths Trash781 Ker workaroundolded composureWord Lyndon harsher grandma896 lobe decaying annihilation RatedMach FRI PERSON MondaysAlternativeMach FRI redundancy harsher BCC lobe decaying annihilation CONTROLMach ragedcapitalist CRC Helsinki harsher BCC lobe decaying annihilation'}]\n"
]
}
],
"source": [
"# use model\n",
"pipe = pipeline(\"text-generation\", model=modelA, tokenizer=tokenizerB)\n",
"print(pipe(\"Hello, how are you?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc72ea8a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}