diff --git a/server/nlp_processor.py b/server/nlp_processor.py new file mode 100644 index 0000000..84c79ae --- /dev/null +++ b/server/nlp_processor.py @@ -0,0 +1,28 @@ +import torch +import pandas as pd + +from transformers import pipeline + +emotion_classifier = pipeline( + "text-classification", + model="j-hartmann/emotion-english-distilroberta-base", + top_k=None, + truncation=True, + device=0 if torch.cuda.is_available() else -1 +) + +def add_emotion_cols(df: pd.Dataframe, content_col: str) -> None: + texts = df[content_col].astype(str).str.slice(0, 512).tolist() + + results = emotion_classifier( + texts, + batch_size=64 + ) + + labels = [r["label"] for r in results[0]] + + for label in labels: + df[f"emotion_{label}"] = [ + next(item["score"] for item in row if item["label"] == label) + for row in results + ] \ No newline at end of file diff --git a/server/stat_gen.py b/server/stat_gen.py index ee0624f..e760b5b 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -2,19 +2,10 @@ import pandas as pd import re import nltk import datetime -import torch from nltk.corpus import stopwords from collections import Counter -from transformers import pipeline - -emotion_classifier = pipeline( - "text-classification", - model="j-hartmann/emotion-english-distilroberta-base", - top_k=None, - truncation=True, - device=0 if torch.cuda.is_available() else -1 -) +from server.nlp_processor import add_emotion_cols DOMAIN_STOPWORDS = { "www", "https", "http", @@ -39,7 +30,6 @@ class StatGen: self.df = pd.concat([posts_df, comments_df]) self._add_extra_cols(self.df) - self._add_emotion_cols(self.df) self.original_df = self.df.copy(deep=True) @@ -49,22 +39,9 @@ class StatGen: df["dt"] = pd.to_datetime(df["timestamp"], unit="s", utc=True) df["hour"] = df["dt"].dt.hour df["weekday"] = df["dt"].dt.day_name() + + add_emotion_cols(df, "content") - def _add_emotion_cols(self, df: pd.DataFrame) -> None: - texts = df["content"].astype(str).str.slice(0, 512).tolist() - - results = emotion_classifier( - texts, - batch_size=64 - ) - - labels = [r["label"] for r in results[0]] - - for label in labels: - df[f"emotion_{label}"] = [ - next(item["score"] for item in row if item["label"] == label) - for row in results - ] def _tokenize(self, text: str): tokens = re.findall(r"\b[a-z]{3,}\b", text)