Use mini llama, getting more interesting results
This commit is contained in:
+10
-5
@@ -35,7 +35,7 @@ toy_instr_data = [
|
|||||||
# Example: print first few
|
# Example: print first few
|
||||||
print(toy_instr_data[:3])
|
print(toy_instr_data[:3])
|
||||||
|
|
||||||
tok_gpt = AutoTokenizer.from_pretrained("distilgpt2")
|
tok_gpt = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||||
tok_gpt.pad_token = tok_gpt.eos_token
|
tok_gpt.pad_token = tok_gpt.eos_token
|
||||||
|
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tok_gpt, mlm=False)
|
||||||
@@ -82,15 +82,20 @@ if bnb_available:
|
|||||||
quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
quant_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
||||||
quant_kwargs["device_map"] = {"": 0} # specify device map
|
quant_kwargs["device_map"] = {"": 0} # specify device map
|
||||||
|
|
||||||
base_lm = AutoModelForCausalLM.from_pretrained("distilgpt2", **quant_kwargs)
|
base_lm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", **quant_kwargs)
|
||||||
|
|
||||||
|
|
||||||
lora_cfg = LoraConfig(
|
lora_cfg = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
r=8, # rank
|
r=8,
|
||||||
lora_alpha=32,
|
lora_alpha=32,
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05,
|
||||||
target_modules=["c_attn", "c_proj", "c_fc"],
|
target_modules=[
|
||||||
fan_in_fan_out=True,
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_model = get_peft_model(base_lm, lora_cfg)
|
lora_model = get_peft_model(base_lm, lora_cfg)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from peft import PeftModel
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Config
|
# Config
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
BASE_MODEL_NAME = "distilgpt2"
|
BASE_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
ADAPTER_PATH = "./ft_lora_adapter"
|
ADAPTER_PATH = "./ft_lora_adapter"
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user