Compare commits

..

4 Commits

6 changed files with 87 additions and 21 deletions

View File

@@ -19,19 +19,18 @@ from server.exceptions import NotAuthorisedException, NonExistentDatasetExceptio
from server.db.database import PostgresConnector from server.db.database import PostgresConnector
from server.core.auth import AuthManager from server.core.auth import AuthManager
from server.core.datasets import DatasetManager from server.core.datasets import DatasetManager
from server.utils import get_request_filters from server.utils import get_request_filters, split_limit, get_env
from server.queue.tasks import process_dataset from server.queue.tasks import process_dataset, fetch_and_process_dataset
from server.connectors.registry import get_connector_metadata from server.connectors.registry import get_available_connectors, get_connector_metadata
app = Flask(__name__) app = Flask(__name__)
# Env Variables # Env Variables
load_dotenv() load_dotenv()
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") max_fetch_limit = int(get_env("MAX_FETCH_LIMIT"))
jwt_secret_key = os.getenv("JWT_SECRET_KEY", "super-secret-change-this") frontend_url = get_env("FRONTEND_URL")
jwt_access_token_expires = int( jwt_secret_key = get_env("JWT_SECRET_KEY")
os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200) jwt_access_token_expires = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)) # Default to 20 minutes
) # Default to 20 minutes
# Flask Configuration # Flask Configuration
CORS(app, resources={r"/*": {"origins": frontend_url}}) CORS(app, resources={r"/*": {"origins": frontend_url}})
@@ -45,7 +44,8 @@ db = PostgresConnector()
auth_manager = AuthManager(db, bcrypt) auth_manager = AuthManager(db, bcrypt)
dataset_manager = DatasetManager(db) dataset_manager = DatasetManager(db)
stat_gen = StatGen() stat_gen = StatGen()
connectors = get_available_connectors()
default_topic_list = json.load(open("server/topics.json"))
@app.route("/register", methods=["POST"]) @app.route("/register", methods=["POST"])
def register_user(): def register_user():
@@ -122,8 +122,39 @@ def scrape_data():
if "sources" not in request.form: if "sources" not in request.form:
return jsonify({"error": "Data source names are required."}), 400 return jsonify({"error": "Data source names are required."}), 400
sources = request.form.get("sources") user_id = int(get_jwt_identity())
sources = request.form.getlist("sources")
limit = int(request.form.get("limit", max_fetch_limit))
dataset_name = request.form.get("name", "").strip()
search = request.form.get("search")
category = request.form.get("category")
if limit > max_fetch_limit:
return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400
for source in sources:
if source not in connectors.keys():
return jsonify({"error": "Source must exist"}), 400
limits = split_limit(limit, len(sources))
per_source = dict(zip(sources, limits))
dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, default_topic_list)
dataset_manager.set_dataset_status(dataset_id, "fetching", f"Data is being fetched from {str(sources)}")
try:
fetch_and_process_dataset.delay(dataset_id, per_source, search, category, default_topic_list)
return jsonify(
{
"message": "Dataset queued for processing",
"dataset_id": dataset_id,
"status": "processing",
}
), 202
except Exception:
print(traceback.format_exc())
return jsonify({"error": "An unexpected error occurred"}), 500
@app.route("/datasets/upload", methods=["POST"]) @app.route("/datasets/upload", methods=["POST"])
@jwt_required() @jwt_required()

View File

@@ -19,32 +19,33 @@ class RedditAPI(BaseConnector):
# Public Methods # # Public Methods #
def get_new_posts_by_search(self, def get_new_posts_by_search(self,
search: str, search: str,
subreddit: str, category: str,
limit: int post_limit: int,
comment_limit: int
) -> list[Post]: ) -> list[Post]:
if not search: if not search:
return self._get_new_subreddit_posts(subreddit, limit=limit) return self._get_new_subreddit_posts(category, limit=post_limit)
params = { params = {
'q': search, 'q': search,
'limit': limit, 'limit': post_limit,
'restrict_sr': 'on', 'restrict_sr': 'on',
'sort': 'new' 'sort': 'new'
} }
logger.info(f"Searching subreddit '{subreddit}' for '{search}' with limit {limit}") logger.info(f"Searching subreddit '{category}' for '{search}' with limit {post_limit}")
url = f"r/{subreddit}/search.json" url = f"r/{category}/search.json"
posts = [] posts = []
while len(posts) < limit: while len(posts) < post_limit:
batch_limit = min(100, limit - len(posts)) batch_limit = min(100, post_limit - len(posts))
params['limit'] = batch_limit params['limit'] = batch_limit
data = self._fetch_post_overviews(url, params) data = self._fetch_post_overviews(url, params)
batch_posts = self._parse_posts(data) batch_posts = self._parse_posts(data)
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {subreddit}") logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {category}")
if not batch_posts: if not batch_posts:
break break

View File

@@ -114,7 +114,7 @@ class DatasetManager:
self.db.execute_batch(query, values) self.db.execute_batch(query, values)
def set_dataset_status(self, dataset_id: int, status: str, status_message: str | None = None): def set_dataset_status(self, dataset_id: int, status: str, status_message: str | None = None):
if status not in ["processing", "complete", "error"]: if status not in ["fetching", "processing", "complete", "error"]:
raise ValueError("Invalid status") raise ValueError("Invalid status")
query = """ query = """

View File

@@ -23,7 +23,7 @@ CREATE TABLE datasets (
-- Enforce valid states -- Enforce valid states
CONSTRAINT datasets_status_check CONSTRAINT datasets_status_check
CHECK (status IN ('processing', 'complete', 'error')) CHECK (status IN ('fetching', 'processing', 'complete', 'error'))
); );
CREATE TABLE events ( CREATE TABLE events (

View File

@@ -4,6 +4,7 @@ from server.queue.celery_app import celery
from server.analysis.enrichment import DatasetEnrichment from server.analysis.enrichment import DatasetEnrichment
from server.db.database import PostgresConnector from server.db.database import PostgresConnector
from server.core.datasets import DatasetManager from server.core.datasets import DatasetManager
from server.connectors.registry import get_available_connectors
@celery.task(bind=True, max_retries=3) @celery.task(bind=True, max_retries=3)
def process_dataset(self, dataset_id: int, posts: list, topics: dict): def process_dataset(self, dataset_id: int, posts: list, topics: dict):
@@ -18,5 +19,31 @@ def process_dataset(self, dataset_id: int, posts: list, topics: dict):
dataset_manager.save_dataset_content(dataset_id, enriched_df) dataset_manager.save_dataset_content(dataset_id, enriched_df)
dataset_manager.set_dataset_status(dataset_id, "complete", "NLP Processing Completed Successfully") dataset_manager.set_dataset_status(dataset_id, "complete", "NLP Processing Completed Successfully")
except Exception as e:
dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}")
@celery.task(bind=True, max_retries=3)
def fetch_and_process_dataset(self,
dataset_id: int,
per_source: dict[str, int],
search: str,
category: str,
topics: dict):
connectors = get_available_connectors()
db = PostgresConnector()
dataset_manager = DatasetManager(db)
posts = []
try:
for source_name, source_limit in per_source.items():
connector = connectors[source_name]()
posts.extend(connector.get_new_posts_by_search(
search=search,
category=category,
post_limit=source_limit,
comment_limit=source_limit
))
process_dataset.delay(dataset_id, [p.to_dict() for p in posts], topics)
except Exception as e: except Exception as e:
dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}") dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}")

View File

@@ -1,4 +1,5 @@
import datetime import datetime
import os
from flask import request from flask import request
def parse_datetime_filter(value): def parse_datetime_filter(value):
@@ -52,3 +53,9 @@ def get_request_filters() -> dict:
def split_limit(limit: int, n: int) -> list[int]: def split_limit(limit: int, n: int) -> list[int]:
base, remainder = divmod(limit, n) base, remainder = divmod(limit, n)
return [base + (1 if i < remainder else 0) for i in range(n)] return [base + (1 if i < remainder else 0) for i in range(n)]
def get_env(name: str) -> str:
value = os.getenv(name)
if not value:
raise RuntimeError(f"Missing required environment variable: {name}")
return value