feat: add descriptions to topics to improve accuracy

Also upgraded to more powerful model
This commit is contained in:
2026-02-08 15:10:11 +00:00
parent b019885b2f
commit a9d63c7041
4 changed files with 27 additions and 13 deletions

2
.gitignore vendored
View File

@@ -2,9 +2,9 @@
__pycache__/ __pycache__/
*.pyc *.pyc
*.jsonl *.jsonl
*.json
*.code-workspace *.code-workspace
.env .env
topic_buckets.txt
# React App Vite # React App Vite
node_modules/ node_modules/

View File

@@ -4,6 +4,7 @@ from server.stat_gen import StatGen
import pandas as pd import pandas as pd
import traceback import traceback
import json
app = Flask(__name__) app = Flask(__name__)
@@ -13,7 +14,8 @@ CORS(app, resources={r"/*": {"origins": "http://localhost:5173"}})
# Global State # Global State
posts_df = pd.read_json('posts.jsonl', lines=True) posts_df = pd.read_json('posts.jsonl', lines=True)
comments_df = pd.read_json('comments.jsonl', lines=True) comments_df = pd.read_json('comments.jsonl', lines=True)
domain_topics = open("topic_buckets.txt").read().splitlines() with open("topic_buckets.json", "r", encoding="utf-8") as f:
domain_topics = json.load(f)
stat_obj = StatGen(posts_df, comments_df, domain_topics) stat_obj = StatGen(posts_df, comments_df, domain_topics)
@app.route('/upload', methods=['POST']) @app.route('/upload', methods=['POST'])
@@ -28,15 +30,15 @@ def upload_data():
if post_file.filename == "" or comment_file.filename == "" or topic_file == "": if post_file.filename == "" or comment_file.filename == "" or topic_file == "":
return jsonify({"error": "Empty filename"}), 400 return jsonify({"error": "Empty filename"}), 400
if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.txt'): if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.json'):
return jsonify({"error": "Invalid file type. Only .jsonl and .txt files are allowed."}), 400 return jsonify({"error": "Invalid file type. Only .jsonl and .json files are allowed."}), 400
try: try:
global stat_obj global stat_obj
posts_df = pd.read_json(post_file, lines=True) posts_df = pd.read_json(post_file, lines=True)
comments_df = pd.read_json(comment_file, lines=True) comments_df = pd.read_json(comment_file, lines=True)
stat_obj = StatGen(posts_df, comments_df, topic_file.splitlines()) stat_obj = StatGen(posts_df, comments_df, json.load(topic_file))
return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200 return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200
except ValueError as e: except ValueError as e:
return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400 return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400

View File

@@ -1,13 +1,12 @@
import torch import torch
import pandas as pd import pandas as pd
import numpy as np
from transformers import pipeline from transformers import pipeline
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
model = SentenceTransformer("all-MiniLM-L6-v2", device=0 if torch.cuda.is_available() else 1) model = SentenceTransformer("all-mpnet-base-v2", device=0 if torch.cuda.is_available() else 1)
def add_emotion_cols( def add_emotion_cols(
df: pd.DataFrame, df: pd.DataFrame,
@@ -38,16 +37,28 @@ def add_emotion_cols(
def add_topic_col( def add_topic_col(
df: pd.DataFrame, df: pd.DataFrame,
title_col: str,
content_col: str, content_col: str,
domain_topics: list[str], domain_topics: dict,
confidence_threshold: float = 0.15 confidence_threshold: float = 0.20
) -> None: ) -> None:
topic_labels = list(domain_topics.keys())
topic_texts = list(domain_topics.values())
topic_embeddings = model.encode( topic_embeddings = model.encode(
domain_topics, topic_texts,
normalize_embeddings=True, normalize_embeddings=True,
) )
texts = df[content_col].astype(str).tolist() titles = df[title_col].fillna("").astype(str)
contents = df[content_col].fillna("").astype(str)
texts = [
f"{title}. {content}" if title else content
for title, content in zip(titles, contents)
]
text_embeddings = model.encode( text_embeddings = model.encode(
texts, texts,
normalize_embeddings=True, normalize_embeddings=True,
@@ -60,8 +71,9 @@ def add_topic_col(
best_idx = sims.argmax(axis=1) best_idx = sims.argmax(axis=1)
best_score = sims.max(axis=1) best_score = sims.max(axis=1)
df["topic"] = [domain_topics[i] for i in best_idx] df["topic"] = [topic_labels[i] for i in best_idx]
df["topic_confidence"] = best_score df["topic_confidence"] = best_score
df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc" df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc"
return df return df

View File

@@ -42,7 +42,7 @@ class StatGen:
df["weekday"] = df["dt"].dt.day_name() df["weekday"] = df["dt"].dt.day_name()
add_emotion_cols(df, "content") add_emotion_cols(df, "content")
add_topic_col(df, "content", self.domain_topics) add_topic_col(df, "title", "content", self.domain_topics)
def _tokenize(self, text: str): def _tokenize(self, text: str):
tokens = re.findall(r"\b[a-z]{3,}\b", text) tokens = re.findall(r"\b[a-z]{3,}\b", text)