diff --git a/server/stat_gen.py b/server/stat_gen.py index e781580..980990e 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -2,11 +2,19 @@ import pandas as pd import re import nltk import datetime +import torch from nltk.corpus import stopwords from collections import Counter +from transformers import pipeline -from pprint import pprint +emotion_classifier = pipeline( + "text-classification", + model="j-hartmann/emotion-english-distilroberta-base", + top_k=None, + truncation=True, + device=0 if torch.cuda.is_available() else -1 +) DOMAIN_STOPWORDS = { "www", "https", "http", @@ -30,17 +38,37 @@ class StatGen: comments_df["parent_id"] = comments_df.get("post_id") self.df = pd.concat([posts_df, comments_df]) - self._add_date_cols(self.df) + self._add_extra_cols(self.df) + self._add_emotion_cols(self.df) self.original_df = self.df.copy(deep=True) ## Private Methods - def _add_date_cols(self, df: pd.DataFrame) -> None: + def _add_extra_cols(self, df: pd.DataFrame) -> None: df['date'] = pd.to_datetime(df['timestamp'], unit='s').dt.date df["dt"] = pd.to_datetime(df["timestamp"], unit="s", utc=True) df["hour"] = df["dt"].dt.hour df["weekday"] = df["dt"].dt.day_name() + def _add_emotion_cols(self, df: pd.DataFrame) -> None: + texts = df["content"].astype(str).str.slice(0, 512).tolist() + + results = emotion_classifier( + texts, + batch_size=64 + ) + + labels = [r["label"] for r in results[0]] + + for label in labels: + df[f"emotion_{label}"] = [ + next(item["score"] for item in row if item["label"] == label) + for row in results + ] + + # strongest emotion per row (much more meaningful than sums) + df["emotion_intensity"] = df.filter(like="emotion_").max(axis=1) + def _tokenize(self, text: str): tokens = re.findall(r"\b[a-z]{3,}\b", text) return [t for t in tokens if t not in EXCLUDE_WORDS]