Automatic Scraping of dataset options #9

Merged
dylan merged 36 commits from feat/automatic-scraping-datasets into main 2026-03-14 21:58:49 +00:00
2 changed files with 61 additions and 18 deletions
Showing only changes of commit 6ec47256d0 - Show all commits

View File

@@ -19,19 +19,18 @@ from server.exceptions import NotAuthorisedException, NonExistentDatasetExceptio
from server.db.database import PostgresConnector
from server.core.auth import AuthManager
from server.core.datasets import DatasetManager
from server.utils import get_request_filters
from server.utils import get_request_filters, split_limit, get_env
from server.queue.tasks import process_dataset
from server.connectors.registry import get_connector_metadata
from server.connectors.registry import get_available_connectors, get_connector_metadata
app = Flask(__name__)
# Env Variables
load_dotenv()
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
jwt_secret_key = os.getenv("JWT_SECRET_KEY", "super-secret-change-this")
jwt_access_token_expires = int(
os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)
) # Default to 20 minutes
max_fetch_limit = int(get_env("MAX_FETCH_LIMIT"))
frontend_url = get_env("FRONTEND_URL")
jwt_secret_key = get_env("JWT_SECRET_KEY")
jwt_access_token_expires = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)) # Default to 20 minutes
# Flask Configuration
CORS(app, resources={r"/*": {"origins": frontend_url}})
@@ -45,7 +44,8 @@ db = PostgresConnector()
auth_manager = AuthManager(db, bcrypt)
dataset_manager = DatasetManager(db)
stat_gen = StatGen()
connectors = get_available_connectors()
default_topic_list = json.load(open("server/topics.json"))
@app.route("/register", methods=["POST"])
def register_user():
@@ -122,8 +122,50 @@ def scrape_data():
if "sources" not in request.form:
return jsonify({"error": "Data source names are required."}), 400
sources = request.form.get("sources")
user_id = int(get_jwt_identity())
sources = request.form.getlist("sources")
limit = int(request.form.get("limit", max_fetch_limit))
dataset_name = request.form.get("name", "").strip()
search = request.form.get("search")
category = request.form.get("category")
print(sources)
if limit > max_fetch_limit:
return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400
for source in sources:
if source not in connectors.keys():
return jsonify({"error": "Source must exist"}), 400
limits = split_limit(limit, len(sources))
per_source = dict(zip(sources, limits))
try:
posts = []
for source_name, source_limit in per_source.items():
connector = connectors[source_name]()
posts.extend(connector.get_new_posts_by_search(
search=search,
category=category,
post_limit=source_limit,
comment_limit=source_limit
))
dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, {})
process_dataset.delay(dataset_id, [p.to_dict() for p in posts], default_topic_list)
return jsonify(
{
"message": "Dataset queued for processing",
"dataset_id": dataset_id,
"status": "processing",
}
), 202
except Exception:
print(traceback.format_exc())
return jsonify({"error": "An unexpected error occurred"}), 500
@app.route("/datasets/upload", methods=["POST"])
@jwt_required()

View File

@@ -19,32 +19,33 @@ class RedditAPI(BaseConnector):
# Public Methods #
def get_new_posts_by_search(self,
search: str,
subreddit: str,
limit: int
category: str,
post_limit: int,
comment_limit: int
) -> list[Post]:
if not search:
return self._get_new_subreddit_posts(subreddit, limit=limit)
return self._get_new_subreddit_posts(category, limit=post_limit)
params = {
'q': search,
'limit': limit,
'limit': post_limit,
'restrict_sr': 'on',
'sort': 'new'
}
logger.info(f"Searching subreddit '{subreddit}' for '{search}' with limit {limit}")
url = f"r/{subreddit}/search.json"
logger.info(f"Searching subreddit '{category}' for '{search}' with limit {post_limit}")
url = f"r/{category}/search.json"
posts = []
while len(posts) < limit:
batch_limit = min(100, limit - len(posts))
while len(posts) < post_limit:
batch_limit = min(100, post_limit - len(posts))
params['limit'] = batch_limit
data = self._fetch_post_overviews(url, params)
batch_posts = self._parse_posts(data)
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {subreddit}")
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {category}")
if not batch_posts:
break