diff --git a/web_server/blueprints/authentication.py b/web_server/blueprints/authentication.py index 469124a..d10509e 100644 --- a/web_server/blueprints/authentication.py +++ b/web_server/blueprints/authentication.py @@ -26,7 +26,7 @@ def signup(): # Validation - ensure all fields exist, users cannot have an empty field if not all([username, email, password]): - error_fields = get_error_fields([username, email, password]), #!←← find the error_fields, to highlight them in red to the user on the frontend + error_fields = get_error_fields([username, email, password]) #!←← find the error_fields, to highlight them in red to the user on the frontend return jsonify({ "account_created": False, "error_fields": error_fields, @@ -48,25 +48,25 @@ def signup(): # Create a connection to the database db = Database() - cursor = db.create_connection() + db.create_connection() try: # Check for duplicate email/username, no two users can have the same - dup_email = cursor.execute( + dup_email = db.fetchone( "SELECT * FROM users WHERE email = ?", (email,) - ).fetchone() + ) - dup_username = cursor.execute( + dup_username = db.fetchone( "SELECT * FROM users WHERE username = ?", (username,) - ).fetchone() + ) if dup_email is not None: return jsonify({ "account_created": False, "error_fields": ["email"], - "message": "Email already taken" + "message": f"Email already taken: {email}" }), 400 if dup_username is not None: @@ -77,7 +77,7 @@ def signup(): }), 400 # Create new user once input is validated - cursor.execute( + db.execute( """INSERT INTO users (username, password, email, num_followers, stream_key, is_partnered, bio, current_stream_title, current_selected_category_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", @@ -151,14 +151,14 @@ def login(): # Create a connection to the database db = Database() - cursor = db.create_connection() + db.create_connection() try: # Check if user exists, only existing users can be logged in - user = cursor.execute( + user = db.fetchone( "SELECT * FROM users WHERE username = ?", (username,) - ).fetchone() + ) if not user: return jsonify({ @@ -210,7 +210,4 @@ def logout() -> dict: def get_error_fields(values: list): fields = ["username", "email", "password"] - for x in fields: - if not values[fields.index(x)]: - fields.remove(x) - return fields \ No newline at end of file + return [fields[i] for i, v in enumerate(values) if not v] \ No newline at end of file diff --git a/web_server/blueprints/chat.py b/web_server/blueprints/chat.py index a090e7a..31c5211 100644 --- a/web_server/blueprints/chat.py +++ b/web_server/blueprints/chat.py @@ -51,10 +51,10 @@ def get_past_chat(stream_id: int): # Connect to the database db = Database() - cursor = db.create_connection() + db.create_connection() # fetched in format: [(chatter_id, message, time_sent)] - all_chats = cursor.execute(""" + all_chats = db.fetchall(""" SELECT * FROM ( SELECT chatter_id, message, time_sent @@ -63,7 +63,7 @@ def get_past_chat(stream_id: int): ORDER BY time_sent DESC LIMIT 50 ) - ORDER BY time_sent ASC;""", (stream_id,)).fetchall() + ORDER BY time_sent ASC;""", (stream_id,)) db.close_connection() # Create JSON output of chat_history to pass through NGINX proxy @@ -103,8 +103,8 @@ def send_chat(data) -> None: def save_chat(chatter_id, stream_id, message): """Save the chat to the database""" db = Database() - cursor = db.create_connection() - cursor.execute(""" + db.create_connection() + db.execute(""" INSERT INTO chat (chatter_id, stream_id, message) VALUES (?, ?, ?);""", (chatter_id, stream_id, message)) db.commit_data() diff --git a/web_server/blueprints/streams.py b/web_server/blueprints/streams.py index e9d1f17..d19389d 100644 --- a/web_server/blueprints/streams.py +++ b/web_server/blueprints/streams.py @@ -34,7 +34,7 @@ def get_recommended_streams() -> list[dict]: Queries DB to get a list of recommended streams using an algorithm """ - user_id = session.get("user_id") + user_id = session.get("username") category = user_recommendation_category(user_id) streams = recommendations_based_on_category(category) return jsonify(streams) @@ -180,7 +180,7 @@ def publish_stream(): 1, datetime.now(), 1)) - + db.commit_data() return redirect(f"/{user_info['username']}") @@ -198,5 +198,6 @@ def end_stream(): # Set stream to not live db.execute("""UPDATE streams SET isLive = 0 WHERE user_id = ? AND isLive = 1""", (user_info["user_id"],)) + db.commit_data() return "Stream ended", 200 diff --git a/web_server/database/database.py b/web_server/database/database.py index a617752..3517e46 100644 --- a/web_server/database/database.py +++ b/web_server/database/database.py @@ -4,69 +4,50 @@ import os class Database: def __init__(self) -> None: self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db") + self._conn = None self.cursor = None - def create_connection(self) -> sqlite3.Cursor: - conn = sqlite3.connect(self._db) - conn.row_factory = sqlite3.Row - self._conn = conn - self.cursor = conn.cursor() - return self.cursor - - def fetchall(self, query: str, parameters=None) -> list[dict]: - if parameters: - self.cursor.execute(query, parameters) - else: - self.cursor.execute(query) + def create_connection(self) -> None: + """Create a database connection if not already established.""" + if self._conn is None: + self._conn = sqlite3.connect(self._db) + self._conn.row_factory = sqlite3.Row + self.cursor = self._conn.cursor() + def close_connection(self) -> None: + """Close the database connection.""" + if self._conn: + self._conn.close() + self._conn = None + self.cursor = None + + def fetchall(self, query: str, parameters=None) -> list[dict]: + """Fetch all records from the database.""" + self.create_connection() + self.cursor.execute(query, parameters or ()) result = self.cursor.fetchall() return self.convert_to_list_dict(result) - - def fetchone(self, query: str, parameters=None) -> list[dict]: - if parameters: - self.cursor.execute(query, parameters) - else: - self.cursor.execute(query) + def fetchone(self, query: str, parameters=None) -> dict | None: + """Fetch one record from the database.""" + self.create_connection() + self.cursor.execute(query, parameters or ()) result = self.cursor.fetchone() - return self.convert_to_list_dict(result) + return self.convert_to_list_dict(result) if result else None def execute(self, query: str, parameters=None) -> None: - """ - Executes a command (e.g., INSERT, UPDATE, DELETE) and commits the changes. - """ + """Execute an INSERT, UPDATE, or DELETE command and commit changes.""" + self.create_connection() try: - if parameters: - self.cursor.execute(query, parameters) - else: - self.cursor.execute(query) - self.commit_data() - except Exception as e: - print(f"Error executing command: {e}") + self.cursor.execute(query, parameters or ()) + self._conn.commit() + except sqlite3.DatabaseError as e: + print(f"Database error: {e}") raise def convert_to_list_dict(self, result): - """ - Converts a query result to a list of dictionaries - """ - # Get the column names from the cursor - columns = [description[0] for description in self.cursor.description] - + """Convert query result to a list of dictionaries.""" if not result: - # for empty result return [] - elif isinstance(result, sqlite3.Row): - # for fetchone - return dict(zip(columns, result)) - else: - # for fetchall or fetchmany - return [dict(zip(columns, row)) for row in result] - - def commit_data(self): - try: - self._conn.commit() - except Exception as e: - print(e) - - def close_connection(self) -> None: - self._conn.close() \ No newline at end of file + columns = [desc[0] for desc in self.cursor.description] + return [dict(zip(columns, row)) for row in result] if isinstance(result, list) else dict(zip(columns, result)) diff --git a/web_server/database/streaming.sql b/web_server/database/streaming.sql index e675990..4c2cad2 100644 --- a/web_server/database/streaming.sql +++ b/web_server/database/streaming.sql @@ -48,4 +48,4 @@ CREATE TABLE streams category_id NOT NULL, FOREIGN KEY (category_id) REFERENCES categories(category_id) ON DELETE CASCADE, FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE -); \ No newline at end of file +); diff --git a/web_server/utils/recommendation_utils.py b/web_server/utils/recommendation_utils.py index a772256..23325c0 100644 --- a/web_server/utils/recommendation_utils.py +++ b/web_server/utils/recommendation_utils.py @@ -6,10 +6,10 @@ def user_recommendation_category(user_id: int) -> Optional[int]: Queries user_preferences database to find users favourite streaming category and returns the category """ db = Database() - cursor = db.create_connection() + db.create_connection() - data = cursor.execute( - "SELECT category_id FROM user_preferences WHERE user_id = ? ORDER BY favourability DESC LIMIT 1", (user_id,)).fetchone() + data = db.fetchone( + "SELECT category_id FROM user_preferences WHERE user_id = ? ORDER BY favourability DESC LIMIT 1", (user_id,)) return data[0] def followed_categories_recommendations(user_id: int): diff --git a/web_server/utils/stream_utils.py b/web_server/utils/stream_utils.py index 51578bc..1efaa75 100644 --- a/web_server/utils/stream_utils.py +++ b/web_server/utils/stream_utils.py @@ -8,15 +8,15 @@ def streamer_live_status(user_id: int) -> bool: Returns boolean on whether the given streamer is live """ db = Database() - cursor = db.create_connection() - return bool(cursor.execute("SELECT 1 FROM streams WHERE user_id = ? AND isLive = 1 ORDER BY stream_id DESC", (user_id,)).fetchone()) + db.create_connection() + return bool(db.fetchone("SELECT 1 FROM streams WHERE user_id = ? AND isLive = 1 ORDER BY stream_id DESC", (user_id,))) def followed_live_streams(user_id: int) -> list[dict]: """ Searches for streamers who the user followed which are currently live """ db = Database() - cursor = db.create_connection() + db.create_connection() live_streams = db.fetchall(""" SELECT user_id, stream_id, title, num_viewers @@ -60,7 +60,7 @@ def user_stream(user_id: int, stream_id: int) -> dict: Returns data of a streamers selected stream """ db = Database() - cursor = db.create_connection() + db.create_connection() stream = db.fetchone("SELECT * FROM streams WHERE user_id = ? AND stream_id = ?", (user_id,stream_id)) return stream \ No newline at end of file diff --git a/web_server/utils/user_utils.py b/web_server/utils/user_utils.py index 277dec0..854daa6 100644 --- a/web_server/utils/user_utils.py +++ b/web_server/utils/user_utils.py @@ -14,13 +14,13 @@ def get_user_id(username: str) -> Optional[int]: Returns user_id associated with given username """ db = Database() - cursor = db.create_connection() + db.create_connection() try: - data = cursor.execute( + data = db.fetchone( "SELECT user_id FROM users WHERE username = ?", (username,) - ).fetchone() + ) return data[0] if data else None except Exception as e: print(f"Error: {e}") @@ -31,13 +31,13 @@ def get_username(user_id: str) -> Optional[str]: Returns username associated with given user_id """ db = Database() - cursor = db.create_connection() + db.create_connection() try: - data = cursor.execute( + data = db.fetchone( "SELECT username FROM user WHERE user_id = ?", (user_id,) - ).fetchone() + ) return data[0] if data else None except Exception as e: print(f"Error: {e}") @@ -48,13 +48,13 @@ def is_user_partner(user_id: int) -> bool: Returns True if user is a partner, else False """ db = Database() - cursor = db.create_connection() + db.create_connection() try: - data = cursor.execute( + data = db.fetchone( "SELECT is_partnered FROM users WHERE user_id = ?", (user_id,) - ).fetchone() + ) return bool(data) except Exception as e: print(f"Error: {e}") @@ -65,13 +65,13 @@ def is_subscribed(user_id: int, streamer_id: int) -> bool: Returns True if user is subscribed to a streamer, else False """ db = Database() - cursor = db.create_connection() + db.create_connection() try: - result = cursor.execute( + result = db.fetchone( "SELECT 1 FROM subscribes WHERE user_id = ? AND streamer_id = ? AND expires > ?", (user_id, streamer_id, datetime.now()) - ).fetchone() + ) return bool(result) except Exception as e: print(f"Error: {e}") @@ -79,13 +79,13 @@ def is_subscribed(user_id: int, streamer_id: int) -> bool: def is_following(user_id: int, followed_id: int) -> bool: db = Database() - cursor = db.create_connection() + db.create_connection() try: - result = cursor.execute( + result = db.fetchone( "SELECT 1 FROM follows WHERE user_id = ? AND followed_id = ?", (user_id, followed_id) - ).fetchone() + ) return bool(result) except Exception as e: print(f"Error: {e}") @@ -96,11 +96,11 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int: Returns the amount of time left until user subscription to a streamer ends """ db = Database() - cursor = db.create_connection() + db.create_connection() remaining_time = 0 try: - data = cursor.execute( - "SELECT expires from subscriptions WHERE user_id = ? AND subscribed_id = ? AND expires > since", (user_id,subscribed_id)).fetchone() + data = db.fetchone( + "SELECT expires from subscriptions WHERE user_id = ? AND subscribed_id = ? AND expires > since", (user_id,subscribed_id)) if data: expiration_date = data[0] @@ -130,7 +130,7 @@ def reset_password(new_password: str, email: str): try: db.execute("UPDATE users SET password = ? WHERE email = ?", (generate_password_hash(new_password), email)) - db.commit() + db.commit_data() return True except Exception as e: print(f"Error: {e}") diff --git a/web_server/utils/utils.py b/web_server/utils/utils.py index 2fe6fa1..f036b0b 100644 --- a/web_server/utils/utils.py +++ b/web_server/utils/utils.py @@ -5,8 +5,8 @@ def categories(): Returns all possible streaming categories """ db = Database() - cursor = db.create_connection() - all_categories = cursor.execute("SELECT * FROM categories").fetchall() + db.create_connection() + all_categories = db.fetchall("SELECT * FROM categories") return all_categories def tags(): @@ -14,8 +14,8 @@ def tags(): Returns all possible streaming tags """ db = Database() - cursor = db.create_connection() - all_tags = cursor.execute("SELECT * FROM tags").fetchall() + db.create_connection() + all_tags = db.fetchall("SELECT * FROM tags") return all_tags def most_popular_category(): @@ -23,9 +23,9 @@ def most_popular_category(): Returns the most popular category based on live stream viewers """ db = Database() - cursor = db.create_connection() + db.create_connection() - category = cursor.execute(""" + category = db.fetchone(""" SELECT categories.category_id, categories.category_name FROM streams JOIN categories ON streams.category_id = categories.category_id @@ -33,7 +33,7 @@ def most_popular_category(): GROUP BY categories.category_name ORDER BY SUM(streams.num_viewers) DESC LIMIT 1; - """).fetchone() + """) return category