diff --git a/server/app.py b/server/app.py index e85adb0..48e757a 100644 --- a/server/app.py +++ b/server/app.py @@ -13,28 +13,30 @@ CORS(app, resources={r"/*": {"origins": "http://localhost:5173"}}) # Global State posts_df = pd.read_json('posts.jsonl', lines=True) comments_df = pd.read_json('comments.jsonl', lines=True) -stat_obj = StatGen(posts_df, comments_df) +domain_topics = open("topic_buckets.txt").read().splitlines() +stat_obj = StatGen(posts_df, comments_df, domain_topics) @app.route('/upload', methods=['POST']) def upload_data(): - if "posts" not in request.files or "comments" not in request.files: - return jsonify({"error": "Missing posts or comments file"}), 400 + if "posts" not in request.files or "comments" not in request.files or "topics" not in request.form: + return jsonify({"error": "Missing required files or form data"}), 400 post_file = request.files["posts"] comment_file = request.files["comments"] + topic_file = request.form["topics"] - if post_file.filename == "" or comment_file.filename == "": + if post_file.filename == "" or comment_file.filename == "" or topic_file == "": return jsonify({"error": "Empty filename"}), 400 - if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl'): - return jsonify({"error": "Invalid file type. Only .jsonl files are allowed."}), 400 + if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.txt'): + return jsonify({"error": "Invalid file type. Only .jsonl and .txt files are allowed."}), 400 try: global stat_obj posts_df = pd.read_json(post_file, lines=True) comments_df = pd.read_json(comment_file, lines=True) - stat_obj = StatGen(posts_df, comments_df) + stat_obj = StatGen(posts_df, comments_df, topic_file.splitlines()) return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200 except ValueError as e: return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400 diff --git a/server/nlp.py b/server/nlp.py index 494c30e..d5a04f3 100644 --- a/server/nlp.py +++ b/server/nlp.py @@ -1,13 +1,18 @@ import torch import pandas as pd +import numpy as np from transformers import pipeline -from keybert import KeyBERT from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity -sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda") -def add_emotion_cols(df: pd.DataFrame, content_col: str) -> None: +model = SentenceTransformer("all-MiniLM-L6-v2", device=0 if torch.cuda.is_available() else 1) + +def add_emotion_cols( + df: pd.DataFrame, + content_col: str + ) -> None: emotion_classifier = pipeline( "text-classification", model="j-hartmann/emotion-english-distilroberta-base", @@ -31,16 +36,32 @@ def add_emotion_cols(df: pd.DataFrame, content_col: str) -> None: for row in results ] -def add_topic_col(df: pd.DataFrame, content_col: str): - kw_model = KeyBERT(model=sentence_model) - - texts = df[content_col].fillna("").astype(str).tolist() - - raw_results = kw_model.extract_keywords( - texts, - keyphrase_ngram_range=(1, 1), - stop_words='english', - top_n=1 +def add_topic_col( + df: pd.DataFrame, + content_col: str, + domain_topics: list[str], + confidence_threshold: float = 0.15 + ) -> None: + topic_embeddings = model.encode( + domain_topics, + normalize_embeddings=True, ) - df['theme'] = [res[0][0] if len(res) > 0 else None for res in raw_results] \ No newline at end of file + texts = df[content_col].astype(str).tolist() + text_embeddings = model.encode( + texts, + normalize_embeddings=True, + ) + + # Similarity + sims = cosine_similarity(text_embeddings, topic_embeddings) + + # Best match + best_idx = sims.argmax(axis=1) + best_score = sims.max(axis=1) + + df["topic"] = [domain_topics[i] for i in best_idx] + df["topic_confidence"] = best_score + df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc" + + return df \ No newline at end of file diff --git a/server/stat_gen.py b/server/stat_gen.py index 9612d68..d2beed6 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -21,12 +21,13 @@ nltk.download('stopwords') EXCLUDE_WORDS = set(stopwords.words('english')) | DOMAIN_STOPWORDS class StatGen: - def __init__(self, posts_df: pd.DataFrame, comments_df: pd.DataFrame) -> None: + def __init__(self, posts_df: pd.DataFrame, comments_df: pd.DataFrame, domain_topics: list) -> None: posts_df["type"] = "post" posts_df["parent_id"] = None comments_df["type"] = "comment" comments_df["parent_id"] = comments_df.get("post_id") + self.domain_topics = domain_topics self.df = pd.concat([posts_df, comments_df]) self._add_extra_cols(self.df) @@ -41,7 +42,7 @@ class StatGen: df["weekday"] = df["dt"].dt.day_name() add_emotion_cols(df, "content") - add_topic_col(df, "content") + add_topic_col(df, "content", self.domain_topics) def _tokenize(self, text: str): tokens = re.findall(r"\b[a-z]{3,}\b", text)