Refactor DB classes and management #5

Merged
dylan merged 7 commits from refactor/db-class into main 2026-03-03 11:17:50 +00:00
5 changed files with 151 additions and 135 deletions
Showing only changes of commit 4bec0dd32c - Show all commits

View File

@@ -16,14 +16,14 @@ from server.enrichment import DatasetEnrichment
from server.exceptions import NotAuthorisedException, NotExistentDatasetException from server.exceptions import NotAuthorisedException, NotExistentDatasetException
from server.db.database import PostgresConnector from server.db.database import PostgresConnector
from server.auth import AuthManager from server.auth import AuthManager
from server.utils import get_request_filters, get_dataset_and_validate from server.datasets import DatasetManager
from server.utils import get_request_filters
import pandas as pd import pandas as pd
import traceback import traceback
import json import json
app = Flask(__name__) app = Flask(__name__)
db = PostgresConnector()
# Env Variables # Env Variables
load_dotenv() load_dotenv()
@@ -40,11 +40,12 @@ app.config["JWT_ACCESS_TOKEN_EXPIRES"] = jwt_access_token_expires
bcrypt = Bcrypt(app) bcrypt = Bcrypt(app)
jwt = JWTManager(app) jwt = JWTManager(app)
db = PostgresConnector()
auth_manager = AuthManager(db, bcrypt) auth_manager = AuthManager(db, bcrypt)
dataset_manager = DatasetManager(db)
stat_gen = StatGen() stat_gen = StatGen()
@app.route("/register", methods=["POST"]) @app.route("/register", methods=["POST"])
def register_user(): def register_user():
data = request.get_json() data = request.get_json()
@@ -132,10 +133,8 @@ def upload_data():
processor = DatasetEnrichment(posts_df, topics) processor = DatasetEnrichment(posts_df, topics)
enriched_df = processor.enrich() enriched_df = processor.enrich()
dataset_id = db.save_dataset_info( dataset_id = dataset_manager.save_dataset_info(current_user, f"dataset_{current_user}", topics)
current_user, f"dataset_{current_user}", topics dataset_manager.save_dataset_content(dataset_id, enriched_df)
)
db.save_dataset_content(dataset_id, enriched_df)
return jsonify( return jsonify(
{ {
@@ -154,7 +153,8 @@ def upload_data():
@jwt_required() @jwt_required()
def get_dataset(dataset_id): def get_dataset(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
filtered_dataset = stat_gen.filter_dataset(dataset_content, filters) filtered_dataset = stat_gen.filter_dataset(dataset_content, filters)
return jsonify(filtered_dataset), 200 return jsonify(filtered_dataset), 200
@@ -171,7 +171,8 @@ def get_dataset(dataset_id):
@jwt_required() @jwt_required()
def content_endpoint(dataset_id): def content_endpoint(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.get_content_analysis(dataset_content, filters)), 200 return jsonify(stat_gen.get_content_analysis(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:
@@ -187,7 +188,8 @@ def content_endpoint(dataset_id):
@jwt_required() @jwt_required()
def get_summary(dataset_id): def get_summary(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.summary(dataset_content, filters)), 200 return jsonify(stat_gen.summary(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:
@@ -203,7 +205,8 @@ def get_summary(dataset_id):
@jwt_required() @jwt_required()
def get_time_analysis(dataset_id): def get_time_analysis(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.get_time_analysis(dataset_content, filters)), 200 return jsonify(stat_gen.get_time_analysis(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:
@@ -219,7 +222,8 @@ def get_time_analysis(dataset_id):
@jwt_required() @jwt_required()
def get_user_analysis(dataset_id): def get_user_analysis(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.get_user_analysis(dataset_content, filters)), 200 return jsonify(stat_gen.get_user_analysis(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:
@@ -235,7 +239,8 @@ def get_user_analysis(dataset_id):
@jwt_required() @jwt_required()
def get_cultural_analysis(dataset_id): def get_cultural_analysis(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.get_cultural_analysis(dataset_content, filters)), 200 return jsonify(stat_gen.get_cultural_analysis(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:
@@ -251,7 +256,8 @@ def get_cultural_analysis(dataset_id):
@jwt_required() @jwt_required()
def get_interaction_analysis(dataset_id): def get_interaction_analysis(dataset_id):
try: try:
dataset_content = get_dataset_and_validate(dataset_id, db) user_id = get_jwt_identity()
dataset_content = dataset_manager.get_dataset_and_validate(dataset_id, int(user_id))
filters = get_request_filters() filters = get_request_filters()
return jsonify(stat_gen.get_interactional_analysis(dataset_content, filters)), 200 return jsonify(stat_gen.get_interactional_analysis(dataset_content, filters)), 200
except NotAuthorisedException: except NotAuthorisedException:

View File

@@ -6,19 +6,28 @@ class AuthManager:
self.db = db self.db = db
self.bcrypt = bcrypt self.bcrypt = bcrypt
# private
def _save_user(self, username, email, password_hash):
query = """
INSERT INTO users (username, email, password_hash)
VALUES (%s, %s, %s)
"""
self.db.execute(query, (username, email, password_hash))
# public
def register_user(self, username, email, password): def register_user(self, username, email, password):
hashed_password = self.bcrypt.generate_password_hash(password).decode("utf-8") hashed_password = self.bcrypt.generate_password_hash(password).decode("utf-8")
if self.db.get_user_by_email(email): if self.get_user_by_email(email):
raise ValueError("Email already registered") raise ValueError("Email already registered")
if self.db.get_user_by_username(username): if self.get_user_by_username(username):
raise ValueError("Username already taken") raise ValueError("Username already taken")
self.db.save_user(username, email, hashed_password) self._save_user(username, email, hashed_password)
def authenticate_user(self, username, password): def authenticate_user(self, username, password):
user = self.db.get_user_by_username(username) user = self.get_user_by_username(username)
if user and self.bcrypt.check_password_hash(user['password_hash'], password): if user and self.bcrypt.check_password_hash(user['password_hash'], password):
return user return user
return None return None
@@ -27,3 +36,13 @@ class AuthManager:
query = "SELECT id, username, email FROM users WHERE id = %s" query = "SELECT id, username, email FROM users WHERE id = %s"
result = self.db.execute(query, (user_id,), fetch=True) result = self.db.execute(query, (user_id,), fetch=True)
return result[0] if result else None return result[0] if result else None
def get_user_by_username(self, username) -> dict:
query = "SELECT id, username, email, password_hash FROM users WHERE username = %s"
result = self.db.execute(query, (username,), fetch=True)
return result[0] if result else None
def get_user_by_email(self, email) -> dict:
query = "SELECT id, username, email, password_hash FROM users WHERE email = %s"
result = self.db.execute(query, (email,), fetch=True)
return result[0] if result else None

104
server/datasets.py Normal file
View File

@@ -0,0 +1,104 @@
import pandas as pd
from server.db.database import PostgresConnector
from psycopg2.extras import Json
from server.exceptions import NotAuthorisedException
from flask_jwt_extended import get_jwt_identity
class DatasetManager:
def __init__(self, db: PostgresConnector):
self.db = db
def get_dataset_and_validate(self, dataset_id: int, user_id: int) -> pd.DataFrame:
dataset_info = self.get_dataset_info(dataset_id)
if dataset_info.get("user_id") != user_id:
raise NotAuthorisedException("This user is not authorised to access this dataset")
return self.get_dataset_content(dataset_id)
def get_dataset_content(self, dataset_id: int) -> pd.DataFrame:
query = "SELECT * FROM events WHERE dataset_id = %s"
result = self.db.execute(query, (dataset_id,), fetch=True)
return pd.DataFrame(result)
def get_dataset_info(self, dataset_id: int) -> dict:
query = "SELECT * FROM datasets WHERE id = %s"
result = self.db.execute(query, (dataset_id,), fetch=True)
return result[0] if result else None
def save_dataset_info(self, user_id: int, dataset_name: str, topics: dict) -> int:
query = """
INSERT INTO datasets (user_id, name, topics)
VALUES (%s, %s, %s)
RETURNING id
"""
result = self.db.execute(query, (user_id, dataset_name, Json(topics)), fetch=True)
return result[0]["id"] if result else None
def get_dataset_content(self, dataset_id: int) -> pd.DataFrame:
query = "SELECT * FROM events WHERE dataset_id = %s"
result = self.db.execute(query, (dataset_id,), fetch=True)
return pd.DataFrame(result)
def save_dataset_content(self, dataset_id: int, event_data: pd.DataFrame):
if event_data.empty:
return
query = """
INSERT INTO events (
dataset_id,
type,
parent_id,
author,
content,
timestamp,
date,
dt,
hour,
weekday,
reply_to,
source,
topic,
topic_confidence,
ner_entities,
emotion_anger,
emotion_disgust,
emotion_fear,
emotion_joy,
emotion_sadness
)
VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s
)
"""
values = [
(
dataset_id,
row["type"],
row["parent_id"],
row["author"],
row["content"],
row["timestamp"],
row["date"],
row["dt"],
row["hour"],
row["weekday"],
row.get("reply_to"),
row["source"],
row.get("topic"),
row.get("topic_confidence"),
Json(row["ner_entities"]) if row.get("ner_entities") else None,
row.get("emotion_anger"),
row.get("emotion_disgust"),
row.get("emotion_fear"),
row.get("emotion_joy"),
row.get("emotion_sadness"),
)
for _, row in event_data.iterrows()
]
self.db.execute_batch(query, values)

View File

@@ -27,112 +27,13 @@ class PostgresConnector:
return cursor.fetchall() return cursor.fetchall()
self.connection.commit() self.connection.commit()
def executemany(self, query, param_list) -> list: def execute_batch(self, query, values):
with self.connection.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.executemany(query, param_list)
self.connection.commit()
## User Management Methods
def save_user(self, username, email, password_hash):
query = """
INSERT INTO users (username, email, password_hash)
VALUES (%s, %s, %s)
"""
self.execute(query, (username, email, password_hash))
def get_user_by_username(self, username) -> dict:
query = "SELECT id, username, email, password_hash FROM users WHERE username = %s"
result = self.execute(query, (username,), fetch=True)
return result[0] if result else None
def get_user_by_email(self, email) -> dict:
query = "SELECT id, username, email, password_hash FROM users WHERE email = %s"
result = self.execute(query, (email,), fetch=True)
return result[0] if result else None
# Dataset Management Methods
def save_dataset_info(self, user_id: int, dataset_name: str, topics: dict) -> int:
query = """
INSERT INTO datasets (user_id, name, topics)
VALUES (%s, %s, %s)
RETURNING id
"""
result = self.execute(query, (user_id, dataset_name, Json(topics)), fetch=True)
return result[0]["id"] if result else None
def save_dataset_content(self, dataset_id: int, event_data: pd.DataFrame):
query = """
INSERT INTO events (
dataset_id,
type,
parent_id,
author,
content,
timestamp,
date,
dt,
hour,
weekday,
reply_to,
source,
topic,
topic_confidence,
ner_entities,
emotion_anger,
emotion_disgust,
emotion_fear,
emotion_joy,
emotion_sadness
)
VALUES (
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s,
%s, %s, %s, %s, %s
)
"""
values = []
for _, row in event_data.iterrows():
values.append((
dataset_id,
row["type"],
row["parent_id"],
row["author"],
row["content"],
row["timestamp"],
row["date"],
row["dt"],
row["hour"],
row["weekday"],
row.get("reply_to"),
row["source"],
row.get("topic"),
row.get("topic_confidence"),
Json(row["ner_entities"]) if row.get("ner_entities") else None,
row.get("emotion_anger"),
row.get("emotion_disgust"),
row.get("emotion_fear"),
row.get("emotion_joy"),
row.get("emotion_sadness"),
))
with self.connection.cursor(cursor_factory=RealDictCursor) as cursor: with self.connection.cursor(cursor_factory=RealDictCursor) as cursor:
execute_batch(cursor, query, values) execute_batch(cursor, query, values)
self.connection.commit() self.connection.commit()
def get_dataset_content(self, dataset_id: int) -> pd.DataFrame:
query = "SELECT * FROM events WHERE dataset_id = %s"
result = self.execute(query, (dataset_id,), fetch=True)
return pd.DataFrame(result)
def get_dataset_info(self, dataset_id: int) -> dict:
query = "SELECT * FROM datasets WHERE id = %s"
result = self.execute(query, (dataset_id,), fetch=True)
return result[0] if result else None
## User Management Methods
def close(self): def close(self):
if self.connection: if self.connection:
self.connection.close() self.connection.close()

View File

@@ -1,10 +1,5 @@
import datetime import datetime
import pandas as pd
from flask import request from flask import request
from flask_jwt_extended import get_jwt_identity
from server.db.database import PostgresConnector
from server.exceptions import NotAuthorisedException
def parse_datetime_filter(value): def parse_datetime_filter(value):
if not value: if not value:
@@ -53,12 +48,3 @@ def get_request_filters() -> dict:
filters["data_sources"] = data_sources filters["data_sources"] = data_sources
return filters return filters
def get_dataset_and_validate(dataset_id: int, db: PostgresConnector) -> pd.DataFrame:
current_user = get_jwt_identity()
dataset = db.get_dataset_info(dataset_id)
if dataset.get("user_id") != int(current_user):
raise NotAuthorisedException("This user is not authorised to access this dataset")
return db.get_dataset_content(dataset_id)