Implement ensemble into final model structure
This commit is contained in:
@@ -3,21 +3,15 @@ from fastapi import FastAPI
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
# Embedding model
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# Roberta
|
||||
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
||||
|
||||
# Flan (seq2seq)
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
############################################
|
||||
# ----------- REQUEST SCHEMA ---------------
|
||||
# SCHEMA
|
||||
############################################
|
||||
|
||||
class EvalRequest(BaseModel):
|
||||
@@ -26,7 +20,7 @@ class EvalRequest(BaseModel):
|
||||
|
||||
|
||||
############################################
|
||||
# ----------- LOGREG MODEL -----------------
|
||||
# REGRESSION MODEL
|
||||
############################################
|
||||
|
||||
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
||||
@@ -72,7 +66,7 @@ logreg_model.eval()
|
||||
|
||||
|
||||
############################################
|
||||
# ----------- ROBERTA MODEL ----------------
|
||||
# ROBERTA
|
||||
############################################
|
||||
|
||||
ROBERTA_PATH = "WillJeynes/LLMsForDisinformationAnalysis"
|
||||
@@ -83,7 +77,7 @@ roberta_model.eval()
|
||||
|
||||
|
||||
############################################
|
||||
# ----------- FLAN MODEL -------------------
|
||||
# FLAN
|
||||
############################################
|
||||
|
||||
FLAN_PATH = "WillJeynes/LLMsForDisinformationAnalysis-Flan"
|
||||
@@ -126,7 +120,7 @@ def parse_generated_label(generated: str):
|
||||
|
||||
|
||||
############################################
|
||||
# ----------- MAIN ENDPOINT ---------------
|
||||
# ENDPOINT
|
||||
############################################
|
||||
|
||||
@app.post("/evaluate")
|
||||
|
||||
Reference in New Issue
Block a user