diff --git a/frontend/src/pages/Stats.tsx b/frontend/src/pages/Stats.tsx index af1bed4..53deb07 100644 --- a/frontend/src/pages/Stats.tsx +++ b/frontend/src/pages/Stats.tsx @@ -67,7 +67,7 @@ const StatPage = () => { }; const onSearch = () => { - const query = inputRef.current?.value ?? ""; // read input only on click + const query = inputRef.current?.value ?? ""; axios.post("http://localhost:5000/filter/search", { query: query }) diff --git a/server/app.py b/server/app.py index 15dc625..c926c3c 100644 --- a/server/app.py +++ b/server/app.py @@ -115,6 +115,27 @@ def filter_time(): except Exception: return jsonify({"error": "Invalid datetime format"}), 400 +@app.route('/filter/sources', methods=["POST"]) +def filter_sources(): + if stat_obj is None: + return jsonify({"error": "No data uploaded"}), 400 + + data = request.get_json(silent=True) + if not data: + return jsonify({"error": "Invalid or missing JSON body"}), 400 + + if "sources" not in data: + return jsonify({"error": "Ensure sources hash map is in 'sources' key"}), 400 + + try: + filtered_df = stat_obj.filter_data_sources(data["sources"]) + return jsonify(filtered_df), 200 + except ValueError: + return jsonify({"error": "Please enable at least one data source"}), 400 + except Exception as e: + return jsonify({"error": "An unexpected server error occured: " + str(e)}), 500 + + @app.route('/filter/reset', methods=["GET"]) def reset_dataset(): if stat_obj is None: diff --git a/server/stat_gen.py b/server/stat_gen.py index eaf4b3f..4960b20 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -125,7 +125,7 @@ class StatGen: "word_frequencies": word_frequencies.to_dict(orient='records') } - def search(self, search_query: str) -> pd.DataFrame: + def search(self, search_query: str) -> dict: self.df = self.df[ self.df["content"].str.contains(search_query) ] @@ -135,7 +135,7 @@ class StatGen: "data": self.df.to_dict(orient="records") } - def set_time_range(self, start: datetime.datetime, end: datetime.datetime): + def set_time_range(self, start: datetime.datetime, end: datetime.datetime) -> dict: self.df = self.df[ (self.df["dt"] >= start) & (self.df["dt"] <= end) @@ -146,6 +146,23 @@ class StatGen: "data": self.df.to_dict(orient="records") } + """ + Input is a hash map (source_name: str -> enabled: bool) + """ + def filter_data_sources(self, data_sources: dict) -> dict: + enabled_sources = [src for src, enabled in data_sources.items() if enabled] + + if not enabled_sources: + raise ValueError("Please choose at least one data source") + + self.df = self.df[self.df["source"].isin(enabled_sources)] + + return { + "rows": len(self.df), + "data": self.df.to_dict(orient="records") + } + + def reset_dataset(self) -> None: self.df = self.original_df.copy(deep=True)