Use mini llama, getting more interesting results
This commit is contained in:
+10
-5
@@ -35,7 +35,7 @@ toy_instr_data = [
|
||||
# Example: print first few
|
||||
print(toy_instr_data[:3])
|
||||
|
||||
tok_gpt = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
tok_gpt = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
tok_gpt.pad_token = tok_gpt.eos_token
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False)
|
||||
@@ -82,15 +82,20 @@ if bnb_available:
|
||||
quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
||||
quant_kwargs["device_map"] = {"": 0} # specify device map
|
||||
|
||||
base_lm = AutoModelForCausalLM.from_pretrained("distilgpt2", **quant_kwargs)
|
||||
base_lm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", **quant_kwargs)
|
||||
|
||||
|
||||
lora_cfg = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=8, # rank
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules=["c_attn", "c_proj", "c_fc"],
|
||||
fan_in_fan_out=True,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj"
|
||||
]
|
||||
)
|
||||
|
||||
lora_model = get_peft_model(base_lm, lora_cfg)
|
||||
|
||||
Reference in New Issue
Block a user