feat: add multi-label classifier and topic bucket file
This commit is contained in:
@@ -13,28 +13,30 @@ CORS(app, resources={r"/*": {"origins": "http://localhost:5173"}})
|
|||||||
# Global State
|
# Global State
|
||||||
posts_df = pd.read_json('posts.jsonl', lines=True)
|
posts_df = pd.read_json('posts.jsonl', lines=True)
|
||||||
comments_df = pd.read_json('comments.jsonl', lines=True)
|
comments_df = pd.read_json('comments.jsonl', lines=True)
|
||||||
stat_obj = StatGen(posts_df, comments_df)
|
domain_topics = open("topic_buckets.txt").read().splitlines()
|
||||||
|
stat_obj = StatGen(posts_df, comments_df, domain_topics)
|
||||||
|
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
def upload_data():
|
def upload_data():
|
||||||
if "posts" not in request.files or "comments" not in request.files:
|
if "posts" not in request.files or "comments" not in request.files or "topics" not in request.form:
|
||||||
return jsonify({"error": "Missing posts or comments file"}), 400
|
return jsonify({"error": "Missing required files or form data"}), 400
|
||||||
|
|
||||||
post_file = request.files["posts"]
|
post_file = request.files["posts"]
|
||||||
comment_file = request.files["comments"]
|
comment_file = request.files["comments"]
|
||||||
|
topic_file = request.form["topics"]
|
||||||
|
|
||||||
if post_file.filename == "" or comment_file.filename == "":
|
if post_file.filename == "" or comment_file.filename == "" or topic_file == "":
|
||||||
return jsonify({"error": "Empty filename"}), 400
|
return jsonify({"error": "Empty filename"}), 400
|
||||||
|
|
||||||
if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl'):
|
if not post_file.filename.endswith('.jsonl') or not comment_file.filename.endswith('.jsonl') or not topic_file.endswith('.txt'):
|
||||||
return jsonify({"error": "Invalid file type. Only .jsonl files are allowed."}), 400
|
return jsonify({"error": "Invalid file type. Only .jsonl and .txt files are allowed."}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
global stat_obj
|
global stat_obj
|
||||||
|
|
||||||
posts_df = pd.read_json(post_file, lines=True)
|
posts_df = pd.read_json(post_file, lines=True)
|
||||||
comments_df = pd.read_json(comment_file, lines=True)
|
comments_df = pd.read_json(comment_file, lines=True)
|
||||||
stat_obj = StatGen(posts_df, comments_df)
|
stat_obj = StatGen(posts_df, comments_df, topic_file.splitlines())
|
||||||
return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200
|
return jsonify({"message": "File uploaded successfully", "event_count": len(stat_obj.df)}), 200
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400
|
return jsonify({"error": f"Failed to read JSONL file: {str(e)}"}), 400
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from keybert import KeyBERT
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
|
|
||||||
|
|
||||||
def add_emotion_cols(df: pd.DataFrame, content_col: str) -> None:
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=0 if torch.cuda.is_available() else 1)
|
||||||
|
|
||||||
|
def add_emotion_cols(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
content_col: str
|
||||||
|
) -> None:
|
||||||
emotion_classifier = pipeline(
|
emotion_classifier = pipeline(
|
||||||
"text-classification",
|
"text-classification",
|
||||||
model="j-hartmann/emotion-english-distilroberta-base",
|
model="j-hartmann/emotion-english-distilroberta-base",
|
||||||
@@ -31,16 +36,32 @@ def add_emotion_cols(df: pd.DataFrame, content_col: str) -> None:
|
|||||||
for row in results
|
for row in results
|
||||||
]
|
]
|
||||||
|
|
||||||
def add_topic_col(df: pd.DataFrame, content_col: str):
|
def add_topic_col(
|
||||||
kw_model = KeyBERT(model=sentence_model)
|
df: pd.DataFrame,
|
||||||
|
content_col: str,
|
||||||
texts = df[content_col].fillna("").astype(str).tolist()
|
domain_topics: list[str],
|
||||||
|
confidence_threshold: float = 0.15
|
||||||
raw_results = kw_model.extract_keywords(
|
) -> None:
|
||||||
texts,
|
topic_embeddings = model.encode(
|
||||||
keyphrase_ngram_range=(1, 1),
|
domain_topics,
|
||||||
stop_words='english',
|
normalize_embeddings=True,
|
||||||
top_n=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
df['theme'] = [res[0][0] if len(res) > 0 else None for res in raw_results]
|
texts = df[content_col].astype(str).tolist()
|
||||||
|
text_embeddings = model.encode(
|
||||||
|
texts,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Similarity
|
||||||
|
sims = cosine_similarity(text_embeddings, topic_embeddings)
|
||||||
|
|
||||||
|
# Best match
|
||||||
|
best_idx = sims.argmax(axis=1)
|
||||||
|
best_score = sims.max(axis=1)
|
||||||
|
|
||||||
|
df["topic"] = [domain_topics[i] for i in best_idx]
|
||||||
|
df["topic_confidence"] = best_score
|
||||||
|
df.loc[df["topic_confidence"] < confidence_threshold, "topic"] = "Misc"
|
||||||
|
|
||||||
|
return df
|
||||||
@@ -21,12 +21,13 @@ nltk.download('stopwords')
|
|||||||
EXCLUDE_WORDS = set(stopwords.words('english')) | DOMAIN_STOPWORDS
|
EXCLUDE_WORDS = set(stopwords.words('english')) | DOMAIN_STOPWORDS
|
||||||
|
|
||||||
class StatGen:
|
class StatGen:
|
||||||
def __init__(self, posts_df: pd.DataFrame, comments_df: pd.DataFrame) -> None:
|
def __init__(self, posts_df: pd.DataFrame, comments_df: pd.DataFrame, domain_topics: list) -> None:
|
||||||
posts_df["type"] = "post"
|
posts_df["type"] = "post"
|
||||||
posts_df["parent_id"] = None
|
posts_df["parent_id"] = None
|
||||||
|
|
||||||
comments_df["type"] = "comment"
|
comments_df["type"] = "comment"
|
||||||
comments_df["parent_id"] = comments_df.get("post_id")
|
comments_df["parent_id"] = comments_df.get("post_id")
|
||||||
|
self.domain_topics = domain_topics
|
||||||
|
|
||||||
self.df = pd.concat([posts_df, comments_df])
|
self.df = pd.concat([posts_df, comments_df])
|
||||||
self._add_extra_cols(self.df)
|
self._add_extra_cols(self.df)
|
||||||
@@ -41,7 +42,7 @@ class StatGen:
|
|||||||
df["weekday"] = df["dt"].dt.day_name()
|
df["weekday"] = df["dt"].dt.day_name()
|
||||||
|
|
||||||
add_emotion_cols(df, "content")
|
add_emotion_cols(df, "content")
|
||||||
add_topic_col(df, "content")
|
add_topic_col(df, "content", self.domain_topics)
|
||||||
|
|
||||||
def _tokenize(self, text: str):
|
def _tokenize(self, text: str):
|
||||||
tokens = re.findall(r"\b[a-z]{3,}\b", text)
|
tokens = re.findall(r"\b[a-z]{3,}\b", text)
|
||||||
|
|||||||
Reference in New Issue
Block a user