From 5ea71023b55922cc554365344b430cd6fa5d0df6 Mon Sep 17 00:00:00 2001 From: Dylan De Faoite Date: Mon, 2 Mar 2026 18:29:09 +0000 Subject: [PATCH] refactor: move query parameter extraction function out of flask app --- server/app.py | 66 +++++++--------------------------------------- server/stat_gen.py | 16 ++++++----- server/utils.py | 50 +++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 63 deletions(-) create mode 100644 server/utils.py diff --git a/server/app.py b/server/app.py index 64f4f53..eef9d7a 100644 --- a/server/app.py +++ b/server/app.py @@ -16,6 +16,7 @@ from server.stat_gen import StatGen from server.dataset_processor import DatasetProcessor from db.database import PostgresConnector from server.auth import AuthManager +from server.utils import get_request_filters, parse_datetime_filter import pandas as pd import traceback @@ -43,56 +44,6 @@ auth_manager = AuthManager(db, bcrypt) stat_gen = StatGen() - -def _parse_datetime_filter(value): - if not value: - return None - - try: - return datetime.datetime.fromisoformat(value) - except ValueError: - try: - return datetime.datetime.fromtimestamp(float(value)) - except ValueError as err: - raise ValueError( - "Date filters must be ISO-8601 strings or Unix timestamps" - ) from err - - -def _get_request_filters() -> dict: - filters = {} - - search_query = request.args.get("search_query") or request.args.get("query") - if search_query: - filters["search_query"] = search_query - - start_date = _parse_datetime_filter( - request.args.get("start_date") or request.args.get("start") - ) - if start_date: - filters["start_date"] = start_date - - end_date = _parse_datetime_filter( - request.args.get("end_date") or request.args.get("end") - ) - if end_date: - filters["end_date"] = end_date - - data_sources = request.args.getlist("data_sources") - if not data_sources: - data_sources = request.args.getlist("sources") - - if len(data_sources) == 1 and "," in data_sources[0]: - data_sources = [ - source.strip() for source in data_sources[0].split(",") if source.strip() - ] - - if data_sources: - filters["data_sources"] = data_sources - - return filters - - @app.route("/register", methods=["POST"]) def register_user(): data = request.get_json() @@ -212,7 +163,8 @@ def get_dataset(dataset_id): if dataset_content.empty: return jsonify({"error": "Dataset content not found"}), 404 - return jsonify(dataset_content.to_dict(orient="records")), 200 + filters = get_request_filters() + return jsonify(stat_gen.filter_dataset(dataset_content, filters)), 200 @app.route("/dataset//content", methods=["GET"]) @@ -226,7 +178,7 @@ def content_endpoint(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify(stat_gen.get_content_analysis(dataset_content, filters)), 200 except ValueError as e: return jsonify({"error": f"Malformed or missing data: {str(e)}"}), 400 @@ -247,7 +199,7 @@ def get_summary(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify(stat_gen.summary(dataset_content, filters)), 200 except ValueError as e: return jsonify({"error": f"Malformed or missing data: {str(e)}"}), 400 @@ -268,7 +220,7 @@ def get_time_analysis(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify(stat_gen.get_time_analysis(dataset_content, filters)), 200 except ValueError as e: return jsonify({"error": f"Malformed or missing data: {str(e)}"}), 400 @@ -289,7 +241,7 @@ def get_user_analysis(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify(stat_gen.get_user_analysis(dataset_content, filters)), 200 except ValueError as e: return jsonify({"error": f"Malformed or missing data: {str(e)}"}), 400 @@ -310,7 +262,7 @@ def get_cultural_analysis(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify(stat_gen.get_cultural_analysis(dataset_content, filters)), 200 except ValueError as e: return jsonify({"error": f"Malformed or missing data: {str(e)}"}), 400 @@ -331,7 +283,7 @@ def get_interaction_analysis(dataset_id): dataset_content = db.get_dataset_content(dataset_id) try: - filters = _get_request_filters() + filters = get_request_filters() return jsonify( stat_gen.get_interactional_analysis(dataset_content, filters) ), 200 diff --git a/server/stat_gen.py b/server/stat_gen.py index cb8dbaa..2ea5ac1 100644 --- a/server/stat_gen.py +++ b/server/stat_gen.py @@ -55,13 +55,15 @@ class StatGen: if search_query: mask = ( filtered_df["content"].str.contains(search_query, case=False, na=False) - | filtered_df["author"] - .str.contains(search_query, case=False, na=False) - .fillna(False) - | filtered_df["title"] - .str.contains(search_query, case=False, na=False, regex=False) - .fillna(False) + | filtered_df["author"].str.contains(search_query, case=False, na=False) ) + + # Only include title if the column exists + if "title" in filtered_df.columns: + mask = mask | filtered_df["title"].str.contains( + search_query, case=False, na=False, regex=False + ) + filtered_df = filtered_df[mask] if start_date_filter: @@ -76,6 +78,8 @@ class StatGen: return filtered_df ## Public Methods + def filter_dataset(self, df: pd.DataFrame, filters: dict | None = None) -> dict: + return self._prepare_filtered_df(df, filters).to_dict(orient="records") def get_time_analysis(self, df: pd.DataFrame, filters: dict | None = None) -> dict: filtered_df = self._prepare_filtered_df(df, filters) diff --git a/server/utils.py b/server/utils.py new file mode 100644 index 0000000..078d1e7 --- /dev/null +++ b/server/utils.py @@ -0,0 +1,50 @@ +import datetime +from flask import request + +def parse_datetime_filter(value): + if not value: + return None + + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + try: + return datetime.datetime.fromtimestamp(float(value)) + except ValueError as err: + raise ValueError( + "Date filters must be ISO-8601 strings or Unix timestamps" + ) from err + + +def get_request_filters() -> dict: + filters = {} + + search_query = request.args.get("search_query") or request.args.get("query") + if search_query: + filters["search_query"] = search_query + + start_date = parse_datetime_filter( + request.args.get("start_date") or request.args.get("start") + ) + if start_date: + filters["start_date"] = start_date + + end_date = parse_datetime_filter( + request.args.get("end_date") or request.args.get("end") + ) + if end_date: + filters["end_date"] = end_date + + data_sources = request.args.getlist("data_sources") + if not data_sources: + data_sources = request.args.getlist("sources") + + if len(data_sources) == 1 and "," in data_sources[0]: + data_sources = [ + source.strip() for source in data_sources[0].split(",") if source.strip() + ] + + if data_sources: + filters["data_sources"] = data_sources + + return filters \ No newline at end of file