feat: add descriptions to topics to improve accuracy
Also upgraded to more powerful model
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from transformers import pipeline
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2", device=0 if torch.cuda.is_available() else 1)
|
||||
model = SentenceTransformer("all-mpnet-base-v2", device=0 if torch.cuda.is_available() else 1)
|
||||
|
||||
def add_emotion_cols(
|
||||
df: pd.DataFrame,
|
||||
@@ -38,16 +37,28 @@ def add_emotion_cols(
|
||||
|
||||
def add_topic_col(
|
||||
df: pd.DataFrame,
|
||||
title_col: str,
|
||||
content_col: str,
|
||||
domain_topics: list[str],
|
||||
confidence_threshold: float = 0.15
|
||||
domain_topics: dict,
|
||||
confidence_threshold: float = 0.20
|
||||
) -> None:
|
||||
|
||||
topic_labels = list(domain_topics.keys())
|
||||
topic_texts = list(domain_topics.values())
|
||||
|
||||
topic_embeddings = model.encode(
|
||||
domain_topics,
|
||||
topic_texts,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
|
||||
texts = df[content_col].astype(str).tolist()
|
||||
titles = df[title_col].fillna("").astype(str)
|
||||
contents = df[content_col].fillna("").astype(str)
|
||||
|
||||
texts = [
|
||||
f"{title}. {content}" if title else content
|
||||
for title, content in zip(titles, contents)
|
||||
]
|
||||
|
||||
text_embeddings = model.encode(
|
||||
texts,
|
||||
normalize_embeddings=True,
|
||||
@@ -60,8 +71,9 @@ def add_topic_col(
|
||||
best_idx = sims.argmax(axis=1)
|
||||
best_score = sims.max(axis=1)
|
||||
|
||||
df["topic"] = [domain_topics[i] for i in best_idx]
|
||||
df["topic"] = [topic_labels[i] for i in best_idx]
|
||||
df["topic_confidence"] = best_score
|
||||
|
||||
df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc"
|
||||
|
||||
return df
|
||||
Reference in New Issue
Block a user