From 87fccb7e2b15fb3dcf71d44f0a5cc696e3074d50 Mon Sep 17 00:00:00 2001 From: William Jeynes Date: Tue, 24 Mar 2026 13:23:08 +0000 Subject: [PATCH] Add downloading from hugging face --- .../RAGAS_Service/regression_service.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/supporting/RAGAS_Service/regression_service.py b/supporting/RAGAS_Service/regression_service.py index a387e63..9e53edf 100644 --- a/supporting/RAGAS_Service/regression_service.py +++ b/supporting/RAGAS_Service/regression_service.py @@ -3,10 +3,31 @@ from sentence_transformers import SentenceTransformer from fastapi import FastAPI import torch import torch.nn as nn +from huggingface_hub import hf_hub_download +import os app = FastAPI() -MODEL_PATH = "logreg_classifier.pt" +HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression" +MODEL_FILENAME = "logreg_classifier.pt" +CACHE_DIR = "./model_cache" + + +def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict: + local_path = os.path.join(cache_dir, filename) + if not os.path.exists(local_path): + print(f"Downloading {filename} from {repo_id}...") + os.makedirs(cache_dir, exist_ok=True) + downloaded = hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=cache_dir, + ) + print(f"Saved to {downloaded}") + else: + print(f"Using cached model at {local_path}") + return torch.load(local_path, map_location="cpu") + class LogisticNet(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float): @@ -23,7 +44,7 @@ class LogisticNet(nn.Module): return self.net(x) -checkpoint = torch.load(MODEL_PATH, map_location="cpu") +checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR) encoder = SentenceTransformer(checkpoint["embedding_model"])