diff --git a/server/stat_gen.py b/server/stat_gen.py index 35abb08..f2f5ae6 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -192,8 +192,24 @@ class StatGen: .reset_index(drop=True) ) + # avearge emotion by topic (excluding neutral) + emotion_cols = [ + col for col in self.df.columns + if col.startswith("emotion_") and col != "emotion_neutral" + ] + + avg_emotion_by_topic = ( + self.df[ + (self.df["topic"] != "Misc") + ] + .groupby("topic")[emotion_cols] + .mean() + .reset_index() + ) + return { - "word_frequencies": word_frequencies.to_dict(orient='records') + "word_frequencies": word_frequencies.to_dict(orient='records'), + "average_emotion_by_topic": avg_emotion_by_topic.to_dict(orient='records') } def user_analysis(self) -> dict: