MAJOR: Reworked database to be easier to use and close automatically to prevent resource leaks

This commit is contained in:
JustIceO7
2025-01-29 17:47:29 +00:00
parent 83b458ed99
commit dfdbe4a7d4
9 changed files with 181 additions and 209 deletions

View File

@@ -48,7 +48,6 @@ def signup():
# Create a connection to the database # Create a connection to the database
db = Database() db = Database()
db.create_connection()
try: try:
# Check for duplicate email/username, no two users can have the same # Check for duplicate email/username, no two users can have the same
@@ -150,7 +149,6 @@ def login():
# Create a connection to the database # Create a connection to the database
db = Database() db = Database()
db.create_connection()
try: try:
# Check if user exists, only existing users can be logged in # Check if user exists, only existing users can be logged in

View File

@@ -51,7 +51,6 @@ def get_past_chat(stream_id: int):
# Connect to the database # Connect to the database
db = Database() db = Database()
db.create_connection()
# fetched in format: [(chatter_id, message, time_sent)] # fetched in format: [(chatter_id, message, time_sent)]
all_chats = db.fetchall(""" all_chats = db.fetchall("""
@@ -103,7 +102,6 @@ def send_chat(data) -> None:
def save_chat(chatter_id, stream_id, message): def save_chat(chatter_id, stream_id, message):
"""Save the chat to the database""" """Save the chat to the database"""
db = Database() db = Database()
db.create_connection()
db.execute(""" db.execute("""
INSERT INTO chat (chatter_id, stream_id, message) INSERT INTO chat (chatter_id, stream_id, message)
VALUES (?, ?, ?);""", (chatter_id, stream_id, message)) VALUES (?, ?, ?);""", (chatter_id, stream_id, message))

View File

@@ -1,5 +1,4 @@
from flask import Blueprint, session from flask import Blueprint, session
from database.db_context import get_db
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from os import getenv from os import getenv
@@ -20,7 +19,6 @@ def send_email() -> None:
# Get the users email address # Get the users email address
db = get_db() db = get_db()
db.create_connection()
user_email = db.fetchone(""" user_email = db.fetchone("""
SELECT email SELECT email
FROM users FROM users

View File

@@ -8,7 +8,13 @@ from utils.stream_utils import (
) )
from utils.user_utils import get_user_id from utils.user_utils import get_user_id
from blueprints.utils import login_required from blueprints.utils import login_required
from utils.recommendation_utils import default_recommendations, recommendations_based_on_category, user_recommendation_category, followed_categories_recommendations from utils.recommendation_utils import (
default_recommendations,
recommendations_based_on_category,
user_recommendation_category,
followed_categories_recommendations,
category_recommendations
)
from utils.utils import most_popular_category from utils.utils import most_popular_category
from database.database import Database from database.database import Database
from datetime import datetime from datetime import datetime
@@ -42,12 +48,11 @@ def get_recommended_streams() -> list[dict]:
@stream_bp.route('/get_categories') @stream_bp.route('/get_categories')
def get_categories() -> list[dict]: def get_categories() -> list[dict]:
""" """
Returns a list of most watched categories Returns a list of top 5 most popular categories
""" """
category_data = most_popular_category() category_data = category_recommendations()
streams = recommendations_based_on_category(category_data['category_id']) return jsonify(category_data)
return jsonify(streams)
@login_required @login_required
@stream_bp.route('/get_recommended_categories') @stream_bp.route('/get_recommended_categories')
@@ -60,7 +65,6 @@ def get_recommended_categories() -> list | list[dict]:
user_id = get_user_id(username) user_id = get_user_id(username)
db = Database() db = Database()
db.create_connection()
categories = db.fetchall("""SELECT categories.category_id, categories.category_name, favourability categories = db.fetchall("""SELECT categories.category_id, categories.category_name, favourability
FROM categories, user_preferences FROM categories, user_preferences
WHERE user_id = ? AND categories.category_id = user_preferences.category_id, WHERE user_id = ? AND categories.category_id = user_preferences.category_id,
@@ -166,7 +170,6 @@ def publish_stream():
# Check if stream key is valid # Check if stream key is valid
db = Database() db = Database()
db.create_connection()
user_info = db.fetchone("""SELECT user_id, username, current_stream_title, current_selected_category_id user_info = db.fetchone("""SELECT user_id, username, current_stream_title, current_selected_category_id
FROM users FROM users
WHERE stream_key = ?""", (stream_key,)) WHERE stream_key = ?""", (stream_key,))
@@ -190,7 +193,6 @@ def end_stream():
Ends a stream Ends a stream
""" """
db = Database() db = Database()
db.create_connection()
user_info = db.fetchone("""SELECT user_id FROM users WHERE stream_key = ?""", (request.form.get("name"),)) user_info = db.fetchone("""SELECT user_id FROM users WHERE stream_key = ?""", (request.form.get("name"),))
if not user_info: if not user_info:

View File

@@ -6,6 +6,15 @@ class Database:
self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db") self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db")
self._conn = None self._conn = None
self.cursor = None self.cursor = None
self.create_connection()
def __enter__(self):
"""Returns db on using with clause"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Closes db connection after with clause"""
self.close_connection()
def create_connection(self) -> None: def create_connection(self) -> None:
"""Create a database connection if not already established.""" """Create a database connection if not already established."""

View File

@@ -5,28 +5,28 @@ def user_recommendation_category(user_id: int) -> Optional[int]:
""" """
Queries user_preferences database to find users favourite streaming category and returns the category Queries user_preferences database to find users favourite streaming category and returns the category
""" """
db = Database() with Database() as db:
db.create_connection() data = db.fetchone("""
SELECT category_id
data = db.fetchone( FROM user_preferences
"SELECT category_id FROM user_preferences WHERE user_id = ? ORDER BY favourability DESC LIMIT 1", (user_id,)) WHERE user_id = ?
db.close_connection() ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
return data return data
def followed_categories_recommendations(user_id: int): def followed_categories_recommendations(user_id: int):
""" """
Returns top 25 streams given a users category following Returns top 25 streams given a users category following
""" """
db = Database() with Database() as db:
db.create_connection()
categories = db.fetchall(""" categories = db.fetchall("""
SELECT users.user_id, title, username, num_viewers, category_name SELECT users.user_id, title, username, num_viewers, category_name
FROM streams FROM streams
WHERE category_id IN (SELECT category_id FROM categories WHERE user_id = ?) WHERE category_id IN (SELECT category_id FROM categories WHERE user_id = ?)
ORDER BY num_viewers DESC ORDER BY num_viewers DESC
LIMIT 25; """, (user_id,)) LIMIT 25;
db.close_connection() """, (user_id,))
return categories return categories
def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[int, str, int]]]: def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[int, str, int]]]:
@@ -34,9 +34,7 @@ def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[i
Queries stream database to get top 25 most viewed streams based on given category and returns Queries stream database to get top 25 most viewed streams based on given category and returns
(user_id, title, username, num_viewers, category_name) (user_id, title, username, num_viewers, category_name)
""" """
db = Database() with Database() as db:
db.create_connection()
data = db.fetchall(""" data = db.fetchall("""
SELECT users.user_id, title, username, num_viewers, category_name SELECT users.user_id, title, username, num_viewers, category_name
FROM streams FROM streams
@@ -44,8 +42,8 @@ def recommendations_based_on_category(category_id: int) -> Optional[List[Tuple[i
JOIN categories ON streams.category_id = categories.category_id JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id = ? WHERE categories.category_id = ?
ORDER BY num_viewers DESC ORDER BY num_viewers DESC
LIMIT 25""", (category_id,)) LIMIT 25
db.close_connection() """, (category_id,))
return data return data
def default_recommendations(): def default_recommendations():
@@ -53,8 +51,7 @@ def default_recommendations():
Return a list of 25 recommended live streams by number of viewers Return a list of 25 recommended live streams by number of viewers
(user_id, title, username, num_viewers, category_name) (user_id, title, username, num_viewers, category_name)
""" """
db = Database() with Database() as db:
db.create_connection()
data = db.fetchall(""" data = db.fetchall("""
SELECT users.user_id, title, username, num_viewers, category_name SELECT users.user_id, title, username, num_viewers, category_name
FROM streams FROM streams
@@ -63,6 +60,21 @@ def default_recommendations():
ORDER BY num_viewers DESC ORDER BY num_viewers DESC
LIMIT 25; LIMIT 25;
""") """)
db.close_connection()
return data return data
def category_recommendations():
"""
Returns a list of the top 5 most popular categories
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name
FROM streams
JOIN categories ON streams.category_id = categories.category_id
WHERE streams.isLive = 1
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 5;
""")
return categories

View File

@@ -7,19 +7,19 @@ def streamer_live_status(user_id: int) -> bool:
""" """
Returns boolean on whether the given streamer is live Returns boolean on whether the given streamer is live
""" """
db = Database() with Database() as db:
db.create_connection() is_live = db.fetchone("""
is_live = db.fetchone("""SELECT isLive FROM streams WHERE user_id = ?""", (user_id,)) SELECT isLive
db.close_connection() FROM streams
WHERE user_id = ?
""", (user_id,))
return is_live return is_live
def followed_live_streams(user_id: int) -> list[dict]: def followed_live_streams(user_id: int) -> list[dict]:
""" """
Searches for streamers who the user followed which are currently live Searches for streamers who the user followed which are currently live
""" """
db = Database() with Database() as db:
db.create_connection()
live_streams = db.fetchall(""" live_streams = db.fetchall("""
SELECT user_id, stream_id, title, num_viewers SELECT user_id, stream_id, title, num_viewers
FROM streams FROM streams
@@ -27,46 +27,40 @@ def followed_live_streams(user_id: int) -> list[dict]:
AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = streams.user_id) AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = streams.user_id)
AND isLive = 1; AND isLive = 1;
""", (user_id,)) """, (user_id,))
db.close_connection()
return live_streams return live_streams
def followed_streamers(user_id: int) -> list[dict]: def followed_streamers(user_id: int) -> list[dict]:
""" """
Returns a list of streamers who the user follows Returns a list of streamers who the user follows
""" """
db = Database() with Database() as db:
db.create_connection()
followed_streamers = db.fetchall(""" followed_streamers = db.fetchall("""
SELECT user_id, username SELECT user_id, username
FROM users FROM users
WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?); WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?);
""", (user_id,)) """, (user_id,))
db.close_connection()
return followed_streamers return followed_streamers
def streamer_most_recent_stream(user_id: int) -> dict: def streamer_most_recent_stream(user_id: int) -> dict:
""" """
Returns data of the most recent stream by a streamer Returns data of the most recent stream by a streamer
""" """
db = Database() with Database() as db:
db.create_connection() most_recent_stream = db.fetchone("""
most_recent_stream = db.fetchone("""SELECT * FROM streams WHERE SELECT * FROM streams
user_id = ? AND WHERE user_id = ?
stream_id = (SELECT MAX(stream_id) FROM AND stream_id = (SELECT MAX(stream_id) FROM streams WHERE user_id = ?)
streams WHERE user_id = ?)""", (user_id, user_id)) """, (user_id, user_id))
db.close_connection()
return most_recent_stream return most_recent_stream
def user_stream(user_id: int, stream_id: int) -> dict: def user_stream(user_id: int, stream_id: int) -> dict:
""" """
Returns data of a streamers selected stream Returns data of a streamers selected stream
""" """
db = Database() with Database() as db:
db.create_connection() stream = db.fetchone("""
stream = db.fetchone("SELECT * FROM streams WHERE user_id = ? AND stream_id = ?", (user_id, stream_id)) SELECT * FROM streams
db.close_connection() WHERE user_id = ?
AND stream_id = ?
""", (user_id, stream_id))
return stream return stream

View File

@@ -13,138 +13,101 @@ def get_user_id(username: str) -> int:
""" """
Returns user_id associated with given username Returns user_id associated with given username
""" """
db = Database() with Database() as db:
db.create_connection() data = db.fetchone("""
SELECT user_id
try: FROM users
data = db.fetchone( WHERE username = ?
"SELECT user_id FROM users WHERE username = ?", """, (username,))
(username,)
)
return data['user_id'] if data else None return data['user_id'] if data else None
except Exception as e:
print(f"Error: {e}")
return None
finally:
db.close_connection()
def get_username(user_id: str) -> Optional[str]: def get_username(user_id: str) -> Optional[str]:
""" """
Returns username associated with given user_id Returns username associated with given user_id
""" """
db = Database() with Database() as db:
db.create_connection() data = db.fetchone("""
SELECT username
try: FROM user
data = db.fetchone( WHERE user_id = ?
"SELECT username FROM user WHERE user_id = ?", """, (user_id,))
(user_id,) return data['username'] if data else None
)
return data[0] if data else None
except Exception as e:
print(f"Error: {e}")
return None
finally:
db.close_connection()
def is_user_partner(user_id: int) -> bool: def is_user_partner(user_id: int) -> bool:
""" """
Returns True if user is a partner, else False Returns True if user is a partner, else False
""" """
db = Database() with Database() as db:
db.create_connection() data = db.fetchone("""
SELECT is_partnered
try: FROM users
data = db.fetchone( WHERE user_id = ?
"SELECT is_partnered FROM users WHERE user_id = ?", """, (user_id,))
(user_id,)
)
return bool(data) return bool(data)
except Exception as e:
print(f"Error: {e}")
return False
finally:
db.close_connection()
def is_subscribed(user_id: int, streamer_id: int) -> bool: def is_subscribed(user_id: int, streamer_id: int) -> bool:
""" """
Returns True if user is subscribed to a streamer, else False Returns True if user is subscribed to a streamer, else False
""" """
db = Database() with Database() as db:
db.create_connection() result = db.fetchone("""
SELECT 1
try: FROM subscribes
result = db.fetchone( WHERE user_id = ?
"SELECT 1 FROM subscribes WHERE user_id = ? AND streamer_id = ? AND expires > ?", AND streamer_id = ?
(user_id, streamer_id, datetime.now()) AND expires > ?
) """, (user_id, streamer_id, datetime.now()))
return bool(result) return bool(result)
except Exception as e:
print(f"Error: {e}")
return False
finally:
db.close_connection()
def is_following(user_id: int, followed_id: int) -> bool: def is_following(user_id: int, followed_id: int) -> bool:
db = Database() """
db.create_connection() Returns where a user is following another
"""
try: with Database() as db:
result = db.fetchone( result = db.fetchone("""
"SELECT 1 FROM follows WHERE user_id = ? AND followed_id = ?", SELECT 1
(user_id, followed_id) FROM follows
) WHERE user_id = ?
AND followed_id = ?
""", (user_id, followed_id))
return bool(result) return bool(result)
except Exception as e:
print(f"Error: {e}")
return False
finally:
db.close_connection()
def subscription_expiration(user_id: int, subscribed_id: int) -> int: def subscription_expiration(user_id: int, subscribed_id: int) -> int:
""" """
Returns the amount of time left until user subscription to a streamer ends Returns the amount of time left until user subscription to a streamer ends
""" """
db = Database() with Database() as db:
db.create_connection() data = db.fetchone("""
remaining_time = 0 SELECT expires
try: FROM subscriptions
data = db.fetchone( WHERE user_id = ?
"SELECT expires from subscriptions WHERE user_id = ? AND subscribed_id = ? AND expires > since", (user_id,subscribed_id)) AND subscribed_id = ?
AND expires > ?
""", (user_id, subscribed_id, datetime.now()))
if data: if data:
expiration_date = data["expires"] expiration_date = data["expires"]
remaining_time = (expiration_date - datetime.now()).seconds remaining_time = (expiration_date - datetime.now()).seconds
except Exception as e:
print(f"Error: {e}")
finally:
db.close_connection()
return remaining_time return remaining_time
return 0
def verify_token(token: str): def verify_token(token: str):
""" """
Given a token verifies token and decodes the token into an email Given a token verifies token and decodes the token into an email
""" """
try:
email = serializer.loads(token, salt='1', max_age=3600) email = serializer.loads(token, salt='1', max_age=3600)
return email return email if email else False
except Exception as e:
print(f"Error: {e}")
return False
def reset_password(new_password: str, email: str): def reset_password(new_password: str, email: str):
""" """
Given email and new password reset the password for a given user Given email and new password reset the password for a given user
""" """
db = Database() with Database() as db:
db.create_connection() db.execute("""
UPDATE users
SET password = ?
WHERE email = ?
""", (generate_password_hash(new_password), email))
try:
db.execute("UPDATE users SET password = ? WHERE email = ?", (generate_password_hash(new_password), email))
return True return True
except Exception as e:
print(f"Error: {e}")
return False
finally:
db.close_connection()

View File

@@ -4,27 +4,25 @@ def categories():
""" """
Returns all possible streaming categories Returns all possible streaming categories
""" """
db = Database() with Database() as db:
db.create_connection()
all_categories = db.fetchall("SELECT * FROM categories") all_categories = db.fetchall("SELECT * FROM categories")
return all_categories return all_categories
def tags(): def tags():
""" """
Returns all possible streaming tags Returns all possible streaming tags
""" """
db = Database() with Database() as db:
db.create_connection()
all_tags = db.fetchall("SELECT * FROM tags") all_tags = db.fetchall("SELECT * FROM tags")
return all_tags return all_tags
def most_popular_category(): def most_popular_category():
""" """
Returns the most popular category based on live stream viewers Returns the most popular category based on live stream viewers
""" """
db = Database() with Database() as db:
db.create_connection()
category = db.fetchone(""" category = db.fetchone("""
SELECT categories.category_id, categories.category_name SELECT categories.category_id, categories.category_name
FROM streams FROM streams