Add downloading from hugging face
This commit is contained in:
@@ -3,10 +3,31 @@ from sentence_transformers import SentenceTransformer
|
||||
from fastapi import FastAPI
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
import os
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
MODEL_PATH = "logreg_classifier.pt"
|
||||
HF_REPO_ID = "WillJeynes/LLMsForDisinformationAnalysis-Regression"
|
||||
MODEL_FILENAME = "logreg_classifier.pt"
|
||||
CACHE_DIR = "./model_cache"
|
||||
|
||||
|
||||
def load_checkpoint(repo_id: str, filename: str, cache_dir: str) -> dict:
|
||||
local_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(local_path):
|
||||
print(f"Downloading {filename} from {repo_id}...")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
print(f"Saved to {downloaded}")
|
||||
else:
|
||||
print(f"Using cached model at {local_path}")
|
||||
return torch.load(local_path, map_location="cpu")
|
||||
|
||||
|
||||
class LogisticNet(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float):
|
||||
@@ -23,7 +44,7 @@ class LogisticNet(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
checkpoint = torch.load(MODEL_PATH, map_location="cpu")
|
||||
checkpoint = load_checkpoint(HF_REPO_ID, MODEL_FILENAME, CACHE_DIR)
|
||||
|
||||
encoder = SentenceTransformer(checkpoint["embedding_model"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user