diff --git a/server/app.py b/server/app.py index e5a8037..d896ac2 100644 --- a/server/app.py +++ b/server/app.py @@ -19,19 +19,18 @@ from server.exceptions import NotAuthorisedException, NonExistentDatasetExceptio from server.db.database import PostgresConnector from server.core.auth import AuthManager from server.core.datasets import DatasetManager -from server.utils import get_request_filters +from server.utils import get_request_filters, split_limit, get_env from server.queue.tasks import process_dataset -from server.connectors.registry import get_connector_metadata +from server.connectors.registry import get_available_connectors, get_connector_metadata app = Flask(__name__) # Env Variables load_dotenv() -frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") -jwt_secret_key = os.getenv("JWT_SECRET_KEY", "super-secret-change-this") -jwt_access_token_expires = int( - os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200) -) # Default to 20 minutes +max_fetch_limit = int(get_env("MAX_FETCH_LIMIT")) +frontend_url = get_env("FRONTEND_URL") +jwt_secret_key = get_env("JWT_SECRET_KEY") +jwt_access_token_expires = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)) # Default to 20 minutes # Flask Configuration CORS(app, resources={r"/*": {"origins": frontend_url}}) @@ -45,7 +44,8 @@ db = PostgresConnector() auth_manager = AuthManager(db, bcrypt) dataset_manager = DatasetManager(db) stat_gen = StatGen() - +connectors = get_available_connectors() +default_topic_list = json.load(open("server/topics.json")) @app.route("/register", methods=["POST"]) def register_user(): @@ -122,8 +122,50 @@ def scrape_data(): if "sources" not in request.form: return jsonify({"error": "Data source names are required."}), 400 - sources = request.form.get("sources") + user_id = int(get_jwt_identity()) + sources = request.form.getlist("sources") + limit = int(request.form.get("limit", max_fetch_limit)) + dataset_name = request.form.get("name", "").strip() + search = request.form.get("search") + category = request.form.get("category") + + print(sources) + + if limit > max_fetch_limit: + return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400 + + for source in sources: + if source not in connectors.keys(): + return jsonify({"error": "Source must exist"}), 400 + + limits = split_limit(limit, len(sources)) + per_source = dict(zip(sources, limits)) + + try: + posts = [] + for source_name, source_limit in per_source.items(): + connector = connectors[source_name]() + posts.extend(connector.get_new_posts_by_search( + search=search, + category=category, + post_limit=source_limit, + comment_limit=source_limit + )) + + dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, {}) + process_dataset.delay(dataset_id, [p.to_dict() for p in posts], default_topic_list) + + return jsonify( + { + "message": "Dataset queued for processing", + "dataset_id": dataset_id, + "status": "processing", + } + ), 202 + except Exception: + print(traceback.format_exc()) + return jsonify({"error": "An unexpected error occurred"}), 500 @app.route("/datasets/upload", methods=["POST"]) @jwt_required() diff --git a/server/connectors/reddit_api.py b/server/connectors/reddit_api.py index 2107ded..444326a 100644 --- a/server/connectors/reddit_api.py +++ b/server/connectors/reddit_api.py @@ -19,32 +19,33 @@ class RedditAPI(BaseConnector): # Public Methods # def get_new_posts_by_search(self, search: str, - subreddit: str, - limit: int + category: str, + post_limit: int, + comment_limit: int ) -> list[Post]: if not search: - return self._get_new_subreddit_posts(subreddit, limit=limit) + return self._get_new_subreddit_posts(category, limit=post_limit) params = { 'q': search, - 'limit': limit, + 'limit': post_limit, 'restrict_sr': 'on', 'sort': 'new' } - logger.info(f"Searching subreddit '{subreddit}' for '{search}' with limit {limit}") - url = f"r/{subreddit}/search.json" + logger.info(f"Searching subreddit '{category}' for '{search}' with limit {post_limit}") + url = f"r/{category}/search.json" posts = [] - while len(posts) < limit: - batch_limit = min(100, limit - len(posts)) + while len(posts) < post_limit: + batch_limit = min(100, post_limit - len(posts)) params['limit'] = batch_limit data = self._fetch_post_overviews(url, params) batch_posts = self._parse_posts(data) - logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {subreddit}") + logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {category}") if not batch_posts: break