perf: use gpu on topic AI & move Model Init into functions

By passing model initialisation into the function itself, the model is unloaded from memory after completion which avoids OOM errors
This commit is contained in:
2026-02-05 19:11:51 +00:00
parent ed70035fd4
commit 4abbd0643e

View File

@@ -3,18 +3,19 @@ import pandas as pd
from transformers import pipeline from transformers import pipeline
from keybert import KeyBERT from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
kw_model = KeyBERT(model="all-MiniLM-L6-v2") sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
emotion_classifier = pipeline( def add_emotion_cols(df: pd.DataFrame, content_col: str) -> None:
"text-classification", emotion_classifier = pipeline(
model="j-hartmann/emotion-english-distilroberta-base", "text-classification",
top_k=None, model="j-hartmann/emotion-english-distilroberta-base",
truncation=True, top_k=None,
device=0 if torch.cuda.is_available() else -1 truncation=True,
) device=0 if torch.cuda.is_available() else -1
)
def add_emotion_cols(df: pd.Dataframe, content_col: str) -> None:
texts = df[content_col].astype(str).str.slice(0, 512).tolist() texts = df[content_col].astype(str).str.slice(0, 512).tolist()
results = emotion_classifier( results = emotion_classifier(
@@ -30,17 +31,16 @@ def add_emotion_cols(df: pd.Dataframe, content_col: str) -> None:
for row in results for row in results
] ]
def add_topic_col(df: pd.DataFrame, content_col: str, top_n: int = 3) -> None: def add_topic_col(df: pd.DataFrame, content_col: str):
topics = [] kw_model = KeyBERT(model=sentence_model)
for text in df["content"].astype(str): texts = df[content_col].fillna("").astype(str).tolist()
keywords = kw_model.extract_keywords(
text,
keyphrase_ngram_range=(1, 3),
stop_words="english",
top_n=top_n
)
topics.append([kw for kw, _ in keywords]) raw_results = kw_model.extract_keywords(
texts,
keyphrase_ngram_range=(1, 1),
stop_words='english',
top_n=1
)
df["topics"] = topics df['theme'] = [res[0][0] if len(res) > 0 else None for res in raw_results]