Compare commits
4 Commits
17bd4702b2
...
a65c4a461c
| Author | SHA1 | Date | |
|---|---|---|---|
| a65c4a461c | |||
| 15704a0782 | |||
| 6ec47256d0 | |||
| 2572664e26 |
@@ -19,19 +19,18 @@ from server.exceptions import NotAuthorisedException, NonExistentDatasetExceptio
|
|||||||
from server.db.database import PostgresConnector
|
from server.db.database import PostgresConnector
|
||||||
from server.core.auth import AuthManager
|
from server.core.auth import AuthManager
|
||||||
from server.core.datasets import DatasetManager
|
from server.core.datasets import DatasetManager
|
||||||
from server.utils import get_request_filters
|
from server.utils import get_request_filters, split_limit, get_env
|
||||||
from server.queue.tasks import process_dataset
|
from server.queue.tasks import process_dataset, fetch_and_process_dataset
|
||||||
from server.connectors.registry import get_connector_metadata
|
from server.connectors.registry import get_available_connectors, get_connector_metadata
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Env Variables
|
# Env Variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
max_fetch_limit = int(get_env("MAX_FETCH_LIMIT"))
|
||||||
jwt_secret_key = os.getenv("JWT_SECRET_KEY", "super-secret-change-this")
|
frontend_url = get_env("FRONTEND_URL")
|
||||||
jwt_access_token_expires = int(
|
jwt_secret_key = get_env("JWT_SECRET_KEY")
|
||||||
os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)
|
jwt_access_token_expires = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES", 1200)) # Default to 20 minutes
|
||||||
) # Default to 20 minutes
|
|
||||||
|
|
||||||
# Flask Configuration
|
# Flask Configuration
|
||||||
CORS(app, resources={r"/*": {"origins": frontend_url}})
|
CORS(app, resources={r"/*": {"origins": frontend_url}})
|
||||||
@@ -45,7 +44,8 @@ db = PostgresConnector()
|
|||||||
auth_manager = AuthManager(db, bcrypt)
|
auth_manager = AuthManager(db, bcrypt)
|
||||||
dataset_manager = DatasetManager(db)
|
dataset_manager = DatasetManager(db)
|
||||||
stat_gen = StatGen()
|
stat_gen = StatGen()
|
||||||
|
connectors = get_available_connectors()
|
||||||
|
default_topic_list = json.load(open("server/topics.json"))
|
||||||
|
|
||||||
@app.route("/register", methods=["POST"])
|
@app.route("/register", methods=["POST"])
|
||||||
def register_user():
|
def register_user():
|
||||||
@@ -122,8 +122,39 @@ def scrape_data():
|
|||||||
if "sources" not in request.form:
|
if "sources" not in request.form:
|
||||||
return jsonify({"error": "Data source names are required."}), 400
|
return jsonify({"error": "Data source names are required."}), 400
|
||||||
|
|
||||||
sources = request.form.get("sources")
|
user_id = int(get_jwt_identity())
|
||||||
|
sources = request.form.getlist("sources")
|
||||||
|
limit = int(request.form.get("limit", max_fetch_limit))
|
||||||
|
|
||||||
|
dataset_name = request.form.get("name", "").strip()
|
||||||
|
search = request.form.get("search")
|
||||||
|
category = request.form.get("category")
|
||||||
|
|
||||||
|
if limit > max_fetch_limit:
|
||||||
|
return jsonify({"error": f"Due to API limitations, we cannot receive more than ${max_fetch_limit} posts"}), 400
|
||||||
|
|
||||||
|
for source in sources:
|
||||||
|
if source not in connectors.keys():
|
||||||
|
return jsonify({"error": "Source must exist"}), 400
|
||||||
|
|
||||||
|
limits = split_limit(limit, len(sources))
|
||||||
|
per_source = dict(zip(sources, limits))
|
||||||
|
dataset_id = dataset_manager.save_dataset_info(user_id, dataset_name, default_topic_list)
|
||||||
|
dataset_manager.set_dataset_status(dataset_id, "fetching", f"Data is being fetched from {str(sources)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_and_process_dataset.delay(dataset_id, per_source, search, category, default_topic_list)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"message": "Dataset queued for processing",
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"status": "processing",
|
||||||
|
}
|
||||||
|
), 202
|
||||||
|
except Exception:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return jsonify({"error": "An unexpected error occurred"}), 500
|
||||||
|
|
||||||
@app.route("/datasets/upload", methods=["POST"])
|
@app.route("/datasets/upload", methods=["POST"])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
|
|||||||
@@ -19,32 +19,33 @@ class RedditAPI(BaseConnector):
|
|||||||
# Public Methods #
|
# Public Methods #
|
||||||
def get_new_posts_by_search(self,
|
def get_new_posts_by_search(self,
|
||||||
search: str,
|
search: str,
|
||||||
subreddit: str,
|
category: str,
|
||||||
limit: int
|
post_limit: int,
|
||||||
|
comment_limit: int
|
||||||
) -> list[Post]:
|
) -> list[Post]:
|
||||||
|
|
||||||
if not search:
|
if not search:
|
||||||
return self._get_new_subreddit_posts(subreddit, limit=limit)
|
return self._get_new_subreddit_posts(category, limit=post_limit)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'q': search,
|
'q': search,
|
||||||
'limit': limit,
|
'limit': post_limit,
|
||||||
'restrict_sr': 'on',
|
'restrict_sr': 'on',
|
||||||
'sort': 'new'
|
'sort': 'new'
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Searching subreddit '{subreddit}' for '{search}' with limit {limit}")
|
logger.info(f"Searching subreddit '{category}' for '{search}' with limit {post_limit}")
|
||||||
url = f"r/{subreddit}/search.json"
|
url = f"r/{category}/search.json"
|
||||||
posts = []
|
posts = []
|
||||||
|
|
||||||
while len(posts) < limit:
|
while len(posts) < post_limit:
|
||||||
batch_limit = min(100, limit - len(posts))
|
batch_limit = min(100, post_limit - len(posts))
|
||||||
params['limit'] = batch_limit
|
params['limit'] = batch_limit
|
||||||
|
|
||||||
data = self._fetch_post_overviews(url, params)
|
data = self._fetch_post_overviews(url, params)
|
||||||
batch_posts = self._parse_posts(data)
|
batch_posts = self._parse_posts(data)
|
||||||
|
|
||||||
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {subreddit}")
|
logger.debug(f"Fetched {len(batch_posts)} posts from search in subreddit {category}")
|
||||||
|
|
||||||
if not batch_posts:
|
if not batch_posts:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class DatasetManager:
|
|||||||
self.db.execute_batch(query, values)
|
self.db.execute_batch(query, values)
|
||||||
|
|
||||||
def set_dataset_status(self, dataset_id: int, status: str, status_message: str | None = None):
|
def set_dataset_status(self, dataset_id: int, status: str, status_message: str | None = None):
|
||||||
if status not in ["processing", "complete", "error"]:
|
if status not in ["fetching", "processing", "complete", "error"]:
|
||||||
raise ValueError("Invalid status")
|
raise ValueError("Invalid status")
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ CREATE TABLE datasets (
|
|||||||
|
|
||||||
-- Enforce valid states
|
-- Enforce valid states
|
||||||
CONSTRAINT datasets_status_check
|
CONSTRAINT datasets_status_check
|
||||||
CHECK (status IN ('processing', 'complete', 'error'))
|
CHECK (status IN ('fetching', 'processing', 'complete', 'error'))
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE events (
|
CREATE TABLE events (
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from server.queue.celery_app import celery
|
|||||||
from server.analysis.enrichment import DatasetEnrichment
|
from server.analysis.enrichment import DatasetEnrichment
|
||||||
from server.db.database import PostgresConnector
|
from server.db.database import PostgresConnector
|
||||||
from server.core.datasets import DatasetManager
|
from server.core.datasets import DatasetManager
|
||||||
|
from server.connectors.registry import get_available_connectors
|
||||||
|
|
||||||
@celery.task(bind=True, max_retries=3)
|
@celery.task(bind=True, max_retries=3)
|
||||||
def process_dataset(self, dataset_id: int, posts: list, topics: dict):
|
def process_dataset(self, dataset_id: int, posts: list, topics: dict):
|
||||||
@@ -20,3 +21,29 @@ def process_dataset(self, dataset_id: int, posts: list, topics: dict):
|
|||||||
dataset_manager.set_dataset_status(dataset_id, "complete", "NLP Processing Completed Successfully")
|
dataset_manager.set_dataset_status(dataset_id, "complete", "NLP Processing Completed Successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}")
|
dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}")
|
||||||
|
|
||||||
|
@celery.task(bind=True, max_retries=3)
|
||||||
|
def fetch_and_process_dataset(self,
|
||||||
|
dataset_id: int,
|
||||||
|
per_source: dict[str, int],
|
||||||
|
search: str,
|
||||||
|
category: str,
|
||||||
|
topics: dict):
|
||||||
|
connectors = get_available_connectors()
|
||||||
|
db = PostgresConnector()
|
||||||
|
dataset_manager = DatasetManager(db)
|
||||||
|
posts = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for source_name, source_limit in per_source.items():
|
||||||
|
connector = connectors[source_name]()
|
||||||
|
posts.extend(connector.get_new_posts_by_search(
|
||||||
|
search=search,
|
||||||
|
category=category,
|
||||||
|
post_limit=source_limit,
|
||||||
|
comment_limit=source_limit
|
||||||
|
))
|
||||||
|
|
||||||
|
process_dataset.delay(dataset_id, [p.to_dict() for p in posts], topics)
|
||||||
|
except Exception as e:
|
||||||
|
dataset_manager.set_dataset_status(dataset_id, "error", f"An error occurred: {e}")
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
def parse_datetime_filter(value):
|
def parse_datetime_filter(value):
|
||||||
@@ -52,3 +53,9 @@ def get_request_filters() -> dict:
|
|||||||
def split_limit(limit: int, n: int) -> list[int]:
|
def split_limit(limit: int, n: int) -> list[int]:
|
||||||
base, remainder = divmod(limit, n)
|
base, remainder = divmod(limit, n)
|
||||||
return [base + (1 if i < remainder else 0) for i in range(n)]
|
return [base + (1 if i < remainder else 0) for i in range(n)]
|
||||||
|
|
||||||
|
def get_env(name: str) -> str:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if not value:
|
||||||
|
raise RuntimeError(f"Missing required environment variable: {name}")
|
||||||
|
return value
|
||||||
|
|||||||
Reference in New Issue
Block a user