testing code for deberta, need to run on GPU

This commit is contained in:
William Jeynes
2026-03-22 16:55:21 +00:00
parent c69730df6b
commit bff5423f3d
+23 -9
View File
@@ -1,6 +1,6 @@
from sklearn.utils import compute_class_weight from sklearn.utils import compute_class_weight
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
import torch import torch
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
@@ -10,7 +10,7 @@ import csv
import numpy as np import numpy as np
NUM_CLASSES = 3 NUM_CLASSES = 3
model_name = "roberta-base" model_name = "microsoft/deberta-v3-base"
LABEL_PRIORITY = [ LABEL_PRIORITY = [
("PERFECT", 0), ("PERFECT", 0),
@@ -29,12 +29,19 @@ class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.get("labels") labels = inputs.get("labels")
print("Before forward")
outputs = model(**inputs) outputs = model(**inputs)
print("After forward")
logits = outputs.get("logits") logits = outputs.get("logits")
loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device)) # loss_fct = CrossEntropyLoss(weight=self.class_weights.to(logits.device))
loss_fct = CrossEntropyLoss(
weight=self.class_weights.to(logits.device).to(logits.dtype)
)
print("Before loss")
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
# loss.backward()
print("After loss")
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
def label_to_int(extra_info: str) -> int: def label_to_int(extra_info: str) -> int:
@@ -114,22 +121,28 @@ def compute_metrics(eval_pred):
} }
def main(): def main():
torch.multiprocessing.set_start_method('fork') # torch.multiprocessing.set_start_method('fork')
print("CUDA available:", torch.cuda.is_available()) print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count()) print("CUDA device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU") print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
texts, labels = load_dataset_from_csv("../../data/classify.csv") texts, labels = load_dataset_from_csv("../../data/classify.csv")
tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2) # tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
model = RobertaForSequenceClassification.from_pretrained( # model = RobertaForSequenceClassification.from_pretrained(
# model_name,
# num_labels=NUM_CLASSES
# )
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, model_name,
num_labels=NUM_CLASSES num_labels=NUM_CLASSES
) )
for param in model.roberta.parameters(): for param in model.deberta.parameters():
param.requires_grad = False param.requires_grad = False
for param in model.roberta.encoder.layer[-6:].parameters(): for param in model.deberta.encoder.layer[-1:].parameters():
param.requires_grad = True param.requires_grad = True
print("Dataset size:", len(texts)) print("Dataset size:", len(texts))
@@ -173,6 +186,7 @@ def main():
self.labels = labels self.labels = labels
def __getitem__(self, idx): def __getitem__(self, idx):
print(f"Loading item {idx}")
item = { item = {
key: torch.tensor(val[idx]) key: torch.tensor(val[idx])
for key, val in self.encodings.items() for key, val in self.encodings.items()