Automatic Scraping of dataset options #9

Merged
dylan merged 36 commits from feat/automatic-scraping-datasets into main 2026-03-14 21:58:49 +00:00
2 changed files with 61 additions and 18 deletions
Showing only changes of commit 6ec47256d0 - Show all commits

View File

@@ -19,19 +19,18 @@ from server.exceptions import NotAuthorisedException, NonExistentDatasetExceptio
from server.db.database import PostgresConnector from server.db.database import PostgresConnector
from server.core.auth import AuthManager from server.core.auth import AuthManager
from server.core.datasets import DatasetManager from server.core.datasets import DatasetManager
from server.utils import get_request_filters from server.utils import get_request_filters, split_limit, get_env
from server.queue.tasks import process_dataset from server.queue.tasks import process_dataset
from server.connectors.registry import get_connector_metadata from server.connectors.registry import get_available_connectors, get_connector_metadata
app = Flask(__name__) app = Flask(__name__)
# Env Variables # Env Variables
load_dotenv() load_dotenv()
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") max_fetch_limit = int(get_env("MAX_FETCH_LIMIT"))
jwt_secret_key = os.getenv("JWT_SECRET_KEY", "super-secret-change-this") frontend_url = get_env("FRONTEND_URL")
jwt_access_token_expires = int( jwt_secret_key = get_env("JWT_SECRET_KEY")
os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200) jwt_access_token_expires = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)) # Default to 20 minutes
) # Default to 20 minutes
# Flask Configuration # Flask Configuration
CORS(app, resources={r"/*": {"origins": frontend_url}}) CORS(app, resources={r"/*": {"origins": frontend_url}})
@@ -45,7 +44,8 @@ db = PostgresConnector()
auth_manager = AuthManager(db, bcrypt) auth_manager = AuthManager(db, bcrypt)
dataset_manager = DatasetManager(db) dataset_manager = DatasetManager(db)
stat_gen = StatGen() stat_gen = StatGen()
connectors = get_available_connectors()
default_topic_list = json.load(open("server/topics.json"))
@app.route("/register", methods=["POST"]) @app.route("/register", methods=["POST"])
def register_user(): def register_user():
@@ -122,8 +122,50 @@ def scrape_data():
if "sources" not in request.form: if "sources" not in request.form:
return jsonify({"error": "Data source names are required."}), 400 return jsonify({"error": "Data source names are required."}), 400
sources = request.form.get("sources") user_id = int(get_jwt_identity())
sources = request.form.getlist("sources")
limit = int(request.form.get("limit", max_fetch_limit))
dataset_name = request.form.get("name", "").strip()
search = request.form.get("search")
category = request.form.get("category")
print(sources)
if limit > max_fetch_limit:
return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400
for source in sources:
if source not in connectors.keys():
return jsonify({"error": "Source must exist"}), 400
limits = split_limit(limit, len(sources))
per_source = dict(zip(sources, limits))
try:
posts = []
for source_name, source_limit in per_source.items():
connector = connectors[source_name]()
posts.extend(connector.get_new_posts_by_search(
search=search,
category=category,
post_limit=source_limit,
comment_limit=source_limit
))
dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, {})
process_dataset.delay(dataset_id, [p.to_dict() for p in posts], default_topic_list)
return jsonify(
{
"message": "Dataset queued for processing",
"dataset_id": dataset_id,
"status": "processing",
}
), 202
except Exception:
print(traceback.format_exc())
return jsonify({"error": "An unexpected error occurred"}), 500
@app.route("/datasets/upload", methods=["POST"]) @app.route("/datasets/upload", methods=["POST"])
@jwt_required() @jwt_required()

View File

@@ -19,32 +19,33 @@ class RedditAPI(BaseConnector):
# Public Methods # # Public Methods #
def get_new_posts_by_search(self, def get_new_posts_by_search(self,
search: str, search: str,
subreddit: str, category: str,
limit: int post_limit: int,
comment_limit: int
) -> list[Post]: ) -> list[Post]:
if not search: if not search:
return self._get_new_subreddit_posts(subreddit, limit=limit) return self._get_new_subreddit_posts(category, limit=post_limit)
params = { params = {
'q': search, 'q': search,
'limit': limit, 'limit': post_limit,
'restrict_sr': 'on', 'restrict_sr': 'on',
'sort': 'new' 'sort': 'new'
} }
logger.info(f"Searching subreddit '{subreddit}' for '{search}' with limit {limit}") logger.info(f"Searching subreddit '{category}' for '{search}' with limit {post_limit}")
url = f"r/{subreddit}/search.json" url = f"r/{category}/search.json"
posts = [] posts = []
while len(posts) < limit: while len(posts) < post_limit:
batch_limit = min(100, limit - len(posts)) batch_limit = min(100, post_limit - len(posts))
params['limit'] = batch_limit params['limit'] = batch_limit
data = self._fetch_post_overviews(url, params) data = self._fetch_post_overviews(url, params)
batch_posts = self._parse_posts(data) batch_posts = self._parse_posts(data)
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {subreddit}") logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {category}")
if not batch_posts: if not batch_posts:
break break