feat: add descriptions to topics to improve accuracy
Also upgraded to more powerful model
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,9 +2,9 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
*.json
|
||||||
*.code-workspace
|
*.code-workspace
|
||||||
.env
|
.env
|
||||||
topic_buckets.txt
|
|
||||||
|
|
||||||
# React App Vite
|
# React App Vite
|
||||||
node_modules/
|
node_modules/
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from server.stat_gen import StatGen
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import traceback
|
import traceback
|
||||||
|
import json
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@@ -13,7 +14,8 @@ CORS(app, resources={r"/*": {"origins": "http://localhost:5173"}})
|
|||||||
# Global State
|
# Global State
|
||||||
posts_df = pd.read_json('posts.jsonl', lines=True)
|
posts_df = pd.read_json('posts.jsonl', lines=True)
|
||||||
comments_df = pd.read_json('comments.jsonl', lines=True)
|
comments_df = pd.read_json('comments.jsonl', lines=True)
|
||||||
domain_topics = open("topic_buckets.txt").read().splitlines()
|
with open("topic_buckets.json", "r", encoding="utf-8") as f:
|
||||||
|
domain_topics = json.load(f)
|
||||||
stat_obj = StatGen(posts_df, comments_df, domain_topics)
|
stat_obj = StatGen(posts_df, comments_df, domain_topics)
|
||||||
|
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
@@ -28,15 +30,15 @@ def upload_data():
|
|||||||
if post_file.filename == "" or comment_file.filename == "" or topic_file == "":
|
if post_file.filename == "" or comment_file.filename == "" or topic_file == "":
|
||||||
return jsonify({"error": "Empty filename"}), 400
|
return jsonify({"error": "Empty filename"}), 400
|
||||||
|
|
||||||
if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.txt'):
|
if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.json'):
|
||||||
return jsonify({"error": "Invalid file type. Only .jsonl and .txt files are allowed."}), 400
|
return jsonify({"error": "Invalid file type. Only .jsonl and .json files are allowed."}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
global stat_obj
|
global stat_obj
|
||||||
|
|
||||||
posts_df = pd.read_json(post_file, lines=True)
|
posts_df = pd.read_json(post_file, lines=True)
|
||||||
comments_df = pd.read_json(comment_file, lines=True)
|
comments_df = pd.read_json(comment_file, lines=True)
|
||||||
stat_obj = StatGen(posts_df, comments_df, topic_file.splitlines())
|
stat_obj = StatGen(posts_df, comments_df, json.load(topic_file))
|
||||||
return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200
|
return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400
|
return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
model = SentenceTransformer("all-MiniLM-L6-v2", device=0 if torch.cuda.is_available() else 1)
|
model = SentenceTransformer("all-mpnet-base-v2", device=0 if torch.cuda.is_available() else 1)
|
||||||
|
|
||||||
def add_emotion_cols(
|
def add_emotion_cols(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
@@ -38,16 +37,28 @@ def add_emotion_cols(
|
|||||||
|
|
||||||
def add_topic_col(
|
def add_topic_col(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
|
title_col: str,
|
||||||
content_col: str,
|
content_col: str,
|
||||||
domain_topics: list[str],
|
domain_topics: dict,
|
||||||
confidence_threshold: float = 0.15
|
confidence_threshold: float = 0.20
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
topic_labels = list(domain_topics.keys())
|
||||||
|
topic_texts = list(domain_topics.values())
|
||||||
|
|
||||||
topic_embeddings = model.encode(
|
topic_embeddings = model.encode(
|
||||||
domain_topics,
|
topic_texts,
|
||||||
normalize_embeddings=True,
|
normalize_embeddings=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
texts = df[content_col].astype(str).tolist()
|
titles = df[title_col].fillna("").astype(str)
|
||||||
|
contents = df[content_col].fillna("").astype(str)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
f"{title}. {content}" if title else content
|
||||||
|
for title, content in zip(titles, contents)
|
||||||
|
]
|
||||||
|
|
||||||
text_embeddings = model.encode(
|
text_embeddings = model.encode(
|
||||||
texts,
|
texts,
|
||||||
normalize_embeddings=True,
|
normalize_embeddings=True,
|
||||||
@@ -60,8 +71,9 @@ def add_topic_col(
|
|||||||
best_idx = sims.argmax(axis=1)
|
best_idx = sims.argmax(axis=1)
|
||||||
best_score = sims.max(axis=1)
|
best_score = sims.max(axis=1)
|
||||||
|
|
||||||
df["topic"] = [domain_topics[i] for i in best_idx]
|
df["topic"] = [topic_labels[i] for i in best_idx]
|
||||||
df["topic_confidence"] = best_score
|
df["topic_confidence"] = best_score
|
||||||
|
|
||||||
df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc"
|
df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc"
|
||||||
|
|
||||||
return df
|
return df
|
||||||
@@ -42,7 +42,7 @@ class StatGen:
|
|||||||
df["weekday"] = df["dt"].dt.day_name()
|
df["weekday"] = df["dt"].dt.day_name()
|
||||||
|
|
||||||
add_emotion_cols(df, "content")
|
add_emotion_cols(df, "content")
|
||||||
add_topic_col(df, "content", self.domain_topics)
|
add_topic_col(df, "title", "content", self.domain_topics)
|
||||||
|
|
||||||
def _tokenize(self, text: str):
|
def _tokenize(self, text: str):
|
||||||
tokens = re.findall(r"\b[a-z]{3,}\b", text)
|
tokens = re.findall(r"\b[a-z]{3,}\b", text)
|
||||||
|
|||||||
Reference in New Issue
Block a user