Fix/pylint cleanup (#8)
Some checks are pending
CI / build (3.10) (push) Waiting to run
CI / build (3.8) (push) Waiting to run
CI / build (3.9) (push) Waiting to run

* Fix pylint warnings across all 24 Python files in web_server

- Add module, class, and function docstrings (C0114, C0115, C0116)
- Fix import ordering: stdlib before third-party before local (C0411)
- Replace wildcard imports with explicit named imports (W0401)
- Remove trailing whitespace and add missing final newlines (C0303, C0304)
- Replace dict() with dict literals (R1735)
- Remove unused imports and variables (W0611, W0612)
- Narrow broad Exception catches to specific exceptions (W0718)
- Replace f-string logging with lazy % formatting (W1203)
- Fix variable naming: UPPER_CASE for constants, snake_case for locals (C0103)
- Add pylint disable comments for necessary global statements (W0603)
- Fix no-else-return, simplifiable-if-expression, singleton-comparison
- Fix bad indentation in stripe.py (W0311)
- Add encoding="utf-8" to open() calls (W1514)
- Add check=True to subprocess.run() calls (W1510)
- Register Celery task modules via conf.include

* Update `package-lock.json` add peer dependencies
This commit is contained in:
Christopher Ahern
2026-02-07 20:57:28 +00:00
committed by GitHub
parent fed1a2f288
commit 2758be8680
25 changed files with 680 additions and 419 deletions

View File

@@ -1,8 +1,10 @@
"""Admin utility functions for user management."""
from database.database import Database
def check_if_admin(username):
"""
Returns whether user is admin
Returns whether user is admin
"""
with Database() as db:
is_admin = db.fetchone("""
@@ -10,7 +12,7 @@ def check_if_admin(username):
FROM users
WHERE username = ?;
""", (username,))
return bool(is_admin)
def check_if_user_exists(banned_user):
@@ -19,11 +21,11 @@ def check_if_user_exists(banned_user):
"""
with Database() as db:
user_exists = db.fetchone("""
SELECT user_id
SELECT user_id
FROM users
WHERE username = ?;""",
(banned_user,))
return bool(user_exists)
def ban_user(banned_user):
@@ -33,6 +35,5 @@ def ban_user(banned_user):
with Database() as db:
db.execute("""
DELETE FROM users
WHERE username = ?;""",
(banned_user)
)
WHERE username = ?;""",
(banned_user,))

View File

@@ -1,13 +1,17 @@
"""Token generation and verification for password resets."""
from typing import Optional
from os import getenv
from database.database import Database
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from typing import Optional
from dotenv import load_dotenv
from os import getenv
from werkzeug.security import generate_password_hash
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY"))
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
@@ -19,7 +23,6 @@ def verify_token(token: str, salt_value) -> Optional[str]:
"""
Given a token, verifies and decodes it into an email
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
@@ -38,7 +41,7 @@ def reset_password(new_password: str, email: str):
"""
with Database() as db:
db.execute("""
UPDATE users
SET password = ?
UPDATE users
SET password = ?
WHERE email = ?
""", (generate_password_hash(new_password), email))

View File

@@ -1,16 +1,18 @@
"""Email sending utilities for password reset, account confirmation, and newsletters."""
import smtplib
from email.mime.text import MIMEText
from os import getenv
from secrets import token_hex
from dotenv import load_dotenv
from utils.auth import generate_token
from secrets import token_hex
from .user_utils import get_session_info_email
import redis
from database.database import Database
redis_url = "redis://redis:6379/1"
r = redis.from_url(redis_url, decode_responses=True)
REDIS_URL = "redis://redis:6379/1"
r = redis.from_url(REDIS_URL, decode_responses=True)
load_dotenv()
@@ -23,31 +25,31 @@ def send_email(email, func) -> None:
"""
# Setup the sender email details
SMTP_SERVER = "smtp.gmail.com"
SMTP_PORT = 587
SMTP_EMAIL = getenv("EMAIL")
SMTP_PASSWORD = getenv("EMAIL_PASSWORD")
smtp_server = "smtp.gmail.com"
smtp_port = 587
smtp_email = getenv("EMAIL")
smtp_password = getenv("EMAIL_PASSWORD")
# Setup up the receiver details
body, subject = func()
msg = MIMEText(body, "html")
msg["Subject"] = subject
msg["From"] = SMTP_EMAIL
msg["From"] = smtp_email
msg["To"] = email
# Send the email using smtplib
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp:
with smtplib.SMTP(smtp_server, smtp_port) as smtp:
try:
smtp.starttls() # TLS handshake to start the connection
smtp.login(SMTP_EMAIL, SMTP_PASSWORD)
smtp.login(smtp_email, smtp_password)
smtp.ehlo()
smtp.send_message(msg)
except TimeoutError:
print("Server timed out", flush=True)
except Exception as e:
except smtplib.SMTPException as e:
print("Error: ", e, flush=True)
def forgot_password_body(email) -> str:
@@ -57,7 +59,7 @@ def forgot_password_body(email) -> str:
salt = token_hex(32)
token = generate_token(email, salt)
token += "R3sET"
token += "R3sET"
r.setex(token, 3600, salt)
username = (get_session_info_email(email))["username"]
@@ -77,7 +79,7 @@ def forgot_password_body(email) -> str:
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Password Reset Request</h2>
<p>Click the button below to reset your password for your account {username}. This link is valid for 1 hour.</p>
<a href="{full_url}" class="btn">Reset Password</a>
@@ -116,7 +118,7 @@ def confirm_account_creation_body(email) -> str:
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Confirm Account Creation</h2>
<p>Click the button below to create your account. This link is valid for 1 hour.</p>
<a href="{full_url}" class="btn">Create Account</a>
@@ -154,7 +156,7 @@ def newsletter_conf(email):
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Welcome to the Official Gander Newsletter!</h2>
<p>If you are receiving this email, it means that you have been officially added to the Monthly Gander newsletter.</p>
<p>In this newsletter, you will receive updates about: your favourite streamers; important Gander updates; and more!</p>
@@ -169,7 +171,7 @@ def newsletter_conf(email):
user_exists = db.fetchone("""
SELECT *
FROM newsletter
WHERE email = ?;""",
WHERE email = ?;""",
(email,))
print(user_exists, flush=True)
@@ -190,7 +192,7 @@ def add_to_newsletter(email):
INSERT INTO newsletter (email)
VALUES (?);
""", (email,))
def remove_from_newsletter(email):
"""
Remove a person from the newsletter database
@@ -202,7 +204,7 @@ def remove_from_newsletter(email):
DELETE FROM newsletter
WHERE email = ?;
""", (email,))
def email_exists(email):
"""
Returns whether email exists within database
@@ -210,6 +212,6 @@ def email_exists(email):
with Database() as db:
data = db.fetchone("""
SELECT * FROM users
WHERE email = ?
WHERE email = ?
""", (email,))
return bool(data)
return bool(data)

View File

@@ -1,8 +1,11 @@
"""Description: This file contains the PathManager class which is responsible for managing the paths of the stream data."""
"""File system path management for user streams, VODs, and profile pictures."""
import os
class PathManager():
"""Manages paths for user stream data, VODs, and profile pictures."""
def __init__(self) -> None:
self.root_path = "user_data"
self.vod_directory_name = "vods"
@@ -26,7 +29,7 @@ class PathManager():
self._create_directory(os.path.join(self.root_path, username))
vods_path = self.get_vods_path(username)
stream_path = self.get_stream_path(username)
self._create_directory(vods_path)
self._create_directory(stream_path)
@@ -39,25 +42,33 @@ class PathManager():
os.rmdir(user_path)
def get_user_path(self, username):
"""Returns the base path for a user's data directory."""
return os.path.join(self.root_path, username)
def get_vods_path(self, username):
"""Returns the path to a user's VODs directory."""
return os.path.join(self.root_path, username, self.vod_directory_name)
def get_stream_path(self, username):
"""Returns the path to a user's stream directory."""
return os.path.join(self.root_path, username, self.stream_directory_name)
def get_stream_file_path(self, username):
"""Returns the path to a user's stream index file."""
return os.path.join(self.get_stream_path(username), "index.m3u8")
def get_current_stream_thumbnail_file_path(self, username):
"""Returns the path to a user's current stream thumbnail."""
return os.path.join(self.get_stream_path(username), "index.png")
def get_vod_file_path(self, username, vod_id):
"""Returns the path to a specific VOD file."""
return os.path.join(self.get_vods_path(username), f"{vod_id}.mp4")
def get_vod_thumbnail_file_path(self, username, vod_id):
"""Returns the path to a specific VOD thumbnail."""
return os.path.join(self.get_vods_path(username), f"{vod_id}.png")
def get_profile_picture_file_path(self, username):
return os.path.join(self.root_path, username, self.profile_picture_name)
"""Returns the path to a user's profile picture."""
return os.path.join(self.root_path, username, self.profile_picture_name)

View File

@@ -1,32 +1,42 @@
from database.database import Database
"""Personalized stream and category recommendations based on user preferences."""
from typing import Optional, List
from database.database import Database
def get_user_preferred_category(user_id: int) -> Optional[int]:
"""
Queries user_preferences database to find users favourite streaming category and returns the category
Queries user_preferences database to find users favourite
streaming category and returns the category
"""
with Database() as db:
category = db.fetchone("""
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
return category["category_id"] if category else None
def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]:
def get_followed_categories_recommendations(
user_id: int, no_streams: int = 4
) -> Optional[List[dict]]:
"""
Returns top streams given a user's category following
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
SELECT u.user_id, title, u.username, num_viewers,
category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
JOIN categories
ON streams.category_id = categories.category_id
WHERE categories.category_id IN
(SELECT category_id
FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT ?;
""", (user_id, no_streams))
@@ -40,25 +50,29 @@ def get_followed_your_categories(user_id: int) -> Optional[List[dict]]:
categories = db.fetchall("""
SELECT categories.category_name
FROM categories
JOIN followed_categories
JOIN followed_categories
ON categories.category_id = followed_categories.category_id
WHERE followed_categories.user_id = ?;
""", (user_id,))
return categories
def get_streams_based_on_category(category_id: int, no_streams: int = 4, offset: int = 0) -> Optional[List[dict]]:
def get_streams_based_on_category(
category_id: int, no_streams: int = 4, offset: int = 0
) -> Optional[List[dict]]:
"""
Queries stream database to get top most viewed streams based on given category
Queries stream database to get top most viewed streams
based on given category
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
SELECT u.user_id, title, username, num_viewers,
c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT ? OFFSET ?
""", (category_id, no_streams, offset))
return streams
@@ -70,43 +84,63 @@ def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]:
"""
with Database() as db:
data = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
SELECT u.user_id, username, title, num_viewers,
category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
JOIN categories
ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return data
def get_highest_view_categories(no_categories: int = 4, offset: int = 0) -> Optional[List[dict]]:
def get_highest_view_categories(
no_categories: int = 4, offset: int = 0
) -> Optional[List[dict]]:
"""
Returns a list of top most popular categories given offset
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
SELECT categories.category_id,
categories.category_name,
COALESCE(SUM(streams.num_viewers), 0)
AS num_viewers
FROM categories
LEFT JOIN streams ON streams.category_id = categories.category_id
GROUP BY categories.category_id, categories.category_name
LEFT JOIN streams
ON streams.category_id = categories.category_id
GROUP BY categories.category_id,
categories.category_name
ORDER BY num_viewers DESC
LIMIT ? OFFSET ?;
""", (no_categories, offset))
return categories
def get_user_category_recommendations(user_id = 1, no_categories: int = 4) -> Optional[List[dict]]:
def get_user_category_recommendations(
user_id=1, no_categories: int = 4
) -> Optional[List[dict]]:
"""
Queries user_preferences database to find users top favourite streaming category and returns the category
Queries user_preferences database to find users top favourite
streaming category and returns the category
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
LEFT JOIN streams ON categories.category_id = streams.category_id
WHERE user_preferences.user_id = ?
GROUP BY categories.category_id, categories.category_name
ORDER BY user_preferences.favourability DESC
SELECT categories.category_id,
categories.category_name,
COALESCE(SUM(streams.num_viewers), 0)
AS num_viewers
FROM categories
JOIN user_preferences
ON categories.category_id
= user_preferences.category_id
LEFT JOIN streams
ON categories.category_id
= streams.category_id
WHERE user_preferences.user_id = ?
GROUP BY categories.category_id,
categories.category_name
ORDER BY user_preferences.favourability DESC
LIMIT ?
""", (user_id, no_categories))
return categories
return categories

View File

@@ -1,9 +1,11 @@
from database.database import Database
from typing import Optional
import os, subprocess
"""Stream data retrieval and management utilities."""
import os
import subprocess
from typing import Optional, List
from time import sleep
from utils.path_manager import PathManager
from database.database import Database
def get_streamer_live_status(user_id: int):
"""
@@ -11,8 +13,8 @@ def get_streamer_live_status(user_id: int):
"""
with Database() as db:
is_live = db.fetchone("""
SELECT is_live
FROM users
SELECT is_live
FROM users
WHERE user_id = ?;
""", (user_id,))
@@ -21,17 +23,19 @@ def get_streamer_live_status(user_id: int):
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
Returns a list of live streams with the streamer's user id,
stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
SELECT users.user_id, streams.title,
streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
def get_current_stream_data(user_id: int) -> Optional[dict]:
@@ -40,7 +44,8 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
"""
with Database() as db:
most_recent_stream = db.fetchone("""
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name, c.category_id
SELECT s.user_id, u.username, s.title, s.start_time,
s.num_viewers, c.category_name, c.category_id
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
@@ -51,78 +56,92 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
def end_user_stream(stream_key, user_id, username):
"""
Utility function to end a user's stream
Parameters:
stream_key: The stream key of the user
user_id: The ID of the user
username: The username of the user
Returns:
bool: True if stream was ended successfully, False otherwise
"""
from flask import current_app
from datetime import datetime
from dateutil import parser
from celery_tasks.streaming import combine_ts_stream
from utils.path_manager import PathManager
path_manager = PathManager()
pm = PathManager()
print(f"Ending stream for user {username} (ID: {user_id})", flush=True)
if not stream_key or not user_id or not username:
print("Cannot end stream - missing required information", flush=True)
return False
return False, "Missing required information"
try:
# Open database connection
with Database() as db:
# Get stream info
stream_info = db.fetchone("""SELECT *
FROM streams
WHERE user_id = ?""", (user_id,))
stream_info = db.fetchone(
"""SELECT *
FROM streams
WHERE user_id = ?""", (user_id,))
# If user is not streaming, just return
if not stream_info:
print(f"User {username} (ID: {user_id}) is not streaming", flush=True)
print(
f"User {username} (ID: {user_id}) "
"is not streaming", flush=True)
return True, "User is not streaming"
# Remove stream from database
db.execute("""DELETE FROM streams
WHERE user_id = ?""", (user_id,))
db.execute(
"""DELETE FROM streams
WHERE user_id = ?""", (user_id,))
# Move stream to vod table
stream_length = int(
(datetime.now() - parser.parse(stream_info.get("start_time"))).total_seconds())
db.execute("""INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?)""", (user_id,
stream_info.get("title"),
stream_info.get("start_time"),
stream_info.get("category_id"),
stream_length,
0))
(datetime.now() - parser.parse(
stream_info.get("start_time")
)).total_seconds())
db.execute(
"""INSERT INTO vods
(user_id, title, datetime, category_id,
length, views)
VALUES (?, ?, ?, ?, ?, ?)""",
(user_id,
stream_info.get("title"),
stream_info.get("start_time"),
stream_info.get("category_id"),
stream_length,
0))
vod_id = db.get_last_insert_id()
# Set user as not streaming
db.execute("""UPDATE users
SET is_live = 0
WHERE user_id = ?""", (user_id,))
db.execute(
"""UPDATE users
SET is_live = 0
WHERE user_id = ?""", (user_id,))
# Queue task to combine TS files into MP4
combine_ts_stream.delay(
path_manager.get_stream_path(username),
path_manager.get_vods_path(username),
pm.get_stream_path(username),
pm.get_vods_path(username),
vod_id,
path_manager.get_vod_thumbnail_file_path(username, vod_id)
pm.get_vod_thumbnail_file_path(username, vod_id)
)
print(f"Stream ended for user {username} (ID: {user_id})", flush=True)
print(
f"Stream ended for user {username} (ID: {user_id})",
flush=True)
return True, "Stream ended successfully"
except Exception as e:
print(f"Error ending stream for user {username}: {str(e)}", flush=True)
return False, f"Error ending stream: {str(e)}"
except (ValueError, TypeError, KeyError) as exc:
print(
f"Error ending stream for user {username}: {str(exc)}",
flush=True)
return False, f"Error ending stream: {str(exc)}"
def get_category_id(category_name: str) -> Optional[int]:
"""
@@ -130,8 +149,8 @@ def get_category_id(category_name: str) -> Optional[int]:
"""
with Database() as db:
data = db.fetchone("""
SELECT category_id
FROM categories
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
return data['category_id'] if data else None
@@ -142,7 +161,7 @@ def get_custom_thumbnail_status(user_id: int) -> Optional[dict]:
"""
with Database() as db:
custom_thumbnail = db.fetchone("""
SELECT custom_thumbnail
SELECT custom_thumbnail
FROM streams
WHERE user_id = ?;
""", (user_id,))
@@ -153,7 +172,13 @@ def get_vod(vod_id: int) -> dict:
Returns data of a streamers vod
"""
with Database() as db:
vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vod_id = ?;""", (vod_id,))
vod = db.fetchone(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vod_id = ?;""", (vod_id,))
return vod
def get_latest_vod(user_id: int):
@@ -161,7 +186,14 @@ def get_latest_vod(user_id: int):
Returns data of the most recent stream by a streamer
"""
with Database() as db:
latest_vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
latest_vod = db.fetchone(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vods.user_id = ?
ORDER BY vod_id DESC;""", (user_id,))
return latest_vod
def get_user_vods(user_id: int):
@@ -169,10 +201,17 @@ def get_user_vods(user_id: int):
Returns data of all vods by a streamer
"""
with Database() as db:
vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
vods = db.fetchall(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vods.user_id = ?
ORDER BY vod_id DESC;""", (user_id,))
return vods
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) -> None:
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture=0) -> None:
"""
Generates the thumbnail of a stream
"""
@@ -192,10 +231,16 @@ def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) ->
]
try:
subprocess.run(thumbnail_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
subprocess.run(
thumbnail_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True)
print(f"Thumbnail generated for {stream_file}")
except subprocess.CalledProcessError as e:
print(f"No information available for {stream_file}, aborting thumbnail generation")
except subprocess.CalledProcessError:
print(
f"No information available for {stream_file}, "
"aborting thumbnail generation")
def remove_hls_files(stream_path: str) -> None:
"""
@@ -221,11 +266,13 @@ def get_video_duration(video_path: str) -> int:
]
try:
video_length = subprocess.check_output(video_length_command).decode("utf-8")
video_length = subprocess.check_output(
video_length_command
).decode("utf-8")
print(f"Video length: {video_length}")
return int(float(video_length))
except subprocess.CalledProcessError as e:
print(f"Error getting video length: {e}")
except subprocess.CalledProcessError:
print("Error getting video length")
return 0
def get_stream_tags(user_id: int) -> Optional[List[str]]:
@@ -234,10 +281,10 @@ def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
WHERE user_id = ?;
""", (user_id,))
return tags
@@ -247,9 +294,9 @@ def get_vod_tags(vod_id: int):
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
WHERE vod_id = ?;
""", (vod_id,))
return tags
return tags

View File

@@ -1,16 +1,20 @@
from database.database import Database
"""User profile management, following, and subscription utilities."""
from typing import Optional, List
from datetime import datetime, timedelta
from database.database import Database
from dateutil import parser
def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
@@ -21,7 +25,7 @@ def get_username(user_id: str) -> Optional[str]:
"""
with Database() as db:
data = db.fetchone("""
SELECT username
SELECT username
FROM users
WHERE user_id = ?
""", (user_id,))
@@ -40,15 +44,16 @@ def update_bio(user_id: int, bio: str):
def has_password(email: str):
"""
Returns if account associated with this email has password, i.e is account from Google OAuth
Returns if account associated with this email has password,
i.e is account from Google OAuth
"""
with Database() as db:
data = db.fetchone("""
SELECT password
FROM users
WHERE email = ?
WHERE email = ?
""", (email,))
return False if data["password"] == None else True
return data["password"] is not None
def get_session_info_email(email: str) -> dict:
"""
@@ -61,15 +66,15 @@ def get_session_info_email(email: str) -> dict:
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
@@ -79,12 +84,12 @@ def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
with Database() as db:
return bool(db.fetchone(
"""
SELECT 1
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
SELECT 1
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""",
""",
(user_id, subscribed_to_id, datetime.now())
))
@@ -94,9 +99,9 @@ def is_following(user_id: int, followed_id: int) -> bool:
"""
with Database() as db:
result = db.fetchone("""
SELECT 1
FROM follows
WHERE user_id = ?
SELECT 1
FROM follows
WHERE user_id = ?
AND followed_id = ?
""", (user_id, followed_id))
return bool(result)
@@ -107,7 +112,7 @@ def follow(user_id: int, followed_id: int):
"""
if is_following(user_id, followed_id):
return {"success": False, "error": "Already following user"}, 400
with Database() as db:
db.execute("""
INSERT INTO follows (user_id, followed_id)
@@ -135,9 +140,9 @@ def is_following_category(user_id: int, category_id: str):
"""
with Database() as db:
result = db.fetchone("""
SELECT 1
FROM followed_categories
WHERE user_id = ?
SELECT 1
FROM followed_categories
WHERE user_id = ?
AND category_id = ?
""", (user_id, category_id))
return bool(result)
@@ -148,7 +153,7 @@ def follow_category(user_id: int, category_id: str):
"""
if is_following_category(user_id, category_id):
return {"success": False, "error": "Already following category"}, 400
with Database() as db:
db.execute("""
INSERT INTO followed_categories (user_id, category_id)
@@ -163,7 +168,7 @@ def unfollow_category(user_id: int, category_id: str):
"""
if not is_following_category(user_id, category_id):
return {"success": False, "error": "Not following category"}, 400
with Database() as db:
db.execute("""
DELETE FROM followed_categories
@@ -179,8 +184,8 @@ def subscribe(user_id: int, streamer_id: int):
# If user is already subscribed then extend the expiration date else create a new entry
with Database() as db:
existing = db.fetchone("""
SELECT expires
FROM subscribes
SELECT expires
FROM subscribes
WHERE user_id = ? AND subscribed_id = ?
""", (user_id, streamer_id))
if existing:
@@ -188,7 +193,7 @@ def subscribe(user_id: int, streamer_id: int):
UPDATE subscribes SET expires = expires + ?
WHERE user_id = ? AND subscribed_id = ?
""", (timedelta(days=30), user_id, streamer_id))
else:
else:
db.execute("""
INSERT INTO subscribes
(user_id, subscribed_id, since, expires)
@@ -212,10 +217,10 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
"""
with Database() as db:
data = db.fetchone("""
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?
""", (user_id, subscribed_id, datetime.now()))
@@ -227,13 +232,14 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
return 0
def get_email(user_id: int) -> Optional[str]:
"""Returns the email address for a given user_id."""
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_followed_streamers(user_id: int) -> Optional[List[dict]]:
@@ -289,4 +295,4 @@ def get_user(user_id: int) -> Optional[dict]:
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data
return data

View File

@@ -1,14 +1,18 @@
from database.database import Database
"""Input sanitization and validation utilities."""
from typing import Optional, List
from re import match
from database.database import Database
def get_all_categories() -> Optional[List[dict]]:
"""
Returns all possible streaming categories
"""
with Database() as db:
all_categories = db.fetchall("SELECT * FROM categories")
return all_categories
def get_all_tags() -> Optional[List[dict]]:
@@ -17,7 +21,7 @@ def get_all_tags() -> Optional[List[dict]]:
"""
with Database() as db:
all_tags = db.fetchall("SELECT * FROM tags")
return all_tags
def get_most_popular_category() -> Optional[List[dict]]:
@@ -34,7 +38,7 @@ def get_most_popular_category() -> Optional[List[dict]]:
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
""")
return category
def get_category_id(category_name: str):
@@ -53,7 +57,7 @@ def get_category_id(category_name: str):
def sanitize(user_input: str, input_type="default") -> str:
"""
Sanitizes user input based on the specified input type.
`input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password').
"""
# Strip leading and trailing whitespace
@@ -84,10 +88,10 @@ def sanitize(user_input: str, input_type="default") -> str:
}
# Get the validation rules for the specified type
r = rules.get(input_type)
if not r or \
not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \
not match(r["pattern"], sanitised_input):
rule = rules.get(input_type)
if (not rule
or not (rule["min_length"] <= len(sanitised_input) <= rule["max_length"])
or not match(rule["pattern"], sanitised_input)):
raise ValueError("Unaccepted character or length in input")
return sanitised_input
return sanitised_input