diff --git a/server/app.py b/server/app.py index 3efa6eb..550537f 100644 --- a/server/app.py +++ b/server/app.py @@ -124,31 +124,33 @@ def get_dataset_sources(): @app.route("/datasets/scrape", methods=["POST"]) @jwt_required() def scrape_data(): - if "sources" not in request.form: - return jsonify({"error": "Data source names are required."}), 400 + data = request.get_json() + + if not data or "sources" not in data: + return jsonify({"error": "Sources must be provided"}), 400 user_id = int(get_jwt_identity()) - sources = request.form.getlist("sources") - limit = int(request.form.get("limit", max_fetch_limit)) + dataset_name = data["name"].strip() + source_configs = data["sources"] - dataset_name = request.form.get("name", "").strip() - search = request.form.get("search") - category = request.form.get("category") + if not isinstance(source_configs, list) or len(source_configs) == 0: + return jsonify({"error": "Sources must be a non-empty list"}), 400 - if limit > max_fetch_limit: - return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400 - - for source in sources: - if source not in connectors.keys(): - return jsonify({"error": "Source must exist"}), 400 - - limits = split_limit(limit, len(sources)) - per_source = dict(zip(sources, limits)) + # Light Validation + for source in source_configs: + if "name" not in source: + return jsonify({"error": "Each source must contain a name"}), 400 + if "limit" in source: + source["limit"] = int(source["limit"]) + dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, default_topic_list) - dataset_manager.set_dataset_status(dataset_id, "fetching", f"Data is being fetched from {str(sources)}") + dataset_manager.set_dataset_status(dataset_id, + "fetching", + f"Data is being fetched from {str(source["name"] + "," for source in source_configs)}" + ) try: - fetch_and_process_dataset.delay(dataset_id, per_source, search, category, default_topic_list) + fetch_and_process_dataset.delay(dataset_id, source_configs, default_topic_list) return jsonify( { diff --git a/server/queue/tasks.py b/server/queue/tasks.py index 7feaf9a..fd5237f 100644 --- a/server/queue/tasks.py +++ b/server/queue/tasks.py @@ -1,5 +1,5 @@ import pandas as pd -import json +import logging from server.queue.celery_app import celery from server.analysis.enrichment import DatasetEnrichment @@ -7,6 +7,8 @@ from server.db.database import PostgresConnector from server.core.datasets import DatasetManager from server.connectors.registry import get_available_connectors +logger = logging.getLogger(__name__) + @celery.task(bind=True, max_retries=3) def process_dataset(self, dataset_id: int, posts: list, topics: dict): db = PostgresConnector() @@ -26,9 +28,7 @@ def process_dataset(self, dataset_id: int, posts: list, topics: dict): @celery.task(bind=True, max_retries=3) def fetch_and_process_dataset(self, dataset_id: int, - per_source: dict[str, int], - search: str, - category: str, + source_info: list[dict], topics: dict): connectors = get_available_connectors() db = PostgresConnector() @@ -36,13 +36,18 @@ def fetch_and_process_dataset(self, posts = [] try: - for source_name, source_limit in per_source.items(): - connector = connectors[source_name]() + for metadata in source_info: + name = metadata["name"] + search = metadata.get("search") + category = metadata.get("category") + limit = metadata.get("limit", 100) + + connector = connectors[name]() raw_posts = connector.get_new_posts_by_search( search=search, category=category, - post_limit=source_limit, - comment_limit=source_limit + post_limit=limit, + comment_limit=limit ) posts.extend(post.to_dict() for post in raw_posts)