Fix/pylint cleanup (#8)
Some checks are pending
CI / build (3.10) (push) Waiting to run
CI / build (3.8) (push) Waiting to run
CI / build (3.9) (push) Waiting to run

* Fix pylint warnings across all 24 Python files in web_server

- Add module, class, and function docstrings (C0114, C0115, C0116)
- Fix import ordering: stdlib before third-party before local (C0411)
- Replace wildcard imports with explicit named imports (W0401)
- Remove trailing whitespace and add missing final newlines (C0303, C0304)
- Replace dict() with dict literals (R1735)
- Remove unused imports and variables (W0611, W0612)
- Narrow broad Exception catches to specific exceptions (W0718)
- Replace f-string logging with lazy % formatting (W1203)
- Fix variable naming: UPPER_CASE for constants, snake_case for locals (C0103)
- Add pylint disable comments for necessary global statements (W0603)
- Fix no-else-return, simplifiable-if-expression, singleton-comparison
- Fix bad indentation in stripe.py (W0311)
- Add encoding="utf-8" to open() calls (W1514)
- Add check=True to subprocess.run() calls (W1510)
- Register Celery task modules via conf.include

* Update `package-lock.json` add peer dependencies
This commit is contained in:
Christopher Ahern
2026-02-07 20:57:28 +00:00
committed by GitHub
parent fed1a2f288
commit 2758be8680
25 changed files with 680 additions and 419 deletions

View File

@@ -1,12 +1,12 @@
{
"name": "frontend",
"version": "0.5.0",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "frontend",
"version": "0.5.0",
"version": "1.0.0",
"dependencies": {
"@stripe/react-stripe-js": "^3.1.1",
"@stripe/stripe-js": "^5.5.0",
@@ -97,6 +97,7 @@
"integrity": "sha512-i1SLeK+DzNnQ3LL/CswPCa/E5u4lh1k6IAEphON8F+cXt0t9euTshDru0q7/IqMa1PMPz5RnHuHscF8/ZJsStg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@ampproject/remapping": "^2.2.0",
"@babel/code-frame": "^7.26.0",
@@ -1393,6 +1394,7 @@
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-5.5.0.tgz",
"integrity": "sha512-lkfjyAd34aeMpTKKcEVfy8IUyEsjuAT3t9EXr5yZDtdIUncnZpedl/xLV16Dkd4z+fQwixScsCCDxSMNtBOgpQ==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12.16"
}
@@ -1482,6 +1484,7 @@
"integrity": "sha512-t4yC+vtgnkYjNSKlFx1jkAhH8LgTo2N/7Qvi83kdEaUtMDiwpbLAktKDaAMlRcJ5eSxZkH74eEGt1ky31d7kfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/prop-types": "*",
"csstype": "^3.0.2"
@@ -1539,6 +1542,7 @@
"integrity": "sha512-gKXG7A5HMyjDIedBi6bUrDcun8GIjnI8qOwVLiY3rx6T/sHP/19XLJOnIq/FgQvWLHja5JN/LSE7eklNBr612g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.20.0",
"@typescript-eslint/types": "8.20.0",
@@ -1805,6 +1809,7 @@
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -2018,6 +2023,7 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"caniuse-lite": "^1.0.30001688",
"electron-to-chromium": "^1.5.73",
@@ -2051,9 +2057,9 @@
}
},
"node_modules/caniuse-lite": {
"version": "1.0.30001695",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001695.tgz",
"integrity": "sha512-vHyLade6wTgI2u1ec3WQBxv+2BrTERV28UXQu9LO6lZ9pYeMk34vjXFLOxo1A4UBA8XTL4njRQZdno/yYaSmWw==",
"version": "1.0.30001769",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001769.tgz",
"integrity": "sha512-BCfFL1sHijQlBGWBMuJyhZUhzo7wer5sVj9hqekB/7xn0Ypy+pER/edCYQm4exbXj4WiySGp40P8UuTh6w1srg==",
"dev": true,
"funding": [
{
@@ -2386,6 +2392,7 @@
"integrity": "sha512-+waTfRWQlSbpt3KWE+CjrPPYnbq9kfZIYUqapc0uBXyjTp8aYXZDsUH16m39Ryq3NjAVP4tjuF7KaukeqoCoaA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@eslint-community/regexpp": "^4.12.1",
@@ -3534,6 +3541,7 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"nanoid": "^3.3.8",
"picocolors": "^1.1.1",
@@ -3723,6 +3731,7 @@
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0"
},
@@ -3741,6 +3750,7 @@
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
"license": "MIT",
"peer": true,
"dependencies": {
"loose-envify": "^1.1.0",
"scheduler": "^0.23.2"
@@ -4233,6 +4243,7 @@
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.17.tgz",
"integrity": "sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==",
"license": "MIT",
"peer": true,
"dependencies": {
"@alloc/quick-lru": "^5.2.0",
"arg": "^5.0.2",
@@ -4351,6 +4362,7 @@
"integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -4443,6 +4455,7 @@
"resolved": "https://registry.npmjs.org/video.js/-/video.js-8.21.0.tgz",
"integrity": "sha512-zcwerRb257QAuWfi8NH9yEX7vrGKFthjfcONmOQ4lxFRpDAbAi+u5LAjCjMWqhJda6zEmxkgdDpOMW3Y21QpXA==",
"license": "Apache-2.0",
"peer": true,
"dependencies": {
"@babel/runtime": "^7.12.5",
"@videojs/http-streaming": "^3.16.2",
@@ -4495,6 +4508,7 @@
"integrity": "sha512-4VL9mQPKoHy4+FE0NnRE/kbY51TOfaknxAjt3fJbGJxhIpBZiqVzlZDEesWWsuREXHwNdAoOFZ9MkPEVXczHwg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.24.2",
"postcss": "^8.4.49",

View File

@@ -1,3 +1,7 @@
"""Flask application factory and blueprint registration."""
from os import getenv
from flask import Flask
from flask_session import Session
from flask_cors import CORS
@@ -13,10 +17,8 @@ from blueprints.oauth import oauth_bp, init_oauth
from blueprints.socket import socketio
from blueprints.search_bar import search_bp
from celery import Celery
from celery_tasks import celery_init_app
from os import getenv
def create_app():
"""
@@ -34,16 +36,15 @@ def create_app():
app.config['GOOGLE_CLIENT_SECRET'] = getenv("GOOGLE_CLIENT_SECRET")
app.config["SESSION_COOKIE_HTTPONLY"] = True
app.config.from_mapping(
CELERY=dict(
broker_url="redis://redis:6379/0",
result_backend="redis://redis:6379/0",
task_ignore_result=True,
),
CELERY={
"broker_url": "redis://redis:6379/0",
"result_backend": "redis://redis:6379/0",
"task_ignore_result": True,
},
)
app.config.from_prefixed_env()
celery = celery_init_app(app)
celery_init_app(app)
#! ↓↓↓ For development purposes only - Allow cross-origin requests for the frontend
CORS(app, supports_credentials=True)

View File

@@ -1,6 +1,8 @@
"""Admin blueprint for user management operations."""
from flask import Blueprint, session
from utils.utils import sanitize
from utils.admin_utils import *
from utils.admin_utils import check_if_admin, check_if_user_exists, ban_user
admin_bp = Blueprint("admin", __name__)
@@ -18,10 +20,10 @@ def admin_delete_user(banned_user):
# Check if the user is an admin
username = session.get("username")
is_admin = check_if_admin(username)
# Check if the user exists
user_exists = check_if_user_exists(banned_user)
# If the user is an admin, try to delete the account
if is_admin and user_exists:
ban_user(banned_user)
ban_user(banned_user)

View File

@@ -1,3 +1,8 @@
"""Authentication blueprint for user signup, login, and logout."""
import logging
from secrets import token_hex
from flask import Blueprint, session, request, jsonify
from werkzeug.security import generate_password_hash, check_password_hash
from flask_cors import cross_origin
@@ -5,18 +10,21 @@ from database.database import Database
from blueprints.middleware import login_required
from utils.user_utils import get_user_id
from utils.utils import sanitize
from secrets import token_hex
from utils.path_manager import PathManager
auth_bp = Blueprint("auth", __name__)
path_manager = PathManager()
logger = logging.getLogger(__name__)
@auth_bp.route("/signup", methods=["POST"])
@cross_origin(supports_credentials=True)
def signup():
"""
Route that allows a user to sign up by providing a `username`, `email` and `password`.
Route that allows a user to sign up by providing a
`username`, `email` and `password`.
"""
# ensure a JSON request is made to contact this route
if not request.is_json:
@@ -30,19 +38,21 @@ def signup():
# Validation - ensure all fields exist, users cannot have an empty field
if not all([username, email, password]):
error_fields = get_error_fields([username, email, password]) #!←← find the error_fields, to highlight them in red to the user on the frontend
error_fields = get_error_fields(
[username, email, password]
)
return jsonify({
"account_created": False,
"error_fields": error_fields,
"message": "Missing required fields"
}), 400
# Sanitize the inputs - helps to prevent SQL injection
try:
username = sanitize(username, "username")
email = sanitize(email, "email")
password = sanitize(password, "password")
except ValueError as e:
except ValueError:
error_fields = get_error_fields([username, email, password])
return jsonify({
"account_created": False,
@@ -81,7 +91,7 @@ def signup():
# Create new user once input is validated
db.execute(
"""INSERT INTO users
"""INSERT INTO users
(username, password, email, stream_key)
VALUES (?, ?, ?, ?)""",
(
@@ -100,16 +110,16 @@ def signup():
"message": "Account created successfully"
}), 201
except Exception as e:
print(f"Error during signup: {e}") # Log the error
except (ValueError, TypeError, KeyError) as exc:
logger.error("Error during signup: %s", exc)
return jsonify({
"account_created": False,
"message": "Server error occurred: " + str(e)
"message": "Server error occurred: " + str(exc)
}), 500
finally:
db.close_connection()
@auth_bp.route("/login", methods=["POST"])
@cross_origin(supports_credentials=True)
@@ -127,24 +137,24 @@ def login():
username = data.get('username')
password = data.get('password')
# Validation - ensure all fields exist, users cannot have an empty field
# Validation - ensure all fields exist, users cannot have an empty field
if not all([username, password]):
return jsonify({
"logged_in": False,
"message": "Missing required fields"
}), 400
# Sanitize the inputs - helps to prevent SQL injection
try:
username = sanitize(username, "username")
password = sanitize(password, "password")
except ValueError as e:
except ValueError:
return jsonify({
"account_created": False,
"error_fields": ["username", "password"],
"message": "Invalid input received"
}), 400
# Create a connection to the database
db = Database()
@@ -169,7 +179,7 @@ def login():
"error_fields": ["username", "password"],
"message": "Invalid username or password"
}), 401
# Add user directories for stream data in case they don't exist
path_manager.create_user(username)
@@ -177,7 +187,10 @@ def login():
session.clear()
session["username"] = username
session["user_id"] = get_user_id(username)
print(f"Logged in as {username}. session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True)
logger.info(
"Logged in as %s. session: %s. user_id: %s",
username, session.get('username'), session.get('user_id')
)
# User has been logged in, let frontend know that
return jsonify({
@@ -186,8 +199,8 @@ def login():
"username": username
}), 200
except Exception as e:
print(f"Error during login: {e}") # Log the error
except (ValueError, TypeError, KeyError) as exc:
logger.error("Error during login: %s", exc)
return jsonify({
"logged_in": False,
"message": "Server error occurred"
@@ -202,31 +215,41 @@ def login():
def logout() -> dict:
"""
Log out and clear the users session.
If the user is currently streaming, end their stream first.
Can only be accessed by a logged in user.
"""
from database.database import Database
from utils.stream_utils import end_user_stream
# Check if user is currently streaming
user_id = session.get("user_id")
username = session.get("username")
with Database() as db:
is_streaming = db.fetchone("""SELECT is_live FROM users WHERE user_id = ?""", (user_id,))
is_streaming = db.fetchone(
"""SELECT is_live FROM users WHERE user_id = ?""",
(user_id,)
)
if is_streaming and is_streaming.get("is_live") == 1:
# Get the user's stream key
stream_key_info = db.fetchone("""SELECT stream_key FROM users WHERE user_id = ?""", (user_id,))
stream_key = stream_key_info.get("stream_key") if stream_key_info else None
stream_key_info = db.fetchone(
"""SELECT stream_key FROM users WHERE user_id = ?""",
(user_id,)
)
stream_key = (
stream_key_info.get("stream_key") if stream_key_info
else None
)
if stream_key:
# End the stream
end_user_stream(stream_key, user_id, username)
session.clear()
return {"logged_in": False}
def get_error_fields(values: list):
"""Return field names for empty values."""
fields = ["username", "email", "password"]
return [fields[i] for i, v in enumerate(values) if not v]
return [fields[i] for i, v in enumerate(values) if not v]

View File

@@ -1,14 +1,17 @@
"""Chat blueprint for WebSocket-based real-time messaging."""
import json
from datetime import datetime
from flask import Blueprint, jsonify
from database.database import Database
from .socket import socketio
from flask_socketio import emit, join_room, leave_room
from datetime import datetime
from utils.user_utils import get_user_id, is_subscribed
from utils.user_utils import is_subscribed
import redis
import json
redis_url = "redis://redis:6379/1"
r = redis.from_url(redis_url, decode_responses=True)
REDIS_URL = "redis://redis:6379/1"
r = redis.from_url(REDIS_URL, decode_responses=True)
chat_bp = Blueprint("chat", __name__)
@socketio.on("connect")
@@ -32,7 +35,7 @@ def handle_join(data) -> None:
join_room(stream_id)
num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id)))
update_viewers(stream_id, num_viewers)
emit("status",
emit("status",
{
"message": f"Welcome to the chat, stream_id: {stream_id}",
"num_viewers": num_viewers
@@ -53,7 +56,7 @@ def handle_leave(data) -> None:
remove_favourability_entry(str(data["user_id"]), str(stream_id))
num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id)))
update_viewers(stream_id, num_viewers)
emit("status",
emit("status",
{
"message": f"Welcome to the chat, stream_id: {stream_id}",
"num_viewers": num_viewers
@@ -78,10 +81,10 @@ def get_past_chat(stream_id: int):
all_chats = db.fetchall("""
SELECT user_id, username, message, time_sent, is_subscribed
FROM (
SELECT
SELECT
u.user_id,
u.username,
c.message,
u.username,
c.message,
c.time_sent,
CASE
WHEN s.user_id IS NOT NULL AND s.expires > CURRENT_TIMESTAMP THEN 1
@@ -101,8 +104,8 @@ def get_past_chat(stream_id: int):
# Create JSON output of chat_history to pass through NGINX proxy
chat_history = [{"chatter_id": chat["user_id"],
"chatter_username": chat["username"],
"message": chat["message"],
"chatter_username": chat["username"],
"message": chat["message"],
"time_sent": chat["time_sent"],
"is_subscribed": bool(chat["is_subscribed"])} for chat in all_chats]
print(chat_history)
@@ -125,7 +128,13 @@ def send_chat(data) -> None:
# Input validation - chatter is logged in, message is not empty, stream exists
if not all([chatter_name, message, stream_id]):
emit("error", {"error": f"Unable to send a chat. The following info was given: chatter_name={chatter_name}, message={message}, stream_id={stream_id}"}, broadcast=False)
emit("error", {
"error": (
f"Unable to send a chat. The following info was given: "
f"chatter_name={chatter_name}, message={message}, "
f"stream_id={stream_id}"
)
}, broadcast=False)
return
subscribed = is_subscribed(chatter_id, stream_id)
# Send the chat message to the client so it can be displayed
@@ -161,9 +170,8 @@ def update_viewers(user_id, num_viewers):
SET num_viewers = ?
WHERE user_id = ?;
""", (num_viewers, user_id))
db.close_connection
#TODO: Make sure that users entry within Redis is removed if they disconnect from socket
db.close_connection()
def add_favourability_entry(user_id, stream_id):
"""
Adds entry to Redis that user is watching a streamer
@@ -183,7 +191,7 @@ def add_favourability_entry(user_id, stream_id):
else:
# Creates new entry for user and stream
current_viewers[user_id] = [stream_id]
r.hset("current_viewers", "viewers", json.dumps(current_viewers))
def remove_favourability_entry(user_id, stream_id):
@@ -202,9 +210,9 @@ def remove_favourability_entry(user_id, stream_id):
if user_id in current_viewers:
# Removes specific stream from user
current_viewers[user_id] = [stream for stream in current_viewers[user_id] if stream != stream_id]
# If user is no longer watching any streams
if not current_viewers[user_id]:
del current_viewers[user_id]
r.hset("current_viewers", "viewers", json.dumps(current_viewers))
r.hset("current_viewers", "viewers", json.dumps(current_viewers))

View File

@@ -1,7 +1,10 @@
from flask import redirect, g, session
from functools import wraps
"""Authentication middleware and error handler registration."""
import logging
from functools import wraps
from os import getenv
from flask import redirect, g, session
from dotenv import load_dotenv
from database.database import Database
@@ -57,5 +60,5 @@ def register_error_handlers(app):
for code, message in error_responses.items():
@app.errorhandler(code)
def handle_error(error, message=message, code=code):
logging.error(f"Error {code}: {str(error)}")
return {"error": message}, code
logging.error("Error %d: %s", code, str(error))
return {"error": message}, code

View File

@@ -1,15 +1,18 @@
"""OAuth blueprint for Google authentication."""
from os import getenv
from secrets import token_hex, token_urlsafe
from random import randint
from authlib.integrations.flask_client import OAuth, OAuthError
from flask import Blueprint, jsonify, session, redirect, request
from blueprints.user import get_session_info_email
from database.database import Database
from dotenv import load_dotenv
from secrets import token_hex, token_urlsafe
from random import randint
from utils.path_manager import PathManager
oauth_bp = Blueprint("oauth", __name__)
google = None
_google = None
load_dotenv()
url_api = getenv("VITE_API_URL")
@@ -23,8 +26,8 @@ def init_oauth(app):
Initialise the OAuth functionality.
"""
oauth = OAuth(app)
global google
google = oauth.register(
global _google # pylint: disable=global-statement
_google = oauth.register(
'google',
client_id=app.config['GOOGLE_CLIENT_ID'],
client_secret=app.config['GOOGLE_CLIENT_SECRET'],
@@ -50,11 +53,11 @@ def login_google():
session["nonce"] = token_urlsafe(16)
session["state"] = token_urlsafe(32)
session["origin"] = request.args.get("next")
# Make sure session is saved before redirect
session.modified = True
return google.authorize_redirect(
return _google.authorize_redirect(
redirect_uri=f'{url}/api/google_auth',
nonce=session['nonce'],
state=session['state']
@@ -70,23 +73,27 @@ def google_auth():
# Check state parameter before authorizing
returned_state = request.args.get('state')
stored_state = session.get('state')
if not stored_state or stored_state != returned_state:
print(f"State mismatch: stored={stored_state}, returned={returned_state}", flush=True)
print(
f"State mismatch: stored={stored_state}, "
f"returned={returned_state}", flush=True
)
return jsonify({
'error': f"mismatching_state: CSRF Warning! State not equal in request and response.",
'error': "mismatching_state: CSRF Warning! "
"State not equal in request and response.",
'message': 'Authentication failed'
}), 400
# State matched, proceed with token authorization
token = google.authorize_access_token()
token = _google.authorize_access_token()
# Verify nonce
nonce = session.get('nonce')
if not nonce:
return jsonify({'error': 'Missing nonce in session'}), 400
user = google.parse_id_token(token, nonce=nonce)
user = _google.parse_id_token(token, nonce=nonce)
# Check if email exists to login else create a database entry
user_email = user.get("email")
@@ -108,7 +115,7 @@ def google_auth():
break
db.execute(
"""INSERT INTO users
"""INSERT INTO users
(username, email, stream_key)
VALUES (?, ?, ?)""",
(
@@ -124,16 +131,19 @@ def google_auth():
origin = session.get("origin", f"{url.replace('/api', '')}")
username = user_data["username"]
user_id = user_data["user_id"]
# Clear session and set new data
session.clear()
session["username"] = username
session["user_id"] = user_id
# Ensure session is saved
session.modified = True
print(f"session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True)
print(
f"session: {session.get('username')}. "
f"user_id: {session.get('user_id')}", flush=True
)
return redirect(origin)
@@ -144,9 +154,9 @@ def google_auth():
'error': str(e)
}), 400
except Exception as e:
except (ValueError, TypeError, KeyError) as e:
print(f"Unexpected Error: {str(e)}", flush=True)
return jsonify({
'message': 'An unexpected error occurred',
'error': str(e)
}), 500
}), 500

View File

@@ -1,3 +1,5 @@
"""Search bar blueprint for querying categories, users, streams, and VODs."""
from flask import Blueprint, jsonify, request
from database.database import Database
from utils.utils import sanitize
@@ -20,7 +22,7 @@ def rank_results(query, result):
# Assign a score based on the level of the match
if query in result:
return 0
elif all(c in charset for c in query):
if all(c in charset for c in query):
return 1
return 2
@@ -50,7 +52,7 @@ def search_results():
res_dict.append(c)
categories = sorted(res_dict, key=lambda d: d["score"])
categories = categories[:4]
# 3 users
res_dict = []
users = db.fetchall("SELECT user_id, username, is_live FROM users")
@@ -63,7 +65,7 @@ def search_results():
users = sorted(res_dict, key=lambda d: d["score"])
users = users[:4]
# 3 streams
# 3 streams
res_dict = []
streams = db.fetchall("""SELECT s.user_id, s.title, s.num_viewers, c.category_name, u.username
FROM streams AS s
@@ -71,7 +73,7 @@ def search_results():
INNER JOIN users AS u ON s.user_id = u.user_id
INNER JOIN categories AS c ON s.category_id = c.category_id
""")
for s in streams:
key = s.get("title")
score = rank_results(query.lower(), key.lower())
@@ -83,7 +85,7 @@ def search_results():
# 3 VODs
res_dict = []
vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username
vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username
FROM vods as v JOIN users as u
ON v.user_id = u.user_id""")
for v in vods:
@@ -98,5 +100,5 @@ def search_results():
db.close_connection()
print(query, streams, users, categories, vods, flush=True)
return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods})
return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods})

View File

@@ -1,10 +1,12 @@
"""WebSocket configuration using Flask-SocketIO."""
from flask_socketio import SocketIO
socketio = SocketIO(
cors_allowed_origins="*",
cors_allowed_origins="*",
async_mode='gevent',
logger=False, # Reduce logging
engineio_logger=False, # Reduce logging
ping_timeout=5000,
ping_interval=25000
)
)

View File

@@ -1,15 +1,24 @@
"""Stream and VOD management blueprint."""
import json
from datetime import datetime
from flask import Blueprint, session, jsonify, request, redirect
from utils.stream_utils import *
from utils.recommendation_utils import *
from utils.stream_utils import (
get_category_id, get_current_stream_data, get_streamer_live_status,
get_latest_vod, get_vod, get_user_vods, end_user_stream
)
from utils.recommendation_utils import (
get_user_preferred_category, get_highest_view_streams,
get_streams_based_on_category, get_highest_view_categories,
get_user_category_recommendations, get_followed_categories_recommendations
)
from utils.user_utils import get_user_id
from blueprints.middleware import login_required
from database.database import Database
from datetime import datetime
from celery_tasks.streaming import update_thumbnail, combine_ts_stream
from dateutil import parser
from celery_tasks.streaming import update_thumbnail
from utils.path_manager import PathManager
from PIL import Image
import json
stream_bp = Blueprint("stream", __name__)
@@ -28,12 +37,12 @@ def popular_streams(no_streams) -> list[dict]:
Returns a list of streams live now with the highest viewers
"""
# Limit the number of streams to MAX_STREAMS
MAX_STREAMS = 100
# Limit the number of streams to max_streams
max_streams = 100
if no_streams < 1:
return jsonify([])
elif no_streams > MAX_STREAMS:
no_streams = MAX_STREAMS
if no_streams > max_streams:
no_streams = max_streams
# Get the highest viewed streams
streams = get_highest_view_streams(no_streams)
@@ -101,7 +110,7 @@ def popular_categories(no_categories=4, offset=0) -> list[dict]:
# Limit the number of categories to 100
if no_categories < 1:
return jsonify([])
elif no_categories > 100:
if no_categories > 100:
no_categories = 100
category_data = get_highest_view_categories(no_categories, offset)
@@ -135,7 +144,8 @@ def following_categories_streams():
@stream_bp.route('/user/<string:username>/status')
def user_live_status(username):
"""
Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live)
Returns a streamer's status, if they are live or not and their
most recent stream (as a vod) (their current stream if live)
Returns:
{
"is_live": bool,
@@ -146,7 +156,7 @@ def user_live_status(username):
user_id = get_user_id(username)
# Check if streamer is live and get their most recent vod
is_live = True if get_streamer_live_status(user_id)['is_live'] else False
is_live = bool(get_streamer_live_status(user_id)['is_live'])
most_recent_vod = get_latest_vod(user_id)
# If there is no most recent vod, set it to None
@@ -171,25 +181,24 @@ def user_live_status_direct(username):
"""
user_id = get_user_id(username)
is_live = True if get_streamer_live_status(user_id)['is_live'] else False
is_live = bool(get_streamer_live_status(user_id)['is_live'])
if is_live:
return 'ok', 200
else:
return 'not live', 404
return 'not live', 404
# VOD Routes
@stream_bp.route('/vods/<int:vod_id>')
def vod(vod_id):
def single_vod(vod_id):
"""
Returns details about a specific vod
"""
vod = get_vod(vod_id)
return jsonify(vod)
vod_data = get_vod(vod_id)
return jsonify(vod_data)
@stream_bp.route('/vods/<string:username>')
def vods(username):
def user_vods(username):
"""
Returns a JSON of all the vods of a streamer
Returns:
@@ -204,11 +213,11 @@ def vods(username):
"views": int
}
]
"""
user_id = get_user_id(username)
vods = get_user_vods(user_id)
return jsonify(vods)
vod_list = get_user_vods(user_id)
return jsonify(vod_list)
@stream_bp.route('/vods/all')
def get_all_vods():
@@ -216,9 +225,14 @@ def get_all_vods():
Returns data of all VODs by all streamers in a JSON-compatible format
"""
with Database() as db:
vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id;""")
return jsonify(vods)
all_vods = db.fetchall(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories ON vods.category_id = categories.category_id;"""
)
return jsonify(all_vods)
# RTMP Server Routes
@@ -232,9 +246,10 @@ def init_stream():
with Database() as db:
# Check if valid stream key and user is allowed to stream
user_info = db.fetchone("""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
user_info = db.fetchone(
"""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
# No user found from stream key
if not user_info:
@@ -280,25 +295,27 @@ def publish_stream():
username = None
with Database() as db:
user_info = db.fetchone("""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
user_info = db.fetchone(
"""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
if not user_info or user_info.get("is_live"):
print(
"Unauthorized. No user found from Stream key or user is already streaming.", flush=True)
"Unauthorized. No user found from Stream key "
"or user is already streaming.", flush=True)
return "Unauthorized", 403
user_id = user_info.get("user_id")
username = user_info.get("username")
# Insert stream into database
db.execute("""INSERT INTO streams (user_id, title, start_time, num_viewers, category_id)
VALUES (?, ?, ?, ?, ?)""", (user_id,
stream_title,
datetime.now(),
0,
get_category_id(stream_category)))
db.execute(
"""INSERT INTO streams
(user_id, title, start_time, num_viewers, category_id)
VALUES (?, ?, ?, ?, ?)""",
(user_id, stream_title, datetime.now(), 0,
get_category_id(stream_category)))
# Set user as streaming
db.execute("""UPDATE users SET is_live = 1 WHERE user_id = ?""",
@@ -306,10 +323,12 @@ def publish_stream():
# Update thumbnail periodically only if a custom thumbnail is not provided
if not stream_thumbnail:
update_thumbnail.apply_async((user_id,
path_manager.get_stream_file_path(username),
path_manager.get_current_stream_thumbnail_file_path(username),
THUMBNAIL_GENERATION_INTERVAL), countdown=10)
update_thumbnail.apply_async(
(user_id,
path_manager.get_stream_file_path(username),
path_manager.get_current_stream_thumbnail_file_path(username),
THUMBNAIL_GENERATION_INTERVAL),
countdown=10)
return "OK", 200
@@ -332,31 +351,36 @@ def update_stream():
with Database() as db:
user_info = db.fetchone("""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
user_info = db.fetchone(
"""SELECT user_id, username, is_live
FROM users
WHERE stream_key = ?""", (stream_key,))
if not user_info or not user_info.get("is_live"):
print(
"Unauthorized - No user found from stream key or user is not streaming", flush=True)
"Unauthorized - No user found from stream key "
"or user is not streaming", flush=True)
return "Unauthorized", 403
user_id = user_info.get("user_id")
username = user_info.get("username")
# TODO: Add update to thumbnail here
db.execute("""UPDATE streams
SET title = ?, category_id = ?
WHERE user_id = ?""", (stream_title, get_category_id(stream_category), user_id))
db.execute(
"""UPDATE streams
SET title = ?, category_id = ?
WHERE user_id = ?""",
(stream_title, get_category_id(stream_category), user_id))
print("GOT: " + stream_thumbnail, flush=True)
if stream_thumbnail:
# Set custom thumbnail status to true
db.execute("""UPDATE streams
SET custom_thumbnail = ?
WHERE user_id = ?""", (True, user_id))
db.execute(
"""UPDATE streams
SET custom_thumbnail = ?
WHERE user_id = ?""", (True, user_id))
# Get thumbnail path
thumbnail_path = path_manager.get_current_stream_thumbnail_file_path(username)
@@ -390,26 +414,27 @@ def end_stream():
# Get user info from stream key
with Database() as db:
user_info = db.fetchone("""SELECT user_id, username
FROM users
WHERE stream_key = ?""", (stream_key,))
user_info = db.fetchone(
"""SELECT user_id, username
FROM users
WHERE stream_key = ?""", (stream_key,))
# Return unauthorized if no user found
if not user_info:
print("Unauthorized - No user found from stream key", flush=True)
return "Unauthorized", 403
# Get user info
user_id = user_info["user_id"]
username = user_info["username"]
# End stream
result, message = end_user_stream(stream_key, user_id, username)
# Return error if stream could not be ended
if not result:
print(f"Error ending stream: {message}", flush=True)
return "Error ending stream", 500
print(f"Stream ended: {message}", flush=True)
return "Stream ended", 200

View File

@@ -1,7 +1,11 @@
"""Stripe payment integration blueprint."""
import os
from flask import Blueprint, request, jsonify, session as s
from blueprints.middleware import login_required
from utils.user_utils import subscribe
import os, stripe
import stripe
stripe_bp = Blueprint("stripe", __name__)
@@ -20,7 +24,7 @@ def create_checkout_session():
# Checks to see who is subscribing to who
user_id = s.get("user_id")
streamer_id = request.args.get("streamer_id")
session = stripe.checkout.Session.create(
checkout_session = stripe.checkout.Session.create(
ui_mode = 'embedded',
payment_method_types=['card'],
line_items=[
@@ -33,19 +37,24 @@ def create_checkout_session():
redirect_on_completion = 'never',
client_reference_id = f"{user_id}-{streamer_id}"
)
except Exception as e:
except (ValueError, TypeError, KeyError) as e:
return str(e), 500
return jsonify(clientSecret=session.client_secret)
return jsonify(clientSecret=checkout_session.client_secret)
@stripe_bp.route('/session-status') # check for payment status
@stripe_bp.route('/session-status') # check for payment status
def session_status():
"""
Used to query payment status
"""
session = stripe.checkout.Session.retrieve(request.args.get('session_id'))
"""
Used to query payment status
"""
checkout_session = stripe.checkout.Session.retrieve(
request.args.get('session_id')
)
return jsonify(status=session.status, customer_email=session.customer_details.email)
return jsonify(
status=checkout_session.status,
customer_email=checkout_session.customer_details.email
)
@stripe_bp.route('/stripe/webhook', methods=['POST'])
def stripe_webhook():
@@ -65,13 +74,20 @@ def stripe_webhook():
raise e
except stripe.error.SignatureVerificationError as e:
raise e
if event['type'] == "checkout.session.completed": # Handles payment success webhook
session = event['data']['object']
product_id = stripe.checkout.Session.list_line_items(session['id'])['data'][0]['price']['product']
# Handles payment success webhook
if event['type'] == "checkout.session.completed":
checkout_session = event['data']['object']
product_id = stripe.checkout.Session.list_line_items(
checkout_session['id']
)['data'][0]['price']['product']
if product_id == subscription:
client_reference_id = session.get("client_reference_id")
user_id, streamer_id = map(int, client_reference_id.split("-"))
client_reference_id = checkout_session.get(
"client_reference_id"
)
user_id, streamer_id = map(
int, client_reference_id.split("-")
)
subscribe(user_id, streamer_id)
return "Success", 200
return "Success", 200

View File

@@ -1,17 +1,28 @@
"""User profile and account management blueprint."""
from flask import Blueprint, jsonify, session, request
from utils.user_utils import *
from utils.auth import *
from utils.user_utils import (
get_user_id, get_user, update_bio, is_subscribed,
subscription_expiration, delete_subscription, is_following,
follow, unfollow, get_followed_streamers, get_followed_categories,
is_following_category, follow_category, unfollow_category,
has_password, get_session_info_email
)
from database.database import Database
from utils.auth import verify_token, reset_password
from utils.utils import get_category_id
from blueprints.middleware import login_required
from utils.email import send_email, forgot_password_body, newsletter_conf, remove_from_newsletter, email_exists
from utils.email import (
send_email, forgot_password_body, newsletter_conf,
remove_from_newsletter, email_exists
)
from utils.path_manager import PathManager
from celery_tasks.streaming import convert_image_to_png
import redis
from PIL import Image
redis_url = "redis://redis:6379/1"
r = redis.from_url(redis_url, decode_responses=True)
REDIS_URL = "redis://redis:6379/1"
r = redis.from_url(REDIS_URL, decode_responses=True)
user_bp = Blueprint("user", __name__)
@@ -24,7 +35,7 @@ def user_data(username: str):
"""
user_id = get_user_id(username)
if not user_id:
jsonify({"error": "User not found from username"}), 404
return jsonify({"error": "User not found from username"}), 404
data = get_user(user_id)
return jsonify(data)
@@ -48,11 +59,11 @@ def user_profile_picture_save():
"""
username = session.get("username")
thumbnail_path = path_manager.get_profile_picture_file_path(username)
# Check if the post request has the file part
if 'image' not in request.files:
return jsonify({"error": "No image found in request"}), 400
# Fetch image, convert to png, and save
image = Image.open(request.files['image'])
image.convert('RGB')
@@ -73,7 +84,7 @@ def user_change_bio():
bio = data.get("bio")
update_bio(user_id, bio)
return jsonify({"status": "Success"}), 200
except Exception as e:
except (ValueError, TypeError, KeyError) as e:
return jsonify({"error": str(e)}), 400
## Subscription Routes
@@ -186,7 +197,7 @@ def user_login_status():
"""
username = session.get("username")
user_id = session.get("user_id")
return jsonify({'status': username is not None,
return jsonify({'status': username is not None,
'username': username,
'user_id': user_id})
@@ -198,13 +209,14 @@ def user_forgot_password(email):
exists = email_exists(email)
password = has_password(email)
# Checks if password exists and is not a Google OAuth account
if(exists and password):
if exists and password:
send_email(email, lambda: forgot_password_body(email))
return email
return jsonify({"error":"Invalid email or not found"}), 404
return jsonify({"error": "Invalid email or not found"}), 404
@user_bp.route("/send_newsletter/<string:email>", methods=["POST"])
def send_newsletter(email):
"""Sends a newsletter confirmation email."""
send_email(email, lambda: newsletter_conf(email))
return email
@@ -226,6 +238,7 @@ def user_reset_password(token, new_password):
@user_bp.route("/user/unsubscribe/<string:token>", methods=["POST"])
def unsubscribe(token):
"""Unsubscribes a user from the newsletter."""
salt = r.get(token)
if salt:
r.delete(token)
@@ -237,4 +250,4 @@ def unsubscribe(token):
if email:
remove_from_newsletter(email)
return jsonify({"message": "unsubscribed from newsletter"}), 200
return jsonify({"error": "Invalid token"}), 400
return jsonify({"error": "Invalid token"}), 400

View File

@@ -1,7 +1,12 @@
from celery import Celery, shared_task, Task
"""Celery configuration and Flask app context setup for async tasks."""
from celery import Celery, Task
def celery_init_app(app) -> Celery:
"""Initialize Celery with Flask application context."""
class FlaskTask(Task):
"""Celery task that runs within Flask app context."""
def __call__(self, *args: object, **kwargs: object) -> object:
with app.app_context():
return self.run(*args, **kwargs)
@@ -14,6 +19,10 @@ def celery_init_app(app) -> Celery:
'schedule': 30.0,
},
}
celery_app.conf.include = [
'celery_tasks.preferences',
'celery_tasks.streaming',
]
celery_app.set_default()
app.extensions["celery"] = celery_app
return celery_app

View File

@@ -1,4 +1,6 @@
"""Celery app initialization with Flask."""
from blueprints import create_app
flask_app = create_app()
celery_app = flask_app.extensions["celery"]
celery_app = flask_app.extensions["celery"]

View File

@@ -1,15 +1,19 @@
"""Scheduled task for updating user preferences based on stream viewing."""
import json
from celery import shared_task
from database.database import Database
import redis
import json
redis_url = "redis://redis:6379/1"
r = redis.from_url(redis_url, decode_responses=True)
REDIS_URL = "redis://redis:6379/1"
r = redis.from_url(REDIS_URL, decode_responses=True)
@shared_task
def user_preferences():
"""
Updates users preferences on different stream categories based on the streams they are currently watching
Updates users preferences on different stream categories
based on the streams they are currently watching
"""
stats = r.hget("current_viewers", "viewers")
# If there are any current viewers
@@ -21,13 +25,19 @@ def user_preferences():
# For each user and stream combination
for stream_id in stream_ids:
# Retrieves category associated with stream
current_category = db.fetchone("""SELECT category_id FROM streams
WHERE user_id = ?
""", (stream_id,))
# If stream is still live then update the user_preferences table to reflect their preferences
current_category = db.fetchone(
"""SELECT category_id FROM streams
WHERE user_id = ?
""", (stream_id,))
# If stream is still live then update the
# user_preferences table to reflect their preferences
if current_category:
db.execute("""INSERT INTO user_preferences (user_id,category_id,favourability)
VALUES (?,?,?)
ON CONFLICT(user_id, category_id)
DO UPDATE SET favourability = favourability + 1
""", (user_id, current_category["category_id"], 1))
db.execute(
"""INSERT INTO user_preferences
(user_id,category_id,favourability)
VALUES (?,?,?)
ON CONFLICT(user_id, category_id)
DO UPDATE SET
favourability = favourability + 1
""", (user_id,
current_category["category_id"], 1))

View File

@@ -1,11 +1,14 @@
from celery import Celery, shared_task, Task
from datetime import datetime
from celery_tasks.preferences import user_preferences
from utils.stream_utils import generate_thumbnail, get_streamer_live_status, get_custom_thumbnail_status, remove_hls_files, get_video_duration
from time import sleep
from os import listdir, remove
from utils.path_manager import PathManager
"""Async tasks for stream thumbnail updates, VOD creation, and image conversion."""
import subprocess
from os import listdir
from celery import shared_task
from utils.stream_utils import (
generate_thumbnail, get_streamer_live_status,
get_custom_thumbnail_status, remove_hls_files, get_video_duration
)
from utils.path_manager import PathManager
path_manager = PathManager()
@@ -16,10 +19,13 @@ def update_thumbnail(user_id, stream_file, thumbnail_file, sleep_time, second_ca
"""
# Check if stream is still live and custom thumbnail has not been set
if get_streamer_live_status(user_id)['is_live'] and not get_custom_thumbnail_status(user_id)['custom_thumbnail']:
if (get_streamer_live_status(user_id)['is_live']
and not get_custom_thumbnail_status(user_id)['custom_thumbnail']):
print("Updating thumbnail...")
generate_thumbnail(stream_file, thumbnail_file)
update_thumbnail.apply_async((user_id, stream_file, thumbnail_file, sleep_time, second_capture), countdown=sleep_time)
update_thumbnail.apply_async(
(user_id, stream_file, thumbnail_file, sleep_time, second_capture),
countdown=sleep_time)
else:
print(f"Stopping thumbnail updates for stream of {user_id}")
@@ -34,12 +40,12 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) ->
ts_files.sort()
# Create temp file listing all ts files
with open(f"{stream_path}/list.txt", "w") as f:
with open(f"{stream_path}/list.txt", "w", encoding="utf-8") as f:
for ts_file in ts_files:
f.write(f"file '{ts_file}'\n")
# Concatenate all ts files into a single vod
vod_command = [
"ffmpeg",
"-f",
@@ -53,7 +59,7 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) ->
vod_file_path
]
subprocess.run(vod_command)
subprocess.run(vod_command, check=True)
# Remove HLS files, even if user is not streaming
remove_hls_files(stream_path)
@@ -78,4 +84,4 @@ def convert_image_to_png(image_path, png_path):
png_path
]
subprocess.run(image_command)
subprocess.run(image_command, check=True)

View File

@@ -1,7 +1,12 @@
"""SQLite database connection management with context manager support."""
import sqlite3
import os
class Database:
"""Database wrapper providing connection management and query execution."""
def __init__(self) -> None:
self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db")
self._conn = None
@@ -63,4 +68,6 @@ class Database:
if not result:
return []
columns = [desc[0] for desc in self.cursor.description]
return [dict(zip(columns, row)) for row in result] if isinstance(result, list) else dict(zip(columns, result))
if isinstance(result, list):
return [dict(zip(columns, row)) for row in result]
return dict(zip(columns, result))

View File

@@ -1,8 +1,10 @@
"""Admin utility functions for user management."""
from database.database import Database
def check_if_admin(username):
"""
Returns whether user is admin
Returns whether user is admin
"""
with Database() as db:
is_admin = db.fetchone("""
@@ -10,7 +12,7 @@ def check_if_admin(username):
FROM users
WHERE username = ?;
""", (username,))
return bool(is_admin)
def check_if_user_exists(banned_user):
@@ -19,11 +21,11 @@ def check_if_user_exists(banned_user):
"""
with Database() as db:
user_exists = db.fetchone("""
SELECT user_id
SELECT user_id
FROM users
WHERE username = ?;""",
(banned_user,))
return bool(user_exists)
def ban_user(banned_user):
@@ -33,6 +35,5 @@ def ban_user(banned_user):
with Database() as db:
db.execute("""
DELETE FROM users
WHERE username = ?;""",
(banned_user)
)
WHERE username = ?;""",
(banned_user,))

View File

@@ -1,13 +1,17 @@
"""Token generation and verification for password resets."""
from typing import Optional
from os import getenv
from database.database import Database
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from typing import Optional
from dotenv import load_dotenv
from os import getenv
from werkzeug.security import generate_password_hash
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY"))
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
@@ -19,7 +23,6 @@ def verify_token(token: str, salt_value) -> Optional[str]:
"""
Given a token, verifies and decodes it into an email
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
@@ -38,7 +41,7 @@ def reset_password(new_password: str, email: str):
"""
with Database() as db:
db.execute("""
UPDATE users
SET password = ?
UPDATE users
SET password = ?
WHERE email = ?
""", (generate_password_hash(new_password), email))

View File

@@ -1,16 +1,18 @@
"""Email sending utilities for password reset, account confirmation, and newsletters."""
import smtplib
from email.mime.text import MIMEText
from os import getenv
from secrets import token_hex
from dotenv import load_dotenv
from utils.auth import generate_token
from secrets import token_hex
from .user_utils import get_session_info_email
import redis
from database.database import Database
redis_url = "redis://redis:6379/1"
r = redis.from_url(redis_url, decode_responses=True)
REDIS_URL = "redis://redis:6379/1"
r = redis.from_url(REDIS_URL, decode_responses=True)
load_dotenv()
@@ -23,31 +25,31 @@ def send_email(email, func) -> None:
"""
# Setup the sender email details
SMTP_SERVER = "smtp.gmail.com"
SMTP_PORT = 587
SMTP_EMAIL = getenv("EMAIL")
SMTP_PASSWORD = getenv("EMAIL_PASSWORD")
smtp_server = "smtp.gmail.com"
smtp_port = 587
smtp_email = getenv("EMAIL")
smtp_password = getenv("EMAIL_PASSWORD")
# Setup up the receiver details
body, subject = func()
msg = MIMEText(body, "html")
msg["Subject"] = subject
msg["From"] = SMTP_EMAIL
msg["From"] = smtp_email
msg["To"] = email
# Send the email using smtplib
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp:
with smtplib.SMTP(smtp_server, smtp_port) as smtp:
try:
smtp.starttls() # TLS handshake to start the connection
smtp.login(SMTP_EMAIL, SMTP_PASSWORD)
smtp.login(smtp_email, smtp_password)
smtp.ehlo()
smtp.send_message(msg)
except TimeoutError:
print("Server timed out", flush=True)
except Exception as e:
except smtplib.SMTPException as e:
print("Error: ", e, flush=True)
def forgot_password_body(email) -> str:
@@ -57,7 +59,7 @@ def forgot_password_body(email) -> str:
salt = token_hex(32)
token = generate_token(email, salt)
token += "R3sET"
token += "R3sET"
r.setex(token, 3600, salt)
username = (get_session_info_email(email))["username"]
@@ -77,7 +79,7 @@ def forgot_password_body(email) -> str:
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Password Reset Request</h2>
<p>Click the button below to reset your password for your account {username}. This link is valid for 1 hour.</p>
<a href="{full_url}" class="btn">Reset Password</a>
@@ -116,7 +118,7 @@ def confirm_account_creation_body(email) -> str:
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Confirm Account Creation</h2>
<p>Click the button below to create your account. This link is valid for 1 hour.</p>
<a href="{full_url}" class="btn">Create Account</a>
@@ -154,7 +156,7 @@ def newsletter_conf(email):
</head>
<body>
<div class="container">
<h1>Gander</h1>
<h1>Gander</h1>
<h2>Welcome to the Official Gander Newsletter!</h2>
<p>If you are receiving this email, it means that you have been officially added to the Monthly Gander newsletter.</p>
<p>In this newsletter, you will receive updates about: your favourite streamers; important Gander updates; and more!</p>
@@ -169,7 +171,7 @@ def newsletter_conf(email):
user_exists = db.fetchone("""
SELECT *
FROM newsletter
WHERE email = ?;""",
WHERE email = ?;""",
(email,))
print(user_exists, flush=True)
@@ -190,7 +192,7 @@ def add_to_newsletter(email):
INSERT INTO newsletter (email)
VALUES (?);
""", (email,))
def remove_from_newsletter(email):
"""
Remove a person from the newsletter database
@@ -202,7 +204,7 @@ def remove_from_newsletter(email):
DELETE FROM newsletter
WHERE email = ?;
""", (email,))
def email_exists(email):
"""
Returns whether email exists within database
@@ -210,6 +212,6 @@ def email_exists(email):
with Database() as db:
data = db.fetchone("""
SELECT * FROM users
WHERE email = ?
WHERE email = ?
""", (email,))
return bool(data)
return bool(data)

View File

@@ -1,8 +1,11 @@
"""Description: This file contains the PathManager class which is responsible for managing the paths of the stream data."""
"""File system path management for user streams, VODs, and profile pictures."""
import os
class PathManager():
"""Manages paths for user stream data, VODs, and profile pictures."""
def __init__(self) -> None:
self.root_path = "user_data"
self.vod_directory_name = "vods"
@@ -26,7 +29,7 @@ class PathManager():
self._create_directory(os.path.join(self.root_path, username))
vods_path = self.get_vods_path(username)
stream_path = self.get_stream_path(username)
self._create_directory(vods_path)
self._create_directory(stream_path)
@@ -39,25 +42,33 @@ class PathManager():
os.rmdir(user_path)
def get_user_path(self, username):
"""Returns the base path for a user's data directory."""
return os.path.join(self.root_path, username)
def get_vods_path(self, username):
"""Returns the path to a user's VODs directory."""
return os.path.join(self.root_path, username, self.vod_directory_name)
def get_stream_path(self, username):
"""Returns the path to a user's stream directory."""
return os.path.join(self.root_path, username, self.stream_directory_name)
def get_stream_file_path(self, username):
"""Returns the path to a user's stream index file."""
return os.path.join(self.get_stream_path(username), "index.m3u8")
def get_current_stream_thumbnail_file_path(self, username):
"""Returns the path to a user's current stream thumbnail."""
return os.path.join(self.get_stream_path(username), "index.png")
def get_vod_file_path(self, username, vod_id):
"""Returns the path to a specific VOD file."""
return os.path.join(self.get_vods_path(username), f"{vod_id}.mp4")
def get_vod_thumbnail_file_path(self, username, vod_id):
"""Returns the path to a specific VOD thumbnail."""
return os.path.join(self.get_vods_path(username), f"{vod_id}.png")
def get_profile_picture_file_path(self, username):
return os.path.join(self.root_path, username, self.profile_picture_name)
"""Returns the path to a user's profile picture."""
return os.path.join(self.root_path, username, self.profile_picture_name)

View File

@@ -1,32 +1,42 @@
from database.database import Database
"""Personalized stream and category recommendations based on user preferences."""
from typing import Optional, List
from database.database import Database
def get_user_preferred_category(user_id: int) -> Optional[int]:
"""
Queries user_preferences database to find users favourite streaming category and returns the category
Queries user_preferences database to find users favourite
streaming category and returns the category
"""
with Database() as db:
category = db.fetchone("""
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
return category["category_id"] if category else None
def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]:
def get_followed_categories_recommendations(
user_id: int, no_streams: int = 4
) -> Optional[List[dict]]:
"""
Returns top streams given a user's category following
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
SELECT u.user_id, title, u.username, num_viewers,
category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
JOIN categories
ON streams.category_id = categories.category_id
WHERE categories.category_id IN
(SELECT category_id
FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT ?;
""", (user_id, no_streams))
@@ -40,25 +50,29 @@ def get_followed_your_categories(user_id: int) -> Optional[List[dict]]:
categories = db.fetchall("""
SELECT categories.category_name
FROM categories
JOIN followed_categories
JOIN followed_categories
ON categories.category_id = followed_categories.category_id
WHERE followed_categories.user_id = ?;
""", (user_id,))
return categories
def get_streams_based_on_category(category_id: int, no_streams: int = 4, offset: int = 0) -> Optional[List[dict]]:
def get_streams_based_on_category(
category_id: int, no_streams: int = 4, offset: int = 0
) -> Optional[List[dict]]:
"""
Queries stream database to get top most viewed streams based on given category
Queries stream database to get top most viewed streams
based on given category
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
SELECT u.user_id, title, username, num_viewers,
c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT ? OFFSET ?
""", (category_id, no_streams, offset))
return streams
@@ -70,43 +84,63 @@ def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]:
"""
with Database() as db:
data = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
SELECT u.user_id, username, title, num_viewers,
category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
JOIN categories
ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return data
def get_highest_view_categories(no_categories: int = 4, offset: int = 0) -> Optional[List[dict]]:
def get_highest_view_categories(
no_categories: int = 4, offset: int = 0
) -> Optional[List[dict]]:
"""
Returns a list of top most popular categories given offset
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
SELECT categories.category_id,
categories.category_name,
COALESCE(SUM(streams.num_viewers), 0)
AS num_viewers
FROM categories
LEFT JOIN streams ON streams.category_id = categories.category_id
GROUP BY categories.category_id, categories.category_name
LEFT JOIN streams
ON streams.category_id = categories.category_id
GROUP BY categories.category_id,
categories.category_name
ORDER BY num_viewers DESC
LIMIT ? OFFSET ?;
""", (no_categories, offset))
return categories
def get_user_category_recommendations(user_id = 1, no_categories: int = 4) -> Optional[List[dict]]:
def get_user_category_recommendations(
user_id=1, no_categories: int = 4
) -> Optional[List[dict]]:
"""
Queries user_preferences database to find users top favourite streaming category and returns the category
Queries user_preferences database to find users top favourite
streaming category and returns the category
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
LEFT JOIN streams ON categories.category_id = streams.category_id
WHERE user_preferences.user_id = ?
GROUP BY categories.category_id, categories.category_name
ORDER BY user_preferences.favourability DESC
SELECT categories.category_id,
categories.category_name,
COALESCE(SUM(streams.num_viewers), 0)
AS num_viewers
FROM categories
JOIN user_preferences
ON categories.category_id
= user_preferences.category_id
LEFT JOIN streams
ON categories.category_id
= streams.category_id
WHERE user_preferences.user_id = ?
GROUP BY categories.category_id,
categories.category_name
ORDER BY user_preferences.favourability DESC
LIMIT ?
""", (user_id, no_categories))
return categories
return categories

View File

@@ -1,9 +1,11 @@
from database.database import Database
from typing import Optional
import os, subprocess
"""Stream data retrieval and management utilities."""
import os
import subprocess
from typing import Optional, List
from time import sleep
from utils.path_manager import PathManager
from database.database import Database
def get_streamer_live_status(user_id: int):
"""
@@ -11,8 +13,8 @@ def get_streamer_live_status(user_id: int):
"""
with Database() as db:
is_live = db.fetchone("""
SELECT is_live
FROM users
SELECT is_live
FROM users
WHERE user_id = ?;
""", (user_id,))
@@ -21,17 +23,19 @@ def get_streamer_live_status(user_id: int):
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
Returns a list of live streams with the streamer's user id,
stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
SELECT users.user_id, streams.title,
streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
def get_current_stream_data(user_id: int) -> Optional[dict]:
@@ -40,7 +44,8 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
"""
with Database() as db:
most_recent_stream = db.fetchone("""
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name, c.category_id
SELECT s.user_id, u.username, s.title, s.start_time,
s.num_viewers, c.category_name, c.category_id
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
@@ -51,78 +56,92 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
def end_user_stream(stream_key, user_id, username):
"""
Utility function to end a user's stream
Parameters:
stream_key: The stream key of the user
user_id: The ID of the user
username: The username of the user
Returns:
bool: True if stream was ended successfully, False otherwise
"""
from flask import current_app
from datetime import datetime
from dateutil import parser
from celery_tasks.streaming import combine_ts_stream
from utils.path_manager import PathManager
path_manager = PathManager()
pm = PathManager()
print(f"Ending stream for user {username} (ID: {user_id})", flush=True)
if not stream_key or not user_id or not username:
print("Cannot end stream - missing required information", flush=True)
return False
return False, "Missing required information"
try:
# Open database connection
with Database() as db:
# Get stream info
stream_info = db.fetchone("""SELECT *
FROM streams
WHERE user_id = ?""", (user_id,))
stream_info = db.fetchone(
"""SELECT *
FROM streams
WHERE user_id = ?""", (user_id,))
# If user is not streaming, just return
if not stream_info:
print(f"User {username} (ID: {user_id}) is not streaming", flush=True)
print(
f"User {username} (ID: {user_id}) "
"is not streaming", flush=True)
return True, "User is not streaming"
# Remove stream from database
db.execute("""DELETE FROM streams
WHERE user_id = ?""", (user_id,))
db.execute(
"""DELETE FROM streams
WHERE user_id = ?""", (user_id,))
# Move stream to vod table
stream_length = int(
(datetime.now() - parser.parse(stream_info.get("start_time"))).total_seconds())
db.execute("""INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?)""", (user_id,
stream_info.get("title"),
stream_info.get("start_time"),
stream_info.get("category_id"),
stream_length,
0))
(datetime.now() - parser.parse(
stream_info.get("start_time")
)).total_seconds())
db.execute(
"""INSERT INTO vods
(user_id, title, datetime, category_id,
length, views)
VALUES (?, ?, ?, ?, ?, ?)""",
(user_id,
stream_info.get("title"),
stream_info.get("start_time"),
stream_info.get("category_id"),
stream_length,
0))
vod_id = db.get_last_insert_id()
# Set user as not streaming
db.execute("""UPDATE users
SET is_live = 0
WHERE user_id = ?""", (user_id,))
db.execute(
"""UPDATE users
SET is_live = 0
WHERE user_id = ?""", (user_id,))
# Queue task to combine TS files into MP4
combine_ts_stream.delay(
path_manager.get_stream_path(username),
path_manager.get_vods_path(username),
pm.get_stream_path(username),
pm.get_vods_path(username),
vod_id,
path_manager.get_vod_thumbnail_file_path(username, vod_id)
pm.get_vod_thumbnail_file_path(username, vod_id)
)
print(f"Stream ended for user {username} (ID: {user_id})", flush=True)
print(
f"Stream ended for user {username} (ID: {user_id})",
flush=True)
return True, "Stream ended successfully"
except Exception as e:
print(f"Error ending stream for user {username}: {str(e)}", flush=True)
return False, f"Error ending stream: {str(e)}"
except (ValueError, TypeError, KeyError) as exc:
print(
f"Error ending stream for user {username}: {str(exc)}",
flush=True)
return False, f"Error ending stream: {str(exc)}"
def get_category_id(category_name: str) -> Optional[int]:
"""
@@ -130,8 +149,8 @@ def get_category_id(category_name: str) -> Optional[int]:
"""
with Database() as db:
data = db.fetchone("""
SELECT category_id
FROM categories
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
return data['category_id'] if data else None
@@ -142,7 +161,7 @@ def get_custom_thumbnail_status(user_id: int) -> Optional[dict]:
"""
with Database() as db:
custom_thumbnail = db.fetchone("""
SELECT custom_thumbnail
SELECT custom_thumbnail
FROM streams
WHERE user_id = ?;
""", (user_id,))
@@ -153,7 +172,13 @@ def get_vod(vod_id: int) -> dict:
Returns data of a streamers vod
"""
with Database() as db:
vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vod_id = ?;""", (vod_id,))
vod = db.fetchone(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vod_id = ?;""", (vod_id,))
return vod
def get_latest_vod(user_id: int):
@@ -161,7 +186,14 @@ def get_latest_vod(user_id: int):
Returns data of the most recent stream by a streamer
"""
with Database() as db:
latest_vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
latest_vod = db.fetchone(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vods.user_id = ?
ORDER BY vod_id DESC;""", (user_id,))
return latest_vod
def get_user_vods(user_id: int):
@@ -169,10 +201,17 @@ def get_user_vods(user_id: int):
Returns data of all vods by a streamer
"""
with Database() as db:
vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
vods = db.fetchall(
"""SELECT vods.*, username, category_name
FROM vods
JOIN users ON vods.user_id = users.user_id
JOIN categories
ON vods.category_id = categories.category_id
WHERE vods.user_id = ?
ORDER BY vod_id DESC;""", (user_id,))
return vods
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) -> None:
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture=0) -> None:
"""
Generates the thumbnail of a stream
"""
@@ -192,10 +231,16 @@ def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) ->
]
try:
subprocess.run(thumbnail_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
subprocess.run(
thumbnail_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True)
print(f"Thumbnail generated for {stream_file}")
except subprocess.CalledProcessError as e:
print(f"No information available for {stream_file}, aborting thumbnail generation")
except subprocess.CalledProcessError:
print(
f"No information available for {stream_file}, "
"aborting thumbnail generation")
def remove_hls_files(stream_path: str) -> None:
"""
@@ -221,11 +266,13 @@ def get_video_duration(video_path: str) -> int:
]
try:
video_length = subprocess.check_output(video_length_command).decode("utf-8")
video_length = subprocess.check_output(
video_length_command
).decode("utf-8")
print(f"Video length: {video_length}")
return int(float(video_length))
except subprocess.CalledProcessError as e:
print(f"Error getting video length: {e}")
except subprocess.CalledProcessError:
print("Error getting video length")
return 0
def get_stream_tags(user_id: int) -> Optional[List[str]]:
@@ -234,10 +281,10 @@ def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
WHERE user_id = ?;
""", (user_id,))
return tags
@@ -247,9 +294,9 @@ def get_vod_tags(vod_id: int):
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
WHERE vod_id = ?;
""", (vod_id,))
return tags
return tags

View File

@@ -1,16 +1,20 @@
from database.database import Database
"""User profile management, following, and subscription utilities."""
from typing import Optional, List
from datetime import datetime, timedelta
from database.database import Database
from dateutil import parser
def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
@@ -21,7 +25,7 @@ def get_username(user_id: str) -> Optional[str]:
"""
with Database() as db:
data = db.fetchone("""
SELECT username
SELECT username
FROM users
WHERE user_id = ?
""", (user_id,))
@@ -40,15 +44,16 @@ def update_bio(user_id: int, bio: str):
def has_password(email: str):
"""
Returns if account associated with this email has password, i.e is account from Google OAuth
Returns if account associated with this email has password,
i.e is account from Google OAuth
"""
with Database() as db:
data = db.fetchone("""
SELECT password
FROM users
WHERE email = ?
WHERE email = ?
""", (email,))
return False if data["password"] == None else True
return data["password"] is not None
def get_session_info_email(email: str) -> dict:
"""
@@ -61,15 +66,15 @@ def get_session_info_email(email: str) -> dict:
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
@@ -79,12 +84,12 @@ def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
with Database() as db:
return bool(db.fetchone(
"""
SELECT 1
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
SELECT 1
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""",
""",
(user_id, subscribed_to_id, datetime.now())
))
@@ -94,9 +99,9 @@ def is_following(user_id: int, followed_id: int) -> bool:
"""
with Database() as db:
result = db.fetchone("""
SELECT 1
FROM follows
WHERE user_id = ?
SELECT 1
FROM follows
WHERE user_id = ?
AND followed_id = ?
""", (user_id, followed_id))
return bool(result)
@@ -107,7 +112,7 @@ def follow(user_id: int, followed_id: int):
"""
if is_following(user_id, followed_id):
return {"success": False, "error": "Already following user"}, 400
with Database() as db:
db.execute("""
INSERT INTO follows (user_id, followed_id)
@@ -135,9 +140,9 @@ def is_following_category(user_id: int, category_id: str):
"""
with Database() as db:
result = db.fetchone("""
SELECT 1
FROM followed_categories
WHERE user_id = ?
SELECT 1
FROM followed_categories
WHERE user_id = ?
AND category_id = ?
""", (user_id, category_id))
return bool(result)
@@ -148,7 +153,7 @@ def follow_category(user_id: int, category_id: str):
"""
if is_following_category(user_id, category_id):
return {"success": False, "error": "Already following category"}, 400
with Database() as db:
db.execute("""
INSERT INTO followed_categories (user_id, category_id)
@@ -163,7 +168,7 @@ def unfollow_category(user_id: int, category_id: str):
"""
if not is_following_category(user_id, category_id):
return {"success": False, "error": "Not following category"}, 400
with Database() as db:
db.execute("""
DELETE FROM followed_categories
@@ -179,8 +184,8 @@ def subscribe(user_id: int, streamer_id: int):
# If user is already subscribed then extend the expiration date else create a new entry
with Database() as db:
existing = db.fetchone("""
SELECT expires
FROM subscribes
SELECT expires
FROM subscribes
WHERE user_id = ? AND subscribed_id = ?
""", (user_id, streamer_id))
if existing:
@@ -188,7 +193,7 @@ def subscribe(user_id: int, streamer_id: int):
UPDATE subscribes SET expires = expires + ?
WHERE user_id = ? AND subscribed_id = ?
""", (timedelta(days=30), user_id, streamer_id))
else:
else:
db.execute("""
INSERT INTO subscribes
(user_id, subscribed_id, since, expires)
@@ -212,10 +217,10 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
"""
with Database() as db:
data = db.fetchone("""
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?
""", (user_id, subscribed_id, datetime.now()))
@@ -227,13 +232,14 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
return 0
def get_email(user_id: int) -> Optional[str]:
"""Returns the email address for a given user_id."""
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_followed_streamers(user_id: int) -> Optional[List[dict]]:
@@ -289,4 +295,4 @@ def get_user(user_id: int) -> Optional[dict]:
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data
return data

View File

@@ -1,14 +1,18 @@
from database.database import Database
"""Input sanitization and validation utilities."""
from typing import Optional, List
from re import match
from database.database import Database
def get_all_categories() -> Optional[List[dict]]:
"""
Returns all possible streaming categories
"""
with Database() as db:
all_categories = db.fetchall("SELECT * FROM categories")
return all_categories
def get_all_tags() -> Optional[List[dict]]:
@@ -17,7 +21,7 @@ def get_all_tags() -> Optional[List[dict]]:
"""
with Database() as db:
all_tags = db.fetchall("SELECT * FROM tags")
return all_tags
def get_most_popular_category() -> Optional[List[dict]]:
@@ -34,7 +38,7 @@ def get_most_popular_category() -> Optional[List[dict]]:
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
""")
return category
def get_category_id(category_name: str):
@@ -53,7 +57,7 @@ def get_category_id(category_name: str):
def sanitize(user_input: str, input_type="default") -> str:
"""
Sanitizes user input based on the specified input type.
`input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password').
"""
# Strip leading and trailing whitespace
@@ -84,10 +88,10 @@ def sanitize(user_input: str, input_type="default") -> str:
}
# Get the validation rules for the specified type
r = rules.get(input_type)
if not r or \
not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \
not match(r["pattern"], sanitised_input):
rule = rules.get(input_type)
if (not rule
or not (rule["min_length"] <= len(sanitised_input) <= rule["max_length"])
or not match(rule["pattern"], sanitised_input)):
raise ValueError("Unaccepted character or length in input")
return sanitised_input
return sanitised_input