Use mini llama, getting more interesting results

This commit is contained in:
WillJeynes
2026-04-10 17:34:36 +01:00
parent 7665292db4
commit 2417efbeca
2 changed files with 11 additions and 6 deletions
+10 -5
View File
@@ -35,7 +35,7 @@ toy_instr_data = [
# Example: print first few # Example: print first few
print(toy_instr_data[:3]) print(toy_instr_data[:3])
tok_gpt = AutoTokenizer.from_pretrained("distilgpt2") tok_gpt = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tok_gpt.pad_token = tok_gpt.eos_token tok_gpt.pad_token = tok_gpt.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False)
@@ -82,15 +82,20 @@ if bnb_available:
quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
quant_kwargs["device_map"] = {"": 0} # specify device map quant_kwargs["device_map"] = {"": 0} # specify device map
base_lm = AutoModelForCausalLM.from_pretrained("distilgpt2", **quant_kwargs) base_lm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", **quant_kwargs)
lora_cfg = LoraConfig( lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
r=8, # rank r=8,
lora_alpha=32, lora_alpha=32,
lora_dropout=0.05, lora_dropout=0.05,
target_modules=["c_attn", "c_proj", "c_fc"], target_modules=[
fan_in_fan_out=True, "q_proj",
"k_proj",
"v_proj",
"o_proj"
]
) )
lora_model = get_peft_model(base_lm, lora_cfg) lora_model = get_peft_model(base_lm, lora_cfg)
+1 -1
View File
@@ -5,7 +5,7 @@ from peft import PeftModel
# ----------------------------- # -----------------------------
# Config # Config
# ----------------------------- # -----------------------------
BASE_MODEL_NAME = "distilgpt2" BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "./ft_lora_adapter" ADAPTER_PATH = "./ft_lora_adapter"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"