Use mini llama, getting more interesting results

This commit is contained in:
WillJeynes
2026-04-10 17:34:36 +01:00
parent 7665292db4
commit 2417efbeca
2 changed files with 11 additions and 6 deletions
+10 -5
View File
@@ -35,7 +35,7 @@ toy_instr_data = [
# Example: print first few
print(toy_instr_data[:3])
tok_gpt = AutoTokenizer.from_pretrained("distilgpt2")
tok_gpt = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tok_gpt.pad_token = tok_gpt.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False)
@@ -82,15 +82,20 @@ if bnb_available:
quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
quant_kwargs["device_map"] = {"": 0} # specify device map
base_lm = AutoModelForCausalLM.from_pretrained("distilgpt2", **quant_kwargs)
base_lm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", **quant_kwargs)
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, # rank
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["c_attn", "c_proj", "c_fc"],
fan_in_fan_out=True,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj"
]
)
lora_model = get_peft_model(base_lm, lora_cfg)