diff --git a/web_server/blueprints/authentication.py b/web_server/blueprints/authentication.py index 3d35e92..dfe15e5 100644 --- a/web_server/blueprints/authentication.py +++ b/web_server/blueprints/authentication.py @@ -48,7 +48,6 @@ def signup(): # Create a connection to the database db = Database() - db.create_connection() try: # Check for duplicate email/username, no two users can have the same @@ -150,7 +149,6 @@ def login(): # Create a connection to the database db = Database() - db.create_connection() try: # Check if user exists, only existing users can be logged in diff --git a/web_server/blueprints/chat.py b/web_server/blueprints/chat.py index d113b54..00087be 100644 --- a/web_server/blueprints/chat.py +++ b/web_server/blueprints/chat.py @@ -51,7 +51,6 @@ def get_past_chat(stream_id: int): # Connect to the database db = Database() - db.create_connection() # fetched in format: [(chatter_id, message, time_sent)] all_chats = db.fetchall(""" @@ -103,7 +102,6 @@ def send_chat(data) -> None: def save_chat(chatter_id, stream_id, message): """Save the chat to the database""" db = Database() - db.create_connection() db.execute(""" INSERT INTO chat (chatter_id, stream_id, message) VALUES (?, ?, ?);""", (chatter_id, stream_id, message)) diff --git a/web_server/blueprints/email.py b/web_server/blueprints/email.py index f581dd3..f8838c9 100644 --- a/web_server/blueprints/email.py +++ b/web_server/blueprints/email.py @@ -1,5 +1,4 @@ from flask import Blueprint, session -from database.db_context import get_db import smtplib from email.mime.text import MIMEText from os import getenv @@ -20,7 +19,6 @@ def send_email() -> None: # Get the users email address db = get_db() - db.create_connection() user_email = db.fetchone(""" SELECT email FROM users diff --git a/web_server/blueprints/streams.py b/web_server/blueprints/streams.py index 9afa255..6ca9c63 100644 --- a/web_server/blueprints/streams.py +++ b/web_server/blueprints/streams.py @@ -8,7 +8,13 @@ from utils.stream_utils import ( ) from utils.user_utils import get_user_id from blueprints.utils import login_required -from utils.recommendation_utils import default_recommendations, recommendations_based_on_category, user_recommendation_category, followed_categories_recommendations +from utils.recommendation_utils import ( + default_recommendations, + recommendations_based_on_category, + user_recommendation_category, + followed_categories_recommendations, + category_recommendations +) from utils.utils import most_popular_category from database.database import Database from datetime import datetime @@ -42,12 +48,11 @@ def get_recommended_streams() -> list[dict]: @stream_bp.route('/get_categories') def get_categories() -> list[dict]: """ - Returns a list of most watched categories + Returns a list of top 5 most popular categories """ - category_data = most_popular_category() - streams = recommendations_based_on_category(category_data['category_id']) - return jsonify(streams) + category_data = category_recommendations() + return jsonify(category_data) @login_required @stream_bp.route('/get_recommended_categories') @@ -60,7 +65,6 @@ def get_recommended_categories() -> list | list[dict]: user_id = get_user_id(username) db = Database() - db.create_connection() categories = db.fetchall("""SELECT categories.category_id, categories.category_name, favourability FROM categories, user_preferences WHERE user_id = ? AND categories.category_id = user_preferences.category_id, @@ -166,7 +170,6 @@ def publish_stream(): # Check if stream key is valid db = Database() - db.create_connection() user_info = db.fetchone("""SELECT user_id, username, current_stream_title, current_selected_category_id FROM users WHERE stream_key = ?""", (stream_key,)) @@ -190,7 +193,6 @@ def end_stream(): Ends a stream """ db = Database() - db.create_connection() user_info = db.fetchone("""SELECT user_id FROM users WHERE stream_key = ?""", (request.form.get("name"),)) if not user_info: diff --git a/web_server/database/database.py b/web_server/database/database.py index 3517e46..3248d69 100644 --- a/web_server/database/database.py +++ b/web_server/database/database.py @@ -6,6 +6,15 @@ class Database: self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db") self._conn = None self.cursor = None + self.create_connection() + + def __enter__(self): + """Returns db on using with clause""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Closes db connection after with clause""" + self.close_connection() def create_connection(self) -> None: """Create a database connection if not already established.""" diff --git a/web_server/utils/recommendation_utils.py b/web_server/utils/recommendation_utils.py index 85cb460..092dbc7 100644 --- a/web_server/utils/recommendation_utils.py +++ b/web_server/utils/recommendation_utils.py @@ -5,28 +5,28 @@ def user_recommendation_category(user_id: int) -> Optional[int]: """ Queries user_preferences database to find users favourite streaming category and returns the category """ - db = Database() - db.create_connection() - - data = db.fetchone( - "SELECT category_id FROM user_preferences WHERE user_id = ? ORDER BY favourability DESC LIMIT 1", (user_id,)) - db.close_connection() + with Database() as db: + data = db.fetchone(""" + SELECT category_id + FROM user_preferences + WHERE user_id = ? + ORDER BY favourability DESC + LIMIT 1 + """, (user_id,)) return data def followed_categories_recommendations(user_id: int): """ Returns top 25 streams given a users category following """ - db = Database() - db.create_connection() - - categories = db.fetchall(""" - SELECT users.user_id, title, username, num_viewers, category_name - FROM streams - WHERE category_id IN (SELECT category_id FROM categories WHERE user_id = ?) - ORDER BY num_viewers DESC - LIMIT 25; """, (user_id,)) - db.close_connection() + with Database() as db: + categories = db.fetchall(""" + SELECT users.user_id, title, username, num_viewers, category_name + FROM streams + WHERE category_id IN (SELECT category_id FROM categories WHERE user_id = ?) + ORDER BY num_viewers DESC + LIMIT 25; + """, (user_id,)) return categories def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[int, str, int]]]: @@ -34,18 +34,16 @@ def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[i Queries stream database to get top 25 most viewed streams based on given category and returns (user_id, title, username, num_viewers, category_name) """ - db = Database() - db.create_connection() - - data = db.fetchall(""" - SELECT users.user_id, title, username, num_viewers, category_name - FROM streams - JOIN users ON users.user_id = streams.user_id - JOIN categories ON streams.category_id = categories.category_id - WHERE categories.category_id = ? - ORDER BY num_viewers DESC - LIMIT 25""", (category_id,)) - db.close_connection() + with Database() as db: + data = db.fetchall(""" + SELECT users.user_id, title, username, num_viewers, category_name + FROM streams + JOIN users ON users.user_id = streams.user_id + JOIN categories ON streams.category_id = categories.category_id + WHERE categories.category_id = ? + ORDER BY num_viewers DESC + LIMIT 25 + """, (category_id,)) return data def default_recommendations(): @@ -53,16 +51,30 @@ def default_recommendations(): Return a list of 25 recommended live streams by number of viewers (user_id, title, username, num_viewers, category_name) """ - db = Database() - db.create_connection() - data = db.fetchall(""" - SELECT users.user_id, title, username, num_viewers, category_name - FROM streams - JOIN users ON users.user_id = streams.user_id - JOIN categories ON streams.category_id = categories.category_id - ORDER BY num_viewers DESC - LIMIT 25; - """) - db.close_connection() + with Database() as db: + data = db.fetchall(""" + SELECT users.user_id, title, username, num_viewers, category_name + FROM streams + JOIN users ON users.user_id = streams.user_id + JOIN categories ON streams.category_id = categories.category_id + ORDER BY num_viewers DESC + LIMIT 25; + """) return data +def category_recommendations(): + """ + Returns a list of the top 5 most popular categories + """ + with Database() as db: + categories = db.fetchall(""" + SELECT categories.category_id, categories.category_name + FROM streams + JOIN categories ON streams.category_id = categories.category_id + WHERE streams.isLive = 1 + GROUP BY categories.category_name + ORDER BY SUM(streams.num_viewers) DESC + LIMIT 5; + """) + return categories + diff --git a/web_server/utils/stream_utils.py b/web_server/utils/stream_utils.py index 3e5f6a9..270b9c8 100644 --- a/web_server/utils/stream_utils.py +++ b/web_server/utils/stream_utils.py @@ -7,66 +7,60 @@ def streamer_live_status(user_id: int) -> bool: """ Returns boolean on whether the given streamer is live """ - db = Database() - db.create_connection() - is_live = db.fetchone("""SELECT isLive FROM streams WHERE user_id = ?""", (user_id,)) - db.close_connection() + with Database() as db: + is_live = db.fetchone(""" + SELECT isLive + FROM streams + WHERE user_id = ? + """, (user_id,)) return is_live def followed_live_streams(user_id: int) -> list[dict]: """ Searches for streamers who the user followed which are currently live """ - db = Database() - db.create_connection() - - live_streams = db.fetchall(""" - SELECT user_id, stream_id, title, num_viewers - FROM streams - WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?) - AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = streams.user_id) - AND isLive = 1; + with Database() as db: + live_streams = db.fetchall(""" + SELECT user_id, stream_id, title, num_viewers + FROM streams + WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?) + AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = streams.user_id) + AND isLive = 1; """, (user_id,)) - db.close_connection() - return live_streams def followed_streamers(user_id: int) -> list[dict]: """ Returns a list of streamers who the user follows """ - db = Database() - db.create_connection() - - followed_streamers = db.fetchall(""" - SELECT user_id, username - FROM users - WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?); + with Database() as db: + followed_streamers = db.fetchall(""" + SELECT user_id, username + FROM users + WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?); """, (user_id,)) - - db.close_connection() return followed_streamers def streamer_most_recent_stream(user_id: int) -> dict: """ Returns data of the most recent stream by a streamer """ - db = Database() - db.create_connection() - most_recent_stream = db.fetchone("""SELECT * FROM streams WHERE - user_id = ? AND - stream_id = (SELECT MAX(stream_id) FROM - streams WHERE user_id = ?)""", (user_id, user_id)) - db.close_connection() + with Database() as db: + most_recent_stream = db.fetchone(""" + SELECT * FROM streams + WHERE user_id = ? + AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = ?) + """, (user_id, user_id)) return most_recent_stream def user_stream(user_id: int, stream_id: int) -> dict: """ Returns data of a streamers selected stream """ - db = Database() - db.create_connection() - stream = db.fetchone("SELECT * FROM streams WHERE user_id = ? AND stream_id = ?", (user_id, stream_id)) - db.close_connection() - + with Database() as db: + stream = db.fetchone(""" + SELECT * FROM streams + WHERE user_id = ? + AND stream_id = ? + """, (user_id, stream_id)) return stream \ No newline at end of file diff --git a/web_server/utils/user_utils.py b/web_server/utils/user_utils.py index 5d56d56..9b7009e 100644 --- a/web_server/utils/user_utils.py +++ b/web_server/utils/user_utils.py @@ -13,138 +13,101 @@ def get_user_id(username: str) -> int: """ Returns user_id associated with given username """ - db = Database() - db.create_connection() - - try: - data = db.fetchone( - "SELECT user_id FROM users WHERE username = ?", - (username,) - ) - return data['user_id'] if data else None - except Exception as e: - print(f"Error: {e}") - return None - finally: - db.close_connection() + with Database() as db: + data = db.fetchone(""" + SELECT user_id + FROM users + WHERE username = ? + """, (username,)) + return data['user_id'] if data else None def get_username(user_id: str) -> Optional[str]: """ Returns username associated with given user_id """ - db = Database() - db.create_connection() - - try: - data = db.fetchone( - "SELECT username FROM user WHERE user_id = ?", - (user_id,) - ) - return data[0] if data else None - except Exception as e: - print(f"Error: {e}") - return None - finally: - db.close_connection() + with Database() as db: + data = db.fetchone(""" + SELECT username + FROM user + WHERE user_id = ? + """, (user_id,)) + return data['username'] if data else None def is_user_partner(user_id: int) -> bool: """ Returns True if user is a partner, else False """ - db = Database() - db.create_connection() - - try: - data = db.fetchone( - "SELECT is_partnered FROM users WHERE user_id = ?", - (user_id,) - ) - return bool(data) - except Exception as e: - print(f"Error: {e}") - return False - finally: - db.close_connection() + with Database() as db: + data = db.fetchone(""" + SELECT is_partnered + FROM users + WHERE user_id = ? + """, (user_id,)) + return bool(data) def is_subscribed(user_id: int, streamer_id: int) -> bool: """ Returns True if user is subscribed to a streamer, else False """ - db = Database() - db.create_connection() - - try: - result = db.fetchone( - "SELECT 1 FROM subscribes WHERE user_id = ? AND streamer_id = ? AND expires > ?", - (user_id, streamer_id, datetime.now()) - ) - return bool(result) - except Exception as e: - print(f"Error: {e}") - return False - finally: - db.close_connection() + with Database() as db: + result = db.fetchone(""" + SELECT 1 + FROM subscribes + WHERE user_id = ? + AND streamer_id = ? + AND expires > ? + """, (user_id, streamer_id, datetime.now())) + return bool(result) def is_following(user_id: int, followed_id: int) -> bool: - db = Database() - db.create_connection() - - try: - result = db.fetchone( - "SELECT 1 FROM follows WHERE user_id = ? AND followed_id = ?", - (user_id, followed_id) - ) - return bool(result) - except Exception as e: - print(f"Error: {e}") - return False - finally: - db.close_connection() + """ + Returns where a user is following another + """ + with Database() as db: + result = db.fetchone(""" + SELECT 1 + FROM follows + WHERE user_id = ? + AND followed_id = ? + """, (user_id, followed_id)) + return bool(result) def subscription_expiration(user_id: int, subscribed_id: int) -> int: """ Returns the amount of time left until user subscription to a streamer ends """ - db = Database() - db.create_connection() - remaining_time = 0 - try: - data = db.fetchone( - "SELECT expires from subscriptions WHERE user_id = ? AND subscribed_id = ? AND expires > since", (user_id,subscribed_id)) - if data: - expiration_date = data["expires"] + with Database() as db: + data = db.fetchone(""" + SELECT expires + FROM subscriptions + WHERE user_id = ? + AND subscribed_id = ? + AND expires > ? + """, (user_id, subscribed_id, datetime.now())) - remaining_time = (expiration_date - datetime.now()).seconds - except Exception as e: - print(f"Error: {e}") - finally: - db.close_connection() + if data: + expiration_date = data["expires"] + remaining_time = (expiration_date - datetime.now()).seconds + return remaining_time - return remaining_time + return 0 def verify_token(token: str): """ Given a token verifies token and decodes the token into an email """ - try: - email = serializer.loads(token, salt='1', max_age=3600) - return email - except Exception as e: - print(f"Error: {e}") - return False + email = serializer.loads(token, salt='1', max_age=3600) + return email if email else False def reset_password(new_password: str, email: str): """ Given email and new password reset the password for a given user """ - db = Database() - db.create_connection() - - try: - db.execute("UPDATE users SET password = ? WHERE email = ?", (generate_password_hash(new_password), email)) - return True - except Exception as e: - print(f"Error: {e}") - return False - finally: - db.close_connection() \ No newline at end of file + with Database() as db: + db.execute(""" + UPDATE users + SET password = ? + WHERE email = ? + """, (generate_password_hash(new_password), email)) + + return True \ No newline at end of file diff --git a/web_server/utils/utils.py b/web_server/utils/utils.py index f036b0b..f2012ce 100644 --- a/web_server/utils/utils.py +++ b/web_server/utils/utils.py @@ -4,36 +4,34 @@ def categories(): """ Returns all possible streaming categories """ - db = Database() - db.create_connection() - all_categories = db.fetchall("SELECT * FROM categories") + with Database() as db: + all_categories = db.fetchall("SELECT * FROM categories") + return all_categories def tags(): """ Returns all possible streaming tags """ - db = Database() - db.create_connection() - all_tags = db.fetchall("SELECT * FROM tags") + with Database() as db: + all_tags = db.fetchall("SELECT * FROM tags") + return all_tags def most_popular_category(): """ Returns the most popular category based on live stream viewers """ - db = Database() - db.create_connection() - - category = db.fetchone(""" - SELECT categories.category_id, categories.category_name - FROM streams - JOIN categories ON streams.category_id = categories.category_id - WHERE streams.isLive = 1 - GROUP BY categories.category_name - ORDER BY SUM(streams.num_viewers) DESC - LIMIT 1; - """) - + with Database() as db: + category = db.fetchone(""" + SELECT categories.category_id, categories.category_name + FROM streams + JOIN categories ON streams.category_id = categories.category_id + WHERE streams.isLive = 1 + GROUP BY categories.category_name + ORDER BY SUM(streams.num_viewers) DESC + LIMIT 1; + """) + return category