Make the model less overfitting. Make it harder for an event to be classed as "perfect"
This commit is contained in:
@@ -9,8 +9,8 @@ export const robertaMetrics: GraphNode<typeof MessagesState> = async (state) =>
|
|||||||
const result = await evaluateWithRoberta({answer})
|
const result = await evaluateWithRoberta({answer})
|
||||||
|
|
||||||
let score = 0;
|
let score = 0;
|
||||||
if (result.validProb > result.invalidProb) {
|
if (result.validProb > (result.invalidProb+0.4)) {
|
||||||
score = 0.7 + ((result.validProb - result.invalidProb)*0.3);
|
score = 0.7 + ((result.validProb - (result.invalidProb+0.4))*0.3);
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
results/
|
results/
|
||||||
roberta_classifier/
|
roberta_classifier/
|
||||||
roberta_classifier*/
|
roberta_classifier*/
|
||||||
|
output*
|
||||||
|
|
||||||
# -- THEIRS --
|
# -- THEIRS --
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ def main():
|
|||||||
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
|
||||||
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
texts, labels = load_dataset_from_csv("../../data/classify.csv")
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
tokenizer = RobertaTokenizer.from_pretrained(model_name, hidden_dropout_prob=0.2,attention_probs_dropout_prob=0.2)
|
||||||
model = RobertaForSequenceClassification.from_pretrained(
|
model = RobertaForSequenceClassification.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
num_labels=NUM_CLASSES
|
num_labels=NUM_CLASSES
|
||||||
@@ -129,7 +129,7 @@ def main():
|
|||||||
for param in model.roberta.parameters():
|
for param in model.roberta.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
for param in model.roberta.encoder.layer[-3:].parameters():
|
for param in model.roberta.encoder.layer[-6:].parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
print("Dataset size:", len(texts))
|
print("Dataset size:", len(texts))
|
||||||
@@ -185,9 +185,9 @@ def main():
|
|||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./results",
|
output_dir="./results",
|
||||||
learning_rate=1e-5,
|
learning_rate=2e-5,
|
||||||
per_device_train_batch_size=8,
|
per_device_train_batch_size=32,
|
||||||
num_train_epochs=15,
|
num_train_epochs=5,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
eval_strategy="epoch",
|
eval_strategy="epoch",
|
||||||
|
|||||||
@@ -118,8 +118,14 @@ def render():
|
|||||||
total = sum(confidence_counter.values())
|
total = sum(confidence_counter.values())
|
||||||
correct = confidence_counter["Correct-PERFECT"] + confidence_counter["Correct-FINE"] + confidence_counter["Correct-FALSE"]
|
correct = confidence_counter["Correct-PERFECT"] + confidence_counter["Correct-FINE"] + confidence_counter["Correct-FALSE"]
|
||||||
|
|
||||||
|
goodkept = confidence_counter["Correct-PERFECT"] + confidence_counter["Correct-FINE"]
|
||||||
|
allkept = confidence_counter["Correct-PERFECT"] + confidence_counter["Correct-FINE"] + confidence_counter["Over-confident"]
|
||||||
|
|
||||||
|
|
||||||
corr_percent = (correct / total) * 100
|
corr_percent = (correct / total) * 100
|
||||||
|
kept_percent = (goodkept / allkept) * 100
|
||||||
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
|
st.markdown(f"**Correct: {corr_percent:.2f}% ({correct}/{total})**")
|
||||||
|
st.markdown(f"**Kept: {kept_percent:.2f}% ({goodkept}/{allkept})**")
|
||||||
st.markdown(f"Duplicates: {dup_counter}")
|
st.markdown(f"Duplicates: {dup_counter}")
|
||||||
st.pyplot(fig, width=500)
|
st.pyplot(fig, width=500)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user