From 2758be86800972af3307953c10691e16734ef6cc Mon Sep 17 00:00:00 2001 From: Christopher Ahern <115749466+Chris-1010@users.noreply.github.com> Date: Sat, 7 Feb 2026 20:57:28 +0000 Subject: [PATCH] Fix/pylint cleanup (#8) * Fix pylint warnings across all 24 Python files in web_server - Add module, class, and function docstrings (C0114, C0115, C0116) - Fix import ordering: stdlib before third-party before local (C0411) - Replace wildcard imports with explicit named imports (W0401) - Remove trailing whitespace and add missing final newlines (C0303, C0304) - Replace dict() with dict literals (R1735) - Remove unused imports and variables (W0611, W0612) - Narrow broad Exception catches to specific exceptions (W0718) - Replace f-string logging with lazy % formatting (W1203) - Fix variable naming: UPPER_CASE for constants, snake_case for locals (C0103) - Add pylint disable comments for necessary global statements (W0603) - Fix no-else-return, simplifiable-if-expression, singleton-comparison - Fix bad indentation in stripe.py (W0311) - Add encoding="utf-8" to open() calls (W1514) - Add check=True to subprocess.run() calls (W1510) - Register Celery task modules via conf.include * Update `package-lock.json` add peer dependencies --- frontend/package-lock.json | 24 ++- web_server/blueprints/__init__.py | 19 +-- web_server/blueprints/admin.py | 8 +- web_server/blueprints/authentication.py | 79 +++++---- web_server/blueprints/chat.py | 46 +++--- web_server/blueprints/middleware.py | 11 +- web_server/blueprints/oauth.py | 52 +++--- web_server/blueprints/search_bar.py | 16 +- web_server/blueprints/socket.py | 6 +- web_server/blueprints/streams.py | 153 ++++++++++-------- web_server/blueprints/stripe.py | 50 ++++-- web_server/blueprints/user.py | 41 +++-- web_server/celery_tasks/__init__.py | 11 +- web_server/celery_tasks/celery_app.py | 4 +- web_server/celery_tasks/preferences.py | 36 +++-- web_server/celery_tasks/streaming.py | 34 ++-- web_server/database/database.py | 9 +- web_server/utils/admin_utils.py | 15 +- web_server/utils/auth.py | 13 +- web_server/utils/email.py | 46 +++--- web_server/utils/path_manager.py | 29 ++-- web_server/utils/recommendation_utils.py | 104 ++++++++---- web_server/utils/stream_utils.py | 197 ++++++++++++++--------- web_server/utils/user_utils.py | 72 +++++---- web_server/utils/utils.py | 24 +-- 25 files changed, 680 insertions(+), 419 deletions(-) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 93ebb29..e4d9ddf 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "frontend", - "version": "0.5.0", + "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "frontend", - "version": "0.5.0", + "version": "1.0.0", "dependencies": { "@stripe/react-stripe-js": "^3.1.1", "@stripe/stripe-js": "^5.5.0", @@ -97,6 +97,7 @@ "integrity": "sha512-i1SLeK+DzNnQ3LL/CswPCa/E5u4lh1k6IAEphON8F+cXt0t9euTshDru0q7/IqMa1PMPz5RnHuHscF8/ZJsStg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@ampproject/remapping": "^2.2.0", "@babel/code-frame": "^7.26.0", @@ -1393,6 +1394,7 @@ "resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-5.5.0.tgz", "integrity": "sha512-lkfjyAd34aeMpTKKcEVfy8IUyEsjuAT3t9EXr5yZDtdIUncnZpedl/xLV16Dkd4z+fQwixScsCCDxSMNtBOgpQ==", "license": "MIT", + "peer": true, "engines": { "node": ">=12.16" } @@ -1482,6 +1484,7 @@ "integrity": "sha512-t4yC+vtgnkYjNSKlFx1jkAhH8LgTo2N/7Qvi83kdEaUtMDiwpbLAktKDaAMlRcJ5eSxZkH74eEGt1ky31d7kfQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/prop-types": "*", "csstype": "^3.0.2" @@ -1539,6 +1542,7 @@ "integrity": "sha512-gKXG7A5HMyjDIedBi6bUrDcun8GIjnI8qOwVLiY3rx6T/sHP/19XLJOnIq/FgQvWLHja5JN/LSE7eklNBr612g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.20.0", "@typescript-eslint/types": "8.20.0", @@ -1805,6 +1809,7 @@ "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2018,6 +2023,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "caniuse-lite": "^1.0.30001688", "electron-to-chromium": "^1.5.73", @@ -2051,9 +2057,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001695", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001695.tgz", - "integrity": "sha512-vHyLade6wTgI2u1ec3WQBxv+2BrTERV28UXQu9LO6lZ9pYeMk34vjXFLOxo1A4UBA8XTL4njRQZdno/yYaSmWw==", + "version": "1.0.30001769", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001769.tgz", + "integrity": "sha512-BCfFL1sHijQlBGWBMuJyhZUhzo7wer5sVj9hqekB/7xn0Ypy+pER/edCYQm4exbXj4WiySGp40P8UuTh6w1srg==", "dev": true, "funding": [ { @@ -2386,6 +2392,7 @@ "integrity": "sha512-+waTfRWQlSbpt3KWE+CjrPPYnbq9kfZIYUqapc0uBXyjTp8aYXZDsUH16m39Ryq3NjAVP4tjuF7KaukeqoCoaA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.12.1", @@ -3534,6 +3541,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.8", "picocolors": "^1.1.1", @@ -3723,6 +3731,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -3741,6 +3750,7 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" @@ -4233,6 +4243,7 @@ "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.17.tgz", "integrity": "sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==", "license": "MIT", + "peer": true, "dependencies": { "@alloc/quick-lru": "^5.2.0", "arg": "^5.0.2", @@ -4351,6 +4362,7 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -4443,6 +4455,7 @@ "resolved": "https://registry.npmjs.org/video.js/-/video.js-8.21.0.tgz", "integrity": "sha512-zcwerRb257QAuWfi8NH9yEX7vrGKFthjfcONmOQ4lxFRpDAbAi+u5LAjCjMWqhJda6zEmxkgdDpOMW3Y21QpXA==", "license": "Apache-2.0", + "peer": true, "dependencies": { "@babel/runtime": "^7.12.5", "@videojs/http-streaming": "^3.16.2", @@ -4495,6 +4508,7 @@ "integrity": "sha512-4VL9mQPKoHy4+FE0NnRE/kbY51TOfaknxAjt3fJbGJxhIpBZiqVzlZDEesWWsuREXHwNdAoOFZ9MkPEVXczHwg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.24.2", "postcss": "^8.4.49", diff --git a/web_server/blueprints/__init__.py b/web_server/blueprints/__init__.py index 96507db..a8aeaa1 100644 --- a/web_server/blueprints/__init__.py +++ b/web_server/blueprints/__init__.py @@ -1,3 +1,7 @@ +"""Flask application factory and blueprint registration.""" + +from os import getenv + from flask import Flask from flask_session import Session from flask_cors import CORS @@ -13,10 +17,8 @@ from blueprints.oauth import oauth_bp, init_oauth from blueprints.socket import socketio from blueprints.search_bar import search_bp -from celery import Celery from celery_tasks import celery_init_app -from os import getenv def create_app(): """ @@ -34,16 +36,15 @@ def create_app(): app.config['GOOGLE_CLIENT_SECRET'] = getenv("GOOGLE_CLIENT_SECRET") app.config["SESSION_COOKIE_HTTPONLY"] = True - app.config.from_mapping( - CELERY=dict( - broker_url="redis://redis:6379/0", - result_backend="redis://redis:6379/0", - task_ignore_result=True, - ), + CELERY={ + "broker_url": "redis://redis:6379/0", + "result_backend": "redis://redis:6379/0", + "task_ignore_result": True, + }, ) app.config.from_prefixed_env() - celery = celery_init_app(app) + celery_init_app(app) #! ↓↓↓ For development purposes only - Allow cross-origin requests for the frontend CORS(app, supports_credentials=True) diff --git a/web_server/blueprints/admin.py b/web_server/blueprints/admin.py index 3493a60..f85b1e1 100644 --- a/web_server/blueprints/admin.py +++ b/web_server/blueprints/admin.py @@ -1,6 +1,8 @@ +"""Admin blueprint for user management operations.""" + from flask import Blueprint, session from utils.utils import sanitize -from utils.admin_utils import * +from utils.admin_utils import check_if_admin, check_if_user_exists, ban_user admin_bp = Blueprint("admin", __name__) @@ -18,10 +20,10 @@ def admin_delete_user(banned_user): # Check if the user is an admin username = session.get("username") is_admin = check_if_admin(username) - + # Check if the user exists user_exists = check_if_user_exists(banned_user) # If the user is an admin, try to delete the account if is_admin and user_exists: - ban_user(banned_user) \ No newline at end of file + ban_user(banned_user) diff --git a/web_server/blueprints/authentication.py b/web_server/blueprints/authentication.py index 40fc0d0..26bfa1e 100644 --- a/web_server/blueprints/authentication.py +++ b/web_server/blueprints/authentication.py @@ -1,3 +1,8 @@ +"""Authentication blueprint for user signup, login, and logout.""" + +import logging +from secrets import token_hex + from flask import Blueprint, session, request, jsonify from werkzeug.security import generate_password_hash, check_password_hash from flask_cors import cross_origin @@ -5,18 +10,21 @@ from database.database import Database from blueprints.middleware import login_required from utils.user_utils import get_user_id from utils.utils import sanitize -from secrets import token_hex from utils.path_manager import PathManager auth_bp = Blueprint("auth", __name__) path_manager = PathManager() +logger = logging.getLogger(__name__) + + @auth_bp.route("/signup", methods=["POST"]) @cross_origin(supports_credentials=True) def signup(): """ - Route that allows a user to sign up by providing a `username`, `email` and `password`. + Route that allows a user to sign up by providing a + `username`, `email` and `password`. """ # ensure a JSON request is made to contact this route if not request.is_json: @@ -30,19 +38,21 @@ def signup(): # Validation - ensure all fields exist, users cannot have an empty field if not all([username, email, password]): - error_fields = get_error_fields([username, email, password]) #!←← find the error_fields, to highlight them in red to the user on the frontend + error_fields = get_error_fields( + [username, email, password] + ) return jsonify({ "account_created": False, "error_fields": error_fields, "message": "Missing required fields" }), 400 - + # Sanitize the inputs - helps to prevent SQL injection try: username = sanitize(username, "username") email = sanitize(email, "email") password = sanitize(password, "password") - except ValueError as e: + except ValueError: error_fields = get_error_fields([username, email, password]) return jsonify({ "account_created": False, @@ -81,7 +91,7 @@ def signup(): # Create new user once input is validated db.execute( - """INSERT INTO users + """INSERT INTO users (username, password, email, stream_key) VALUES (?, ?, ?, ?)""", ( @@ -100,16 +110,16 @@ def signup(): "message": "Account created successfully" }), 201 - except Exception as e: - print(f"Error during signup: {e}") # Log the error + except (ValueError, TypeError, KeyError) as exc: + logger.error("Error during signup: %s", exc) return jsonify({ "account_created": False, - "message": "Server error occurred: " + str(e) + "message": "Server error occurred: " + str(exc) }), 500 finally: db.close_connection() - + @auth_bp.route("/login", methods=["POST"]) @cross_origin(supports_credentials=True) @@ -127,24 +137,24 @@ def login(): username = data.get('username') password = data.get('password') - # Validation - ensure all fields exist, users cannot have an empty field + # Validation - ensure all fields exist, users cannot have an empty field if not all([username, password]): return jsonify({ "logged_in": False, "message": "Missing required fields" }), 400 - + # Sanitize the inputs - helps to prevent SQL injection try: username = sanitize(username, "username") password = sanitize(password, "password") - except ValueError as e: + except ValueError: return jsonify({ "account_created": False, "error_fields": ["username", "password"], "message": "Invalid input received" }), 400 - + # Create a connection to the database db = Database() @@ -169,7 +179,7 @@ def login(): "error_fields": ["username", "password"], "message": "Invalid username or password" }), 401 - + # Add user directories for stream data in case they don't exist path_manager.create_user(username) @@ -177,7 +187,10 @@ def login(): session.clear() session["username"] = username session["user_id"] = get_user_id(username) - print(f"Logged in as {username}. session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True) + logger.info( + "Logged in as %s. session: %s. user_id: %s", + username, session.get('username'), session.get('user_id') + ) # User has been logged in, let frontend know that return jsonify({ @@ -186,8 +199,8 @@ def login(): "username": username }), 200 - except Exception as e: - print(f"Error during login: {e}") # Log the error + except (ValueError, TypeError, KeyError) as exc: + logger.error("Error during login: %s", exc) return jsonify({ "logged_in": False, "message": "Server error occurred" @@ -202,31 +215,41 @@ def login(): def logout() -> dict: """ Log out and clear the users session. - + If the user is currently streaming, end their stream first. Can only be accessed by a logged in user. """ - from database.database import Database from utils.stream_utils import end_user_stream - + # Check if user is currently streaming user_id = session.get("user_id") username = session.get("username") - + with Database() as db: - is_streaming = db.fetchone("""SELECT is_live FROM users WHERE user_id = ?""", (user_id,)) - + is_streaming = db.fetchone( + """SELECT is_live FROM users WHERE user_id = ?""", + (user_id,) + ) + if is_streaming and is_streaming.get("is_live") == 1: # Get the user's stream key - stream_key_info = db.fetchone("""SELECT stream_key FROM users WHERE user_id = ?""", (user_id,)) - stream_key = stream_key_info.get("stream_key") if stream_key_info else None - + stream_key_info = db.fetchone( + """SELECT stream_key FROM users WHERE user_id = ?""", + (user_id,) + ) + stream_key = ( + stream_key_info.get("stream_key") if stream_key_info + else None + ) + if stream_key: # End the stream end_user_stream(stream_key, user_id, username) session.clear() return {"logged_in": False} + def get_error_fields(values: list): + """Return field names for empty values.""" fields = ["username", "email", "password"] - return [fields[i] for i, v in enumerate(values) if not v] \ No newline at end of file + return [fields[i] for i, v in enumerate(values) if not v] diff --git a/web_server/blueprints/chat.py b/web_server/blueprints/chat.py index 426a8c5..4a0bba0 100644 --- a/web_server/blueprints/chat.py +++ b/web_server/blueprints/chat.py @@ -1,14 +1,17 @@ +"""Chat blueprint for WebSocket-based real-time messaging.""" + +import json +from datetime import datetime + from flask import Blueprint, jsonify from database.database import Database from .socket import socketio from flask_socketio import emit, join_room, leave_room -from datetime import datetime -from utils.user_utils import get_user_id, is_subscribed +from utils.user_utils import is_subscribed import redis -import json -redis_url = "redis://redis:6379/1" -r = redis.from_url(redis_url, decode_responses=True) +REDIS_URL = "redis://redis:6379/1" +r = redis.from_url(REDIS_URL, decode_responses=True) chat_bp = Blueprint("chat", __name__) @socketio.on("connect") @@ -32,7 +35,7 @@ def handle_join(data) -> None: join_room(stream_id) num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id))) update_viewers(stream_id, num_viewers) - emit("status", + emit("status", { "message": f"Welcome to the chat, stream_id: {stream_id}", "num_viewers": num_viewers @@ -53,7 +56,7 @@ def handle_leave(data) -> None: remove_favourability_entry(str(data["user_id"]), str(stream_id)) num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id))) update_viewers(stream_id, num_viewers) - emit("status", + emit("status", { "message": f"Welcome to the chat, stream_id: {stream_id}", "num_viewers": num_viewers @@ -78,10 +81,10 @@ def get_past_chat(stream_id: int): all_chats = db.fetchall(""" SELECT user_id, username, message, time_sent, is_subscribed FROM ( - SELECT + SELECT u.user_id, - u.username, - c.message, + u.username, + c.message, c.time_sent, CASE WHEN s.user_id IS NOT NULL AND s.expires > CURRENT_TIMESTAMP THEN 1 @@ -101,8 +104,8 @@ def get_past_chat(stream_id: int): # Create JSON output of chat_history to pass through NGINX proxy chat_history = [{"chatter_id": chat["user_id"], - "chatter_username": chat["username"], - "message": chat["message"], + "chatter_username": chat["username"], + "message": chat["message"], "time_sent": chat["time_sent"], "is_subscribed": bool(chat["is_subscribed"])} for chat in all_chats] print(chat_history) @@ -125,7 +128,13 @@ def send_chat(data) -> None: # Input validation - chatter is logged in, message is not empty, stream exists if not all([chatter_name, message, stream_id]): - emit("error", {"error": f"Unable to send a chat. The following info was given: chatter_name={chatter_name}, message={message}, stream_id={stream_id}"}, broadcast=False) + emit("error", { + "error": ( + f"Unable to send a chat. The following info was given: " + f"chatter_name={chatter_name}, message={message}, " + f"stream_id={stream_id}" + ) + }, broadcast=False) return subscribed = is_subscribed(chatter_id, stream_id) # Send the chat message to the client so it can be displayed @@ -161,9 +170,8 @@ def update_viewers(user_id, num_viewers): SET num_viewers = ? WHERE user_id = ?; """, (num_viewers, user_id)) - db.close_connection - -#TODO: Make sure that users entry within Redis is removed if they disconnect from socket + db.close_connection() + def add_favourability_entry(user_id, stream_id): """ Adds entry to Redis that user is watching a streamer @@ -183,7 +191,7 @@ def add_favourability_entry(user_id, stream_id): else: # Creates new entry for user and stream current_viewers[user_id] = [stream_id] - + r.hset("current_viewers", "viewers", json.dumps(current_viewers)) def remove_favourability_entry(user_id, stream_id): @@ -202,9 +210,9 @@ def remove_favourability_entry(user_id, stream_id): if user_id in current_viewers: # Removes specific stream from user current_viewers[user_id] = [stream for stream in current_viewers[user_id] if stream != stream_id] - + # If user is no longer watching any streams if not current_viewers[user_id]: del current_viewers[user_id] - r.hset("current_viewers", "viewers", json.dumps(current_viewers)) \ No newline at end of file + r.hset("current_viewers", "viewers", json.dumps(current_viewers)) diff --git a/web_server/blueprints/middleware.py b/web_server/blueprints/middleware.py index f89c53e..1a0a37c 100644 --- a/web_server/blueprints/middleware.py +++ b/web_server/blueprints/middleware.py @@ -1,7 +1,10 @@ -from flask import redirect, g, session -from functools import wraps +"""Authentication middleware and error handler registration.""" + import logging +from functools import wraps from os import getenv + +from flask import redirect, g, session from dotenv import load_dotenv from database.database import Database @@ -57,5 +60,5 @@ def register_error_handlers(app): for code, message in error_responses.items(): @app.errorhandler(code) def handle_error(error, message=message, code=code): - logging.error(f"Error {code}: {str(error)}") - return {"error": message}, code \ No newline at end of file + logging.error("Error %d: %s", code, str(error)) + return {"error": message}, code diff --git a/web_server/blueprints/oauth.py b/web_server/blueprints/oauth.py index eaeff9b..0b906c8 100644 --- a/web_server/blueprints/oauth.py +++ b/web_server/blueprints/oauth.py @@ -1,15 +1,18 @@ +"""OAuth blueprint for Google authentication.""" + from os import getenv +from secrets import token_hex, token_urlsafe +from random import randint + from authlib.integrations.flask_client import OAuth, OAuthError from flask import Blueprint, jsonify, session, redirect, request from blueprints.user import get_session_info_email from database.database import Database from dotenv import load_dotenv -from secrets import token_hex, token_urlsafe -from random import randint from utils.path_manager import PathManager oauth_bp = Blueprint("oauth", __name__) -google = None +_google = None load_dotenv() url_api = getenv("VITE_API_URL") @@ -23,8 +26,8 @@ def init_oauth(app): Initialise the OAuth functionality. """ oauth = OAuth(app) - global google - google = oauth.register( + global _google # pylint: disable=global-statement + _google = oauth.register( 'google', client_id=app.config['GOOGLE_CLIENT_ID'], client_secret=app.config['GOOGLE_CLIENT_SECRET'], @@ -50,11 +53,11 @@ def login_google(): session["nonce"] = token_urlsafe(16) session["state"] = token_urlsafe(32) session["origin"] = request.args.get("next") - + # Make sure session is saved before redirect session.modified = True - - return google.authorize_redirect( + + return _google.authorize_redirect( redirect_uri=f'{url}/api/google_auth', nonce=session['nonce'], state=session['state'] @@ -70,23 +73,27 @@ def google_auth(): # Check state parameter before authorizing returned_state = request.args.get('state') stored_state = session.get('state') - + if not stored_state or stored_state != returned_state: - print(f"State mismatch: stored={stored_state}, returned={returned_state}", flush=True) + print( + f"State mismatch: stored={stored_state}, " + f"returned={returned_state}", flush=True + ) return jsonify({ - 'error': f"mismatching_state: CSRF Warning! State not equal in request and response.", + 'error': "mismatching_state: CSRF Warning! " + "State not equal in request and response.", 'message': 'Authentication failed' }), 400 - + # State matched, proceed with token authorization - token = google.authorize_access_token() + token = _google.authorize_access_token() # Verify nonce nonce = session.get('nonce') if not nonce: return jsonify({'error': 'Missing nonce in session'}), 400 - user = google.parse_id_token(token, nonce=nonce) + user = _google.parse_id_token(token, nonce=nonce) # Check if email exists to login else create a database entry user_email = user.get("email") @@ -108,7 +115,7 @@ def google_auth(): break db.execute( - """INSERT INTO users + """INSERT INTO users (username, email, stream_key) VALUES (?, ?, ?)""", ( @@ -124,16 +131,19 @@ def google_auth(): origin = session.get("origin", f"{url.replace('/api', '')}") username = user_data["username"] user_id = user_data["user_id"] - + # Clear session and set new data session.clear() session["username"] = username session["user_id"] = user_id - + # Ensure session is saved session.modified = True - - print(f"session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True) + + print( + f"session: {session.get('username')}. " + f"user_id: {session.get('user_id')}", flush=True + ) return redirect(origin) @@ -144,9 +154,9 @@ def google_auth(): 'error': str(e) }), 400 - except Exception as e: + except (ValueError, TypeError, KeyError) as e: print(f"Unexpected Error: {str(e)}", flush=True) return jsonify({ 'message': 'An unexpected error occurred', 'error': str(e) - }), 500 \ No newline at end of file + }), 500 diff --git a/web_server/blueprints/search_bar.py b/web_server/blueprints/search_bar.py index d80d3e0..00c6ee9 100644 --- a/web_server/blueprints/search_bar.py +++ b/web_server/blueprints/search_bar.py @@ -1,3 +1,5 @@ +"""Search bar blueprint for querying categories, users, streams, and VODs.""" + from flask import Blueprint, jsonify, request from database.database import Database from utils.utils import sanitize @@ -20,7 +22,7 @@ def rank_results(query, result): # Assign a score based on the level of the match if query in result: return 0 - elif all(c in charset for c in query): + if all(c in charset for c in query): return 1 return 2 @@ -50,7 +52,7 @@ def search_results(): res_dict.append(c) categories = sorted(res_dict, key=lambda d: d["score"]) categories = categories[:4] - + # 3 users res_dict = [] users = db.fetchall("SELECT user_id, username, is_live FROM users") @@ -63,7 +65,7 @@ def search_results(): users = sorted(res_dict, key=lambda d: d["score"]) users = users[:4] - # 3 streams + # 3 streams res_dict = [] streams = db.fetchall("""SELECT s.user_id, s.title, s.num_viewers, c.category_name, u.username FROM streams AS s @@ -71,7 +73,7 @@ def search_results(): INNER JOIN users AS u ON s.user_id = u.user_id INNER JOIN categories AS c ON s.category_id = c.category_id """) - + for s in streams: key = s.get("title") score = rank_results(query.lower(), key.lower()) @@ -83,7 +85,7 @@ def search_results(): # 3 VODs res_dict = [] - vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username + vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username FROM vods as v JOIN users as u ON v.user_id = u.user_id""") for v in vods: @@ -98,5 +100,5 @@ def search_results(): db.close_connection() print(query, streams, users, categories, vods, flush=True) - - return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods}) \ No newline at end of file + + return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods}) diff --git a/web_server/blueprints/socket.py b/web_server/blueprints/socket.py index 9932947..8dfe2b6 100644 --- a/web_server/blueprints/socket.py +++ b/web_server/blueprints/socket.py @@ -1,10 +1,12 @@ +"""WebSocket configuration using Flask-SocketIO.""" + from flask_socketio import SocketIO socketio = SocketIO( - cors_allowed_origins="*", + cors_allowed_origins="*", async_mode='gevent', logger=False, # Reduce logging engineio_logger=False, # Reduce logging ping_timeout=5000, ping_interval=25000 -) \ No newline at end of file +) diff --git a/web_server/blueprints/streams.py b/web_server/blueprints/streams.py index 96532fc..782df61 100644 --- a/web_server/blueprints/streams.py +++ b/web_server/blueprints/streams.py @@ -1,15 +1,24 @@ +"""Stream and VOD management blueprint.""" + +import json +from datetime import datetime + from flask import Blueprint, session, jsonify, request, redirect -from utils.stream_utils import * -from utils.recommendation_utils import * +from utils.stream_utils import ( + get_category_id, get_current_stream_data, get_streamer_live_status, + get_latest_vod, get_vod, get_user_vods, end_user_stream +) +from utils.recommendation_utils import ( + get_user_preferred_category, get_highest_view_streams, + get_streams_based_on_category, get_highest_view_categories, + get_user_category_recommendations, get_followed_categories_recommendations +) from utils.user_utils import get_user_id from blueprints.middleware import login_required from database.database import Database -from datetime import datetime -from celery_tasks.streaming import update_thumbnail, combine_ts_stream -from dateutil import parser +from celery_tasks.streaming import update_thumbnail from utils.path_manager import PathManager from PIL import Image -import json stream_bp = Blueprint("stream", __name__) @@ -28,12 +37,12 @@ def popular_streams(no_streams) -> list[dict]: Returns a list of streams live now with the highest viewers """ - # Limit the number of streams to MAX_STREAMS - MAX_STREAMS = 100 + # Limit the number of streams to max_streams + max_streams = 100 if no_streams < 1: return jsonify([]) - elif no_streams > MAX_STREAMS: - no_streams = MAX_STREAMS + if no_streams > max_streams: + no_streams = max_streams # Get the highest viewed streams streams = get_highest_view_streams(no_streams) @@ -101,7 +110,7 @@ def popular_categories(no_categories=4, offset=0) -> list[dict]: # Limit the number of categories to 100 if no_categories < 1: return jsonify([]) - elif no_categories > 100: + if no_categories > 100: no_categories = 100 category_data = get_highest_view_categories(no_categories, offset) @@ -135,7 +144,8 @@ def following_categories_streams(): @stream_bp.route('/user//status') def user_live_status(username): """ - Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live) + Returns a streamer's status, if they are live or not and their + most recent stream (as a vod) (their current stream if live) Returns: { "is_live": bool, @@ -146,7 +156,7 @@ def user_live_status(username): user_id = get_user_id(username) # Check if streamer is live and get their most recent vod - is_live = True if get_streamer_live_status(user_id)['is_live'] else False + is_live = bool(get_streamer_live_status(user_id)['is_live']) most_recent_vod = get_latest_vod(user_id) # If there is no most recent vod, set it to None @@ -171,25 +181,24 @@ def user_live_status_direct(username): """ user_id = get_user_id(username) - is_live = True if get_streamer_live_status(user_id)['is_live'] else False + is_live = bool(get_streamer_live_status(user_id)['is_live']) if is_live: return 'ok', 200 - else: - return 'not live', 404 + return 'not live', 404 # VOD Routes @stream_bp.route('/vods/') -def vod(vod_id): +def single_vod(vod_id): """ Returns details about a specific vod """ - vod = get_vod(vod_id) - return jsonify(vod) + vod_data = get_vod(vod_id) + return jsonify(vod_data) @stream_bp.route('/vods/') -def vods(username): +def user_vods(username): """ Returns a JSON of all the vods of a streamer Returns: @@ -204,11 +213,11 @@ def vods(username): "views": int } ] - + """ user_id = get_user_id(username) - vods = get_user_vods(user_id) - return jsonify(vods) + vod_list = get_user_vods(user_id) + return jsonify(vod_list) @stream_bp.route('/vods/all') def get_all_vods(): @@ -216,9 +225,14 @@ def get_all_vods(): Returns data of all VODs by all streamers in a JSON-compatible format """ with Database() as db: - vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id;""") - - return jsonify(vods) + all_vods = db.fetchall( + """SELECT vods.*, username, category_name + FROM vods + JOIN users ON vods.user_id = users.user_id + JOIN categories ON vods.category_id = categories.category_id;""" + ) + + return jsonify(all_vods) # RTMP Server Routes @@ -232,9 +246,10 @@ def init_stream(): with Database() as db: # Check if valid stream key and user is allowed to stream - user_info = db.fetchone("""SELECT user_id, username, is_live - FROM users - WHERE stream_key = ?""", (stream_key,)) + user_info = db.fetchone( + """SELECT user_id, username, is_live + FROM users + WHERE stream_key = ?""", (stream_key,)) # No user found from stream key if not user_info: @@ -280,25 +295,27 @@ def publish_stream(): username = None with Database() as db: - user_info = db.fetchone("""SELECT user_id, username, is_live - FROM users - WHERE stream_key = ?""", (stream_key,)) + user_info = db.fetchone( + """SELECT user_id, username, is_live + FROM users + WHERE stream_key = ?""", (stream_key,)) if not user_info or user_info.get("is_live"): print( - "Unauthorized. No user found from Stream key or user is already streaming.", flush=True) + "Unauthorized. No user found from Stream key " + "or user is already streaming.", flush=True) return "Unauthorized", 403 user_id = user_info.get("user_id") username = user_info.get("username") # Insert stream into database - db.execute("""INSERT INTO streams (user_id, title, start_time, num_viewers, category_id) - VALUES (?, ?, ?, ?, ?)""", (user_id, - stream_title, - datetime.now(), - 0, - get_category_id(stream_category))) + db.execute( + """INSERT INTO streams + (user_id, title, start_time, num_viewers, category_id) + VALUES (?, ?, ?, ?, ?)""", + (user_id, stream_title, datetime.now(), 0, + get_category_id(stream_category))) # Set user as streaming db.execute("""UPDATE users SET is_live = 1 WHERE user_id = ?""", @@ -306,10 +323,12 @@ def publish_stream(): # Update thumbnail periodically only if a custom thumbnail is not provided if not stream_thumbnail: - update_thumbnail.apply_async((user_id, - path_manager.get_stream_file_path(username), - path_manager.get_current_stream_thumbnail_file_path(username), - THUMBNAIL_GENERATION_INTERVAL), countdown=10) + update_thumbnail.apply_async( + (user_id, + path_manager.get_stream_file_path(username), + path_manager.get_current_stream_thumbnail_file_path(username), + THUMBNAIL_GENERATION_INTERVAL), + countdown=10) return "OK", 200 @@ -332,31 +351,36 @@ def update_stream(): with Database() as db: - user_info = db.fetchone("""SELECT user_id, username, is_live - FROM users - WHERE stream_key = ?""", (stream_key,)) + user_info = db.fetchone( + """SELECT user_id, username, is_live + FROM users + WHERE stream_key = ?""", (stream_key,)) if not user_info or not user_info.get("is_live"): print( - "Unauthorized - No user found from stream key or user is not streaming", flush=True) + "Unauthorized - No user found from stream key " + "or user is not streaming", flush=True) return "Unauthorized", 403 user_id = user_info.get("user_id") username = user_info.get("username") # TODO: Add update to thumbnail here - db.execute("""UPDATE streams - SET title = ?, category_id = ? - WHERE user_id = ?""", (stream_title, get_category_id(stream_category), user_id)) - + db.execute( + """UPDATE streams + SET title = ?, category_id = ? + WHERE user_id = ?""", + (stream_title, get_category_id(stream_category), user_id)) + print("GOT: " + stream_thumbnail, flush=True) - + if stream_thumbnail: # Set custom thumbnail status to true - db.execute("""UPDATE streams - SET custom_thumbnail = ? - WHERE user_id = ?""", (True, user_id)) - + db.execute( + """UPDATE streams + SET custom_thumbnail = ? + WHERE user_id = ?""", (True, user_id)) + # Get thumbnail path thumbnail_path = path_manager.get_current_stream_thumbnail_file_path(username) @@ -390,26 +414,27 @@ def end_stream(): # Get user info from stream key with Database() as db: - user_info = db.fetchone("""SELECT user_id, username - FROM users - WHERE stream_key = ?""", (stream_key,)) - + user_info = db.fetchone( + """SELECT user_id, username + FROM users + WHERE stream_key = ?""", (stream_key,)) + # Return unauthorized if no user found if not user_info: print("Unauthorized - No user found from stream key", flush=True) return "Unauthorized", 403 - + # Get user info user_id = user_info["user_id"] username = user_info["username"] - + # End stream result, message = end_user_stream(stream_key, user_id, username) - + # Return error if stream could not be ended if not result: print(f"Error ending stream: {message}", flush=True) return "Error ending stream", 500 - + print(f"Stream ended: {message}", flush=True) return "Stream ended", 200 diff --git a/web_server/blueprints/stripe.py b/web_server/blueprints/stripe.py index fad186c..fd22c2a 100644 --- a/web_server/blueprints/stripe.py +++ b/web_server/blueprints/stripe.py @@ -1,7 +1,11 @@ +"""Stripe payment integration blueprint.""" + +import os + from flask import Blueprint, request, jsonify, session as s from blueprints.middleware import login_required from utils.user_utils import subscribe -import os, stripe +import stripe stripe_bp = Blueprint("stripe", __name__) @@ -20,7 +24,7 @@ def create_checkout_session(): # Checks to see who is subscribing to who user_id = s.get("user_id") streamer_id = request.args.get("streamer_id") - session = stripe.checkout.Session.create( + checkout_session = stripe.checkout.Session.create( ui_mode = 'embedded', payment_method_types=['card'], line_items=[ @@ -33,19 +37,24 @@ def create_checkout_session(): redirect_on_completion = 'never', client_reference_id = f"{user_id}-{streamer_id}" ) - except Exception as e: + except (ValueError, TypeError, KeyError) as e: return str(e), 500 - return jsonify(clientSecret=session.client_secret) + return jsonify(clientSecret=checkout_session.client_secret) -@stripe_bp.route('/session-status') # check for payment status +@stripe_bp.route('/session-status') # check for payment status def session_status(): - """ - Used to query payment status - """ - session = stripe.checkout.Session.retrieve(request.args.get('session_id')) + """ + Used to query payment status + """ + checkout_session = stripe.checkout.Session.retrieve( + request.args.get('session_id') + ) - return jsonify(status=session.status, customer_email=session.customer_details.email) + return jsonify( + status=checkout_session.status, + customer_email=checkout_session.customer_details.email + ) @stripe_bp.route('/stripe/webhook', methods=['POST']) def stripe_webhook(): @@ -65,13 +74,20 @@ def stripe_webhook(): raise e except stripe.error.SignatureVerificationError as e: raise e - - if event['type'] == "checkout.session.completed": # Handles payment success webhook - session = event['data']['object'] - product_id = stripe.checkout.Session.list_line_items(session['id'])['data'][0]['price']['product'] + + # Handles payment success webhook + if event['type'] == "checkout.session.completed": + checkout_session = event['data']['object'] + product_id = stripe.checkout.Session.list_line_items( + checkout_session['id'] + )['data'][0]['price']['product'] if product_id == subscription: - client_reference_id = session.get("client_reference_id") - user_id, streamer_id = map(int, client_reference_id.split("-")) + client_reference_id = checkout_session.get( + "client_reference_id" + ) + user_id, streamer_id = map( + int, client_reference_id.split("-") + ) subscribe(user_id, streamer_id) - return "Success", 200 \ No newline at end of file + return "Success", 200 diff --git a/web_server/blueprints/user.py b/web_server/blueprints/user.py index 6cf2d00..27883be 100644 --- a/web_server/blueprints/user.py +++ b/web_server/blueprints/user.py @@ -1,17 +1,28 @@ +"""User profile and account management blueprint.""" + from flask import Blueprint, jsonify, session, request -from utils.user_utils import * -from utils.auth import * +from utils.user_utils import ( + get_user_id, get_user, update_bio, is_subscribed, + subscription_expiration, delete_subscription, is_following, + follow, unfollow, get_followed_streamers, get_followed_categories, + is_following_category, follow_category, unfollow_category, + has_password, get_session_info_email +) +from database.database import Database +from utils.auth import verify_token, reset_password from utils.utils import get_category_id from blueprints.middleware import login_required -from utils.email import send_email, forgot_password_body, newsletter_conf, remove_from_newsletter, email_exists +from utils.email import ( + send_email, forgot_password_body, newsletter_conf, + remove_from_newsletter, email_exists +) from utils.path_manager import PathManager -from celery_tasks.streaming import convert_image_to_png import redis from PIL import Image -redis_url = "redis://redis:6379/1" -r = redis.from_url(redis_url, decode_responses=True) +REDIS_URL = "redis://redis:6379/1" +r = redis.from_url(REDIS_URL, decode_responses=True) user_bp = Blueprint("user", __name__) @@ -24,7 +35,7 @@ def user_data(username: str): """ user_id = get_user_id(username) if not user_id: - jsonify({"error": "User not found from username"}), 404 + return jsonify({"error": "User not found from username"}), 404 data = get_user(user_id) return jsonify(data) @@ -48,11 +59,11 @@ def user_profile_picture_save(): """ username = session.get("username") thumbnail_path = path_manager.get_profile_picture_file_path(username) - + # Check if the post request has the file part if 'image' not in request.files: return jsonify({"error": "No image found in request"}), 400 - + # Fetch image, convert to png, and save image = Image.open(request.files['image']) image.convert('RGB') @@ -73,7 +84,7 @@ def user_change_bio(): bio = data.get("bio") update_bio(user_id, bio) return jsonify({"status": "Success"}), 200 - except Exception as e: + except (ValueError, TypeError, KeyError) as e: return jsonify({"error": str(e)}), 400 ## Subscription Routes @@ -186,7 +197,7 @@ def user_login_status(): """ username = session.get("username") user_id = session.get("user_id") - return jsonify({'status': username is not None, + return jsonify({'status': username is not None, 'username': username, 'user_id': user_id}) @@ -198,13 +209,14 @@ def user_forgot_password(email): exists = email_exists(email) password = has_password(email) # Checks if password exists and is not a Google OAuth account - if(exists and password): + if exists and password: send_email(email, lambda: forgot_password_body(email)) return email - return jsonify({"error":"Invalid email or not found"}), 404 + return jsonify({"error": "Invalid email or not found"}), 404 @user_bp.route("/send_newsletter/", methods=["POST"]) def send_newsletter(email): + """Sends a newsletter confirmation email.""" send_email(email, lambda: newsletter_conf(email)) return email @@ -226,6 +238,7 @@ def user_reset_password(token, new_password): @user_bp.route("/user/unsubscribe/", methods=["POST"]) def unsubscribe(token): + """Unsubscribes a user from the newsletter.""" salt = r.get(token) if salt: r.delete(token) @@ -237,4 +250,4 @@ def unsubscribe(token): if email: remove_from_newsletter(email) return jsonify({"message": "unsubscribed from newsletter"}), 200 - return jsonify({"error": "Invalid token"}), 400 \ No newline at end of file + return jsonify({"error": "Invalid token"}), 400 diff --git a/web_server/celery_tasks/__init__.py b/web_server/celery_tasks/__init__.py index 9ea551f..eb3e840 100644 --- a/web_server/celery_tasks/__init__.py +++ b/web_server/celery_tasks/__init__.py @@ -1,7 +1,12 @@ -from celery import Celery, shared_task, Task +"""Celery configuration and Flask app context setup for async tasks.""" + +from celery import Celery, Task + def celery_init_app(app) -> Celery: + """Initialize Celery with Flask application context.""" class FlaskTask(Task): + """Celery task that runs within Flask app context.""" def __call__(self, *args: object, **kwargs: object) -> object: with app.app_context(): return self.run(*args, **kwargs) @@ -14,6 +19,10 @@ def celery_init_app(app) -> Celery: 'schedule': 30.0, }, } + celery_app.conf.include = [ + 'celery_tasks.preferences', + 'celery_tasks.streaming', + ] celery_app.set_default() app.extensions["celery"] = celery_app return celery_app diff --git a/web_server/celery_tasks/celery_app.py b/web_server/celery_tasks/celery_app.py index 36c1c79..b8d80b8 100644 --- a/web_server/celery_tasks/celery_app.py +++ b/web_server/celery_tasks/celery_app.py @@ -1,4 +1,6 @@ +"""Celery app initialization with Flask.""" + from blueprints import create_app flask_app = create_app() -celery_app = flask_app.extensions["celery"] \ No newline at end of file +celery_app = flask_app.extensions["celery"] diff --git a/web_server/celery_tasks/preferences.py b/web_server/celery_tasks/preferences.py index 0d0d10d..cea23f2 100644 --- a/web_server/celery_tasks/preferences.py +++ b/web_server/celery_tasks/preferences.py @@ -1,15 +1,19 @@ +"""Scheduled task for updating user preferences based on stream viewing.""" + +import json + from celery import shared_task from database.database import Database import redis -import json -redis_url = "redis://redis:6379/1" -r = redis.from_url(redis_url, decode_responses=True) +REDIS_URL = "redis://redis:6379/1" +r = redis.from_url(REDIS_URL, decode_responses=True) @shared_task def user_preferences(): """ - Updates users preferences on different stream categories based on the streams they are currently watching + Updates users preferences on different stream categories + based on the streams they are currently watching """ stats = r.hget("current_viewers", "viewers") # If there are any current viewers @@ -21,13 +25,19 @@ def user_preferences(): # For each user and stream combination for stream_id in stream_ids: # Retrieves category associated with stream - current_category = db.fetchone("""SELECT category_id FROM streams - WHERE user_id = ? - """, (stream_id,)) - # If stream is still live then update the user_preferences table to reflect their preferences + current_category = db.fetchone( + """SELECT category_id FROM streams + WHERE user_id = ? + """, (stream_id,)) + # If stream is still live then update the + # user_preferences table to reflect their preferences if current_category: - db.execute("""INSERT INTO user_preferences (user_id,category_id,favourability) - VALUES (?,?,?) - ON CONFLICT(user_id, category_id) - DO UPDATE SET favourability = favourability + 1 - """, (user_id, current_category["category_id"], 1)) \ No newline at end of file + db.execute( + """INSERT INTO user_preferences + (user_id,category_id,favourability) + VALUES (?,?,?) + ON CONFLICT(user_id, category_id) + DO UPDATE SET + favourability = favourability + 1 + """, (user_id, + current_category["category_id"], 1)) diff --git a/web_server/celery_tasks/streaming.py b/web_server/celery_tasks/streaming.py index c1d4e4c..42f4fab 100644 --- a/web_server/celery_tasks/streaming.py +++ b/web_server/celery_tasks/streaming.py @@ -1,11 +1,14 @@ -from celery import Celery, shared_task, Task -from datetime import datetime -from celery_tasks.preferences import user_preferences -from utils.stream_utils import generate_thumbnail, get_streamer_live_status, get_custom_thumbnail_status, remove_hls_files, get_video_duration -from time import sleep -from os import listdir, remove -from utils.path_manager import PathManager +"""Async tasks for stream thumbnail updates, VOD creation, and image conversion.""" + import subprocess +from os import listdir + +from celery import shared_task +from utils.stream_utils import ( + generate_thumbnail, get_streamer_live_status, + get_custom_thumbnail_status, remove_hls_files, get_video_duration +) +from utils.path_manager import PathManager path_manager = PathManager() @@ -16,10 +19,13 @@ def update_thumbnail(user_id, stream_file, thumbnail_file, sleep_time, second_ca """ # Check if stream is still live and custom thumbnail has not been set - if get_streamer_live_status(user_id)['is_live'] and not get_custom_thumbnail_status(user_id)['custom_thumbnail']: + if (get_streamer_live_status(user_id)['is_live'] + and not get_custom_thumbnail_status(user_id)['custom_thumbnail']): print("Updating thumbnail...") generate_thumbnail(stream_file, thumbnail_file) - update_thumbnail.apply_async((user_id, stream_file, thumbnail_file, sleep_time, second_capture), countdown=sleep_time) + update_thumbnail.apply_async( + (user_id, stream_file, thumbnail_file, sleep_time, second_capture), + countdown=sleep_time) else: print(f"Stopping thumbnail updates for stream of {user_id}") @@ -34,12 +40,12 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) -> ts_files.sort() # Create temp file listing all ts files - with open(f"{stream_path}/list.txt", "w") as f: + with open(f"{stream_path}/list.txt", "w", encoding="utf-8") as f: for ts_file in ts_files: f.write(f"file '{ts_file}'\n") - + # Concatenate all ts files into a single vod - + vod_command = [ "ffmpeg", "-f", @@ -53,7 +59,7 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) -> vod_file_path ] - subprocess.run(vod_command) + subprocess.run(vod_command, check=True) # Remove HLS files, even if user is not streaming remove_hls_files(stream_path) @@ -78,4 +84,4 @@ def convert_image_to_png(image_path, png_path): png_path ] - subprocess.run(image_command) \ No newline at end of file + subprocess.run(image_command, check=True) diff --git a/web_server/database/database.py b/web_server/database/database.py index 0f31435..7e2fd5a 100644 --- a/web_server/database/database.py +++ b/web_server/database/database.py @@ -1,7 +1,12 @@ +"""SQLite database connection management with context manager support.""" + import sqlite3 import os + class Database: + """Database wrapper providing connection management and query execution.""" + def __init__(self) -> None: self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db") self._conn = None @@ -63,4 +68,6 @@ class Database: if not result: return [] columns = [desc[0] for desc in self.cursor.description] - return [dict(zip(columns, row)) for row in result] if isinstance(result, list) else dict(zip(columns, result)) + if isinstance(result, list): + return [dict(zip(columns, row)) for row in result] + return dict(zip(columns, result)) diff --git a/web_server/utils/admin_utils.py b/web_server/utils/admin_utils.py index 1e6999a..8419307 100644 --- a/web_server/utils/admin_utils.py +++ b/web_server/utils/admin_utils.py @@ -1,8 +1,10 @@ +"""Admin utility functions for user management.""" + from database.database import Database def check_if_admin(username): """ - Returns whether user is admin + Returns whether user is admin """ with Database() as db: is_admin = db.fetchone(""" @@ -10,7 +12,7 @@ def check_if_admin(username): FROM users WHERE username = ?; """, (username,)) - + return bool(is_admin) def check_if_user_exists(banned_user): @@ -19,11 +21,11 @@ def check_if_user_exists(banned_user): """ with Database() as db: user_exists = db.fetchone(""" - SELECT user_id + SELECT user_id FROM users WHERE username = ?;""", (banned_user,)) - + return bool(user_exists) def ban_user(banned_user): @@ -33,6 +35,5 @@ def ban_user(banned_user): with Database() as db: db.execute(""" DELETE FROM users - WHERE username = ?;""", - (banned_user) - ) \ No newline at end of file + WHERE username = ?;""", + (banned_user,)) diff --git a/web_server/utils/auth.py b/web_server/utils/auth.py index ab05beb..e3e1adc 100644 --- a/web_server/utils/auth.py +++ b/web_server/utils/auth.py @@ -1,13 +1,17 @@ +"""Token generation and verification for password resets.""" + +from typing import Optional +from os import getenv + from database.database import Database from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired -from typing import Optional from dotenv import load_dotenv -from os import getenv from werkzeug.security import generate_password_hash load_dotenv() serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY")) + def generate_token(email, salt_value) -> str: """ Creates a token for password reset @@ -19,7 +23,6 @@ def verify_token(token: str, salt_value) -> Optional[str]: """ Given a token, verifies and decodes it into an email """ - try: email = serializer.loads(token, salt=salt_value, max_age=3600) return email @@ -38,7 +41,7 @@ def reset_password(new_password: str, email: str): """ with Database() as db: db.execute(""" - UPDATE users - SET password = ? + UPDATE users + SET password = ? WHERE email = ? """, (generate_password_hash(new_password), email)) diff --git a/web_server/utils/email.py b/web_server/utils/email.py index 73543f3..0b05a24 100644 --- a/web_server/utils/email.py +++ b/web_server/utils/email.py @@ -1,16 +1,18 @@ +"""Email sending utilities for password reset, account confirmation, and newsletters.""" + import smtplib from email.mime.text import MIMEText - from os import getenv +from secrets import token_hex + from dotenv import load_dotenv from utils.auth import generate_token -from secrets import token_hex from .user_utils import get_session_info_email import redis from database.database import Database -redis_url = "redis://redis:6379/1" -r = redis.from_url(redis_url, decode_responses=True) +REDIS_URL = "redis://redis:6379/1" +r = redis.from_url(REDIS_URL, decode_responses=True) load_dotenv() @@ -23,31 +25,31 @@ def send_email(email, func) -> None: """ # Setup the sender email details - SMTP_SERVER = "smtp.gmail.com" - SMTP_PORT = 587 - SMTP_EMAIL = getenv("EMAIL") - SMTP_PASSWORD = getenv("EMAIL_PASSWORD") + smtp_server = "smtp.gmail.com" + smtp_port = 587 + smtp_email = getenv("EMAIL") + smtp_password = getenv("EMAIL_PASSWORD") # Setup up the receiver details body, subject = func() msg = MIMEText(body, "html") msg["Subject"] = subject - msg["From"] = SMTP_EMAIL + msg["From"] = smtp_email msg["To"] = email # Send the email using smtplib - with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp: + with smtplib.SMTP(smtp_server, smtp_port) as smtp: try: smtp.starttls() # TLS handshake to start the connection - smtp.login(SMTP_EMAIL, SMTP_PASSWORD) + smtp.login(smtp_email, smtp_password) smtp.ehlo() smtp.send_message(msg) - + except TimeoutError: print("Server timed out", flush=True) - except Exception as e: + except smtplib.SMTPException as e: print("Error: ", e, flush=True) def forgot_password_body(email) -> str: @@ -57,7 +59,7 @@ def forgot_password_body(email) -> str: salt = token_hex(32) token = generate_token(email, salt) - token += "R3sET" + token += "R3sET" r.setex(token, 3600, salt) username = (get_session_info_email(email))["username"] @@ -77,7 +79,7 @@ def forgot_password_body(email) -> str:
-

Gander

+

Gander

Password Reset Request

Click the button below to reset your password for your account {username}. This link is valid for 1 hour.

Reset Password @@ -116,7 +118,7 @@ def confirm_account_creation_body(email) -> str:
-

Gander

+

Gander

Confirm Account Creation

Click the button below to create your account. This link is valid for 1 hour.

Create Account @@ -154,7 +156,7 @@ def newsletter_conf(email):
-

Gander

+

Gander

Welcome to the Official Gander Newsletter!

If you are receiving this email, it means that you have been officially added to the Monthly Gander newsletter.

In this newsletter, you will receive updates about: your favourite streamers; important Gander updates; and more!

@@ -169,7 +171,7 @@ def newsletter_conf(email): user_exists = db.fetchone(""" SELECT * FROM newsletter - WHERE email = ?;""", + WHERE email = ?;""", (email,)) print(user_exists, flush=True) @@ -190,7 +192,7 @@ def add_to_newsletter(email): INSERT INTO newsletter (email) VALUES (?); """, (email,)) - + def remove_from_newsletter(email): """ Remove a person from the newsletter database @@ -202,7 +204,7 @@ def remove_from_newsletter(email): DELETE FROM newsletter WHERE email = ?; """, (email,)) - + def email_exists(email): """ Returns whether email exists within database @@ -210,6 +212,6 @@ def email_exists(email): with Database() as db: data = db.fetchone(""" SELECT * FROM users - WHERE email = ? + WHERE email = ? """, (email,)) - return bool(data) \ No newline at end of file + return bool(data) diff --git a/web_server/utils/path_manager.py b/web_server/utils/path_manager.py index efd7e08..2cf8ff8 100644 --- a/web_server/utils/path_manager.py +++ b/web_server/utils/path_manager.py @@ -1,8 +1,11 @@ -"""Description: This file contains the PathManager class which is responsible for managing the paths of the stream data.""" +"""File system path management for user streams, VODs, and profile pictures.""" import os + class PathManager(): + """Manages paths for user stream data, VODs, and profile pictures.""" + def __init__(self) -> None: self.root_path = "user_data" self.vod_directory_name = "vods" @@ -26,7 +29,7 @@ class PathManager(): self._create_directory(os.path.join(self.root_path, username)) vods_path = self.get_vods_path(username) stream_path = self.get_stream_path(username) - + self._create_directory(vods_path) self._create_directory(stream_path) @@ -39,25 +42,33 @@ class PathManager(): os.rmdir(user_path) def get_user_path(self, username): + """Returns the base path for a user's data directory.""" return os.path.join(self.root_path, username) def get_vods_path(self, username): + """Returns the path to a user's VODs directory.""" return os.path.join(self.root_path, username, self.vod_directory_name) - + def get_stream_path(self, username): + """Returns the path to a user's stream directory.""" return os.path.join(self.root_path, username, self.stream_directory_name) - + def get_stream_file_path(self, username): + """Returns the path to a user's stream index file.""" return os.path.join(self.get_stream_path(username), "index.m3u8") - + def get_current_stream_thumbnail_file_path(self, username): + """Returns the path to a user's current stream thumbnail.""" return os.path.join(self.get_stream_path(username), "index.png") - + def get_vod_file_path(self, username, vod_id): + """Returns the path to a specific VOD file.""" return os.path.join(self.get_vods_path(username), f"{vod_id}.mp4") - + def get_vod_thumbnail_file_path(self, username, vod_id): + """Returns the path to a specific VOD thumbnail.""" return os.path.join(self.get_vods_path(username), f"{vod_id}.png") - + def get_profile_picture_file_path(self, username): - return os.path.join(self.root_path, username, self.profile_picture_name) \ No newline at end of file + """Returns the path to a user's profile picture.""" + return os.path.join(self.root_path, username, self.profile_picture_name) diff --git a/web_server/utils/recommendation_utils.py b/web_server/utils/recommendation_utils.py index 4cebff0..17d9d9b 100644 --- a/web_server/utils/recommendation_utils.py +++ b/web_server/utils/recommendation_utils.py @@ -1,32 +1,42 @@ -from database.database import Database +"""Personalized stream and category recommendations based on user preferences.""" + from typing import Optional, List +from database.database import Database + def get_user_preferred_category(user_id: int) -> Optional[int]: """ - Queries user_preferences database to find users favourite streaming category and returns the category + Queries user_preferences database to find users favourite + streaming category and returns the category """ with Database() as db: category = db.fetchone(""" - SELECT category_id - FROM user_preferences - WHERE user_id = ? - ORDER BY favourability DESC + SELECT category_id + FROM user_preferences + WHERE user_id = ? + ORDER BY favourability DESC LIMIT 1 """, (user_id,)) return category["category_id"] if category else None -def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]: +def get_followed_categories_recommendations( + user_id: int, no_streams: int = 4 +) -> Optional[List[dict]]: """ Returns top streams given a user's category following """ with Database() as db: streams = db.fetchall(""" - SELECT u.user_id, title, u.username, num_viewers, category_name - FROM streams + SELECT u.user_id, title, u.username, num_viewers, + category_name + FROM streams JOIN users u ON streams.user_id = u.user_id - JOIN categories ON streams.category_id = categories.category_id - WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?) + JOIN categories + ON streams.category_id = categories.category_id + WHERE categories.category_id IN + (SELECT category_id + FROM followed_categories WHERE user_id = ?) ORDER BY num_viewers DESC LIMIT ?; """, (user_id, no_streams)) @@ -40,25 +50,29 @@ def get_followed_your_categories(user_id: int) -> Optional[List[dict]]: categories = db.fetchall(""" SELECT categories.category_name FROM categories - JOIN followed_categories + JOIN followed_categories ON categories.category_id = followed_categories.category_id WHERE followed_categories.user_id = ?; """, (user_id,)) return categories -def get_streams_based_on_category(category_id: int, no_streams: int = 4, offset: int = 0) -> Optional[List[dict]]: +def get_streams_based_on_category( + category_id: int, no_streams: int = 4, offset: int = 0 +) -> Optional[List[dict]]: """ - Queries stream database to get top most viewed streams based on given category + Queries stream database to get top most viewed streams + based on given category """ with Database() as db: streams = db.fetchall(""" - SELECT u.user_id, title, username, num_viewers, c.category_name + SELECT u.user_id, title, username, num_viewers, + c.category_name FROM streams s JOIN users u ON s.user_id = u.user_id JOIN categories c ON s.category_id = c.category_id - WHERE c.category_id = ? - ORDER BY num_viewers DESC + WHERE c.category_id = ? + ORDER BY num_viewers DESC LIMIT ? OFFSET ? """, (category_id, no_streams, offset)) return streams @@ -70,43 +84,63 @@ def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]: """ with Database() as db: data = db.fetchall(""" - SELECT u.user_id, username, title, num_viewers, category_name - FROM streams + SELECT u.user_id, username, title, num_viewers, + category_name + FROM streams JOIN users u ON streams.user_id = u.user_id - JOIN categories ON streams.category_id = categories.category_id - ORDER BY num_viewers DESC + JOIN categories + ON streams.category_id = categories.category_id + ORDER BY num_viewers DESC LIMIT ?; """, (no_streams,)) return data -def get_highest_view_categories(no_categories: int = 4, offset: int = 0) -> Optional[List[dict]]: +def get_highest_view_categories( + no_categories: int = 4, offset: int = 0 +) -> Optional[List[dict]]: """ Returns a list of top most popular categories given offset """ with Database() as db: categories = db.fetchall(""" - SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers + SELECT categories.category_id, + categories.category_name, + COALESCE(SUM(streams.num_viewers), 0) + AS num_viewers FROM categories - LEFT JOIN streams ON streams.category_id = categories.category_id - GROUP BY categories.category_id, categories.category_name + LEFT JOIN streams + ON streams.category_id = categories.category_id + GROUP BY categories.category_id, + categories.category_name ORDER BY num_viewers DESC LIMIT ? OFFSET ?; """, (no_categories, offset)) return categories -def get_user_category_recommendations(user_id = 1, no_categories: int = 4) -> Optional[List[dict]]: +def get_user_category_recommendations( + user_id=1, no_categories: int = 4 +) -> Optional[List[dict]]: """ - Queries user_preferences database to find users top favourite streaming category and returns the category + Queries user_preferences database to find users top favourite + streaming category and returns the category """ with Database() as db: categories = db.fetchall(""" - SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers - FROM categories - JOIN user_preferences ON categories.category_id = user_preferences.category_id - LEFT JOIN streams ON categories.category_id = streams.category_id - WHERE user_preferences.user_id = ? - GROUP BY categories.category_id, categories.category_name - ORDER BY user_preferences.favourability DESC + SELECT categories.category_id, + categories.category_name, + COALESCE(SUM(streams.num_viewers), 0) + AS num_viewers + FROM categories + JOIN user_preferences + ON categories.category_id + = user_preferences.category_id + LEFT JOIN streams + ON categories.category_id + = streams.category_id + WHERE user_preferences.user_id = ? + GROUP BY categories.category_id, + categories.category_name + ORDER BY user_preferences.favourability DESC LIMIT ? """, (user_id, no_categories)) - return categories \ No newline at end of file + return categories diff --git a/web_server/utils/stream_utils.py b/web_server/utils/stream_utils.py index 1995150..50f0e26 100644 --- a/web_server/utils/stream_utils.py +++ b/web_server/utils/stream_utils.py @@ -1,9 +1,11 @@ -from database.database import Database -from typing import Optional -import os, subprocess +"""Stream data retrieval and management utilities.""" + +import os +import subprocess from typing import Optional, List -from time import sleep -from utils.path_manager import PathManager + +from database.database import Database + def get_streamer_live_status(user_id: int): """ @@ -11,8 +13,8 @@ def get_streamer_live_status(user_id: int): """ with Database() as db: is_live = db.fetchone(""" - SELECT is_live - FROM users + SELECT is_live + FROM users WHERE user_id = ?; """, (user_id,)) @@ -21,17 +23,19 @@ def get_streamer_live_status(user_id: int): def get_followed_live_streams(user_id: int) -> Optional[List[dict]]: """ Searches for streamers who the user followed which are currently live - Returns a list of live streams with the streamer's user id, stream title, and number of viewers + Returns a list of live streams with the streamer's user id, + stream title, and number of viewers """ with Database() as db: live_streams = db.fetchall(""" - SELECT users.user_id, streams.title, streams.num_viewers, users.username - FROM streams JOIN users - ON streams.user_id = users.user_id - WHERE users.user_id IN - (SELECT followed_id FROM follows WHERE user_id = ?) - AND users.is_live = 1; - """, (user_id,)) + SELECT users.user_id, streams.title, + streams.num_viewers, users.username + FROM streams JOIN users + ON streams.user_id = users.user_id + WHERE users.user_id IN + (SELECT followed_id FROM follows WHERE user_id = ?) + AND users.is_live = 1; + """, (user_id,)) return live_streams def get_current_stream_data(user_id: int) -> Optional[dict]: @@ -40,7 +44,8 @@ def get_current_stream_data(user_id: int) -> Optional[dict]: """ with Database() as db: most_recent_stream = db.fetchone(""" - SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name, c.category_id + SELECT s.user_id, u.username, s.title, s.start_time, + s.num_viewers, c.category_name, c.category_id FROM streams AS s JOIN categories AS c ON s.category_id = c.category_id JOIN users AS u ON s.user_id = u.user_id @@ -51,78 +56,92 @@ def get_current_stream_data(user_id: int) -> Optional[dict]: def end_user_stream(stream_key, user_id, username): """ Utility function to end a user's stream - + Parameters: stream_key: The stream key of the user user_id: The ID of the user username: The username of the user - + Returns: bool: True if stream was ended successfully, False otherwise """ - from flask import current_app from datetime import datetime from dateutil import parser from celery_tasks.streaming import combine_ts_stream from utils.path_manager import PathManager - - path_manager = PathManager() + + pm = PathManager() print(f"Ending stream for user {username} (ID: {user_id})", flush=True) - + if not stream_key or not user_id or not username: print("Cannot end stream - missing required information", flush=True) - return False - + return False, "Missing required information" + try: # Open database connection with Database() as db: # Get stream info - stream_info = db.fetchone("""SELECT * - FROM streams - WHERE user_id = ?""", (user_id,)) - + stream_info = db.fetchone( + """SELECT * + FROM streams + WHERE user_id = ?""", (user_id,)) + # If user is not streaming, just return if not stream_info: - print(f"User {username} (ID: {user_id}) is not streaming", flush=True) + print( + f"User {username} (ID: {user_id}) " + "is not streaming", flush=True) return True, "User is not streaming" - + # Remove stream from database - db.execute("""DELETE FROM streams - WHERE user_id = ?""", (user_id,)) - + db.execute( + """DELETE FROM streams + WHERE user_id = ?""", (user_id,)) + # Move stream to vod table stream_length = int( - (datetime.now() - parser.parse(stream_info.get("start_time"))).total_seconds()) - - db.execute("""INSERT INTO vods (user_id, title, datetime, category_id, length, views) - VALUES (?, ?, ?, ?, ?, ?)""", (user_id, - stream_info.get("title"), - stream_info.get("start_time"), - stream_info.get("category_id"), - stream_length, - 0)) - + (datetime.now() - parser.parse( + stream_info.get("start_time") + )).total_seconds()) + + db.execute( + """INSERT INTO vods + (user_id, title, datetime, category_id, + length, views) + VALUES (?, ?, ?, ?, ?, ?)""", + (user_id, + stream_info.get("title"), + stream_info.get("start_time"), + stream_info.get("category_id"), + stream_length, + 0)) + vod_id = db.get_last_insert_id() - + # Set user as not streaming - db.execute("""UPDATE users - SET is_live = 0 - WHERE user_id = ?""", (user_id,)) - + db.execute( + """UPDATE users + SET is_live = 0 + WHERE user_id = ?""", (user_id,)) + # Queue task to combine TS files into MP4 combine_ts_stream.delay( - path_manager.get_stream_path(username), - path_manager.get_vods_path(username), + pm.get_stream_path(username), + pm.get_vods_path(username), vod_id, - path_manager.get_vod_thumbnail_file_path(username, vod_id) + pm.get_vod_thumbnail_file_path(username, vod_id) ) - - print(f"Stream ended for user {username} (ID: {user_id})", flush=True) + + print( + f"Stream ended for user {username} (ID: {user_id})", + flush=True) return True, "Stream ended successfully" - - except Exception as e: - print(f"Error ending stream for user {username}: {str(e)}", flush=True) - return False, f"Error ending stream: {str(e)}" + + except (ValueError, TypeError, KeyError) as exc: + print( + f"Error ending stream for user {username}: {str(exc)}", + flush=True) + return False, f"Error ending stream: {str(exc)}" def get_category_id(category_name: str) -> Optional[int]: """ @@ -130,8 +149,8 @@ def get_category_id(category_name: str) -> Optional[int]: """ with Database() as db: data = db.fetchone(""" - SELECT category_id - FROM categories + SELECT category_id + FROM categories WHERE category_name = ?; """, (category_name,)) return data['category_id'] if data else None @@ -142,7 +161,7 @@ def get_custom_thumbnail_status(user_id: int) -> Optional[dict]: """ with Database() as db: custom_thumbnail = db.fetchone(""" - SELECT custom_thumbnail + SELECT custom_thumbnail FROM streams WHERE user_id = ?; """, (user_id,)) @@ -153,7 +172,13 @@ def get_vod(vod_id: int) -> dict: Returns data of a streamers vod """ with Database() as db: - vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vod_id = ?;""", (vod_id,)) + vod = db.fetchone( + """SELECT vods.*, username, category_name + FROM vods + JOIN users ON vods.user_id = users.user_id + JOIN categories + ON vods.category_id = categories.category_id + WHERE vod_id = ?;""", (vod_id,)) return vod def get_latest_vod(user_id: int): @@ -161,7 +186,14 @@ def get_latest_vod(user_id: int): Returns data of the most recent stream by a streamer """ with Database() as db: - latest_vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,)) + latest_vod = db.fetchone( + """SELECT vods.*, username, category_name + FROM vods + JOIN users ON vods.user_id = users.user_id + JOIN categories + ON vods.category_id = categories.category_id + WHERE vods.user_id = ? + ORDER BY vod_id DESC;""", (user_id,)) return latest_vod def get_user_vods(user_id: int): @@ -169,10 +201,17 @@ def get_user_vods(user_id: int): Returns data of all vods by a streamer """ with Database() as db: - vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,)) + vods = db.fetchall( + """SELECT vods.*, username, category_name + FROM vods + JOIN users ON vods.user_id = users.user_id + JOIN categories + ON vods.category_id = categories.category_id + WHERE vods.user_id = ? + ORDER BY vod_id DESC;""", (user_id,)) return vods -def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) -> None: +def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture=0) -> None: """ Generates the thumbnail of a stream """ @@ -192,10 +231,16 @@ def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) -> ] try: - subprocess.run(thumbnail_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + subprocess.run( + thumbnail_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True) print(f"Thumbnail generated for {stream_file}") - except subprocess.CalledProcessError as e: - print(f"No information available for {stream_file}, aborting thumbnail generation") + except subprocess.CalledProcessError: + print( + f"No information available for {stream_file}, " + "aborting thumbnail generation") def remove_hls_files(stream_path: str) -> None: """ @@ -221,11 +266,13 @@ def get_video_duration(video_path: str) -> int: ] try: - video_length = subprocess.check_output(video_length_command).decode("utf-8") + video_length = subprocess.check_output( + video_length_command + ).decode("utf-8") print(f"Video length: {video_length}") return int(float(video_length)) - except subprocess.CalledProcessError as e: - print(f"Error getting video length: {e}") + except subprocess.CalledProcessError: + print("Error getting video length") return 0 def get_stream_tags(user_id: int) -> Optional[List[str]]: @@ -234,10 +281,10 @@ def get_stream_tags(user_id: int) -> Optional[List[str]]: """ with Database() as db: tags = db.fetchall(""" - SELECT tag_name + SELECT tag_name FROM tags JOIN stream_tags ON tags.tag_id = stream_tags.tag_id - WHERE user_id = ?; + WHERE user_id = ?; """, (user_id,)) return tags @@ -247,9 +294,9 @@ def get_vod_tags(vod_id: int): """ with Database() as db: tags = db.fetchall(""" - SELECT tag_name + SELECT tag_name FROM tags JOIN vod_tags ON tags.tag_id = vod_tags.tag_id - WHERE vod_id = ?; + WHERE vod_id = ?; """, (vod_id,)) - return tags \ No newline at end of file + return tags diff --git a/web_server/utils/user_utils.py b/web_server/utils/user_utils.py index a8658a3..e7b37ba 100644 --- a/web_server/utils/user_utils.py +++ b/web_server/utils/user_utils.py @@ -1,16 +1,20 @@ -from database.database import Database +"""User profile management, following, and subscription utilities.""" + from typing import Optional, List from datetime import datetime, timedelta + +from database.database import Database from dateutil import parser + def get_user_id(username: str) -> Optional[int]: """ Returns user_id associated with given username """ with Database() as db: data = db.fetchone(""" - SELECT user_id - FROM users + SELECT user_id + FROM users WHERE username = ? """, (username,)) return data['user_id'] if data else None @@ -21,7 +25,7 @@ def get_username(user_id: str) -> Optional[str]: """ with Database() as db: data = db.fetchone(""" - SELECT username + SELECT username FROM users WHERE user_id = ? """, (user_id,)) @@ -40,15 +44,16 @@ def update_bio(user_id: int, bio: str): def has_password(email: str): """ - Returns if account associated with this email has password, i.e is account from Google OAuth + Returns if account associated with this email has password, + i.e is account from Google OAuth """ with Database() as db: data = db.fetchone(""" SELECT password FROM users - WHERE email = ? + WHERE email = ? """, (email,)) - return False if data["password"] == None else True + return data["password"] is not None def get_session_info_email(email: str) -> dict: """ @@ -61,15 +66,15 @@ def get_session_info_email(email: str) -> dict: WHERE email = ? """, (email,)) return session_info - + def is_user_partner(user_id: int) -> bool: """ Returns True if user is a partner, else False """ with Database() as db: data = db.fetchone(""" - SELECT is_partnered - FROM users + SELECT is_partnered + FROM users WHERE user_id = ? """, (user_id,)) return bool(data) @@ -79,12 +84,12 @@ def is_subscribed(user_id: int, subscribed_to_id: int) -> bool: with Database() as db: return bool(db.fetchone( """ - SELECT 1 - FROM subscribes - WHERE user_id = ? - AND subscribed_id = ? + SELECT 1 + FROM subscribes + WHERE user_id = ? + AND subscribed_id = ? AND expires > ?; - """, + """, (user_id, subscribed_to_id, datetime.now()) )) @@ -94,9 +99,9 @@ def is_following(user_id: int, followed_id: int) -> bool: """ with Database() as db: result = db.fetchone(""" - SELECT 1 - FROM follows - WHERE user_id = ? + SELECT 1 + FROM follows + WHERE user_id = ? AND followed_id = ? """, (user_id, followed_id)) return bool(result) @@ -107,7 +112,7 @@ def follow(user_id: int, followed_id: int): """ if is_following(user_id, followed_id): return {"success": False, "error": "Already following user"}, 400 - + with Database() as db: db.execute(""" INSERT INTO follows (user_id, followed_id) @@ -135,9 +140,9 @@ def is_following_category(user_id: int, category_id: str): """ with Database() as db: result = db.fetchone(""" - SELECT 1 - FROM followed_categories - WHERE user_id = ? + SELECT 1 + FROM followed_categories + WHERE user_id = ? AND category_id = ? """, (user_id, category_id)) return bool(result) @@ -148,7 +153,7 @@ def follow_category(user_id: int, category_id: str): """ if is_following_category(user_id, category_id): return {"success": False, "error": "Already following category"}, 400 - + with Database() as db: db.execute(""" INSERT INTO followed_categories (user_id, category_id) @@ -163,7 +168,7 @@ def unfollow_category(user_id: int, category_id: str): """ if not is_following_category(user_id, category_id): return {"success": False, "error": "Not following category"}, 400 - + with Database() as db: db.execute(""" DELETE FROM followed_categories @@ -179,8 +184,8 @@ def subscribe(user_id: int, streamer_id: int): # If user is already subscribed then extend the expiration date else create a new entry with Database() as db: existing = db.fetchone(""" - SELECT expires - FROM subscribes + SELECT expires + FROM subscribes WHERE user_id = ? AND subscribed_id = ? """, (user_id, streamer_id)) if existing: @@ -188,7 +193,7 @@ def subscribe(user_id: int, streamer_id: int): UPDATE subscribes SET expires = expires + ? WHERE user_id = ? AND subscribed_id = ? """, (timedelta(days=30), user_id, streamer_id)) - else: + else: db.execute(""" INSERT INTO subscribes (user_id, subscribed_id, since, expires) @@ -212,10 +217,10 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int: """ with Database() as db: data = db.fetchone(""" - SELECT expires - FROM subscribes - WHERE user_id = ? - AND subscribed_id = ? + SELECT expires + FROM subscribes + WHERE user_id = ? + AND subscribed_id = ? AND expires > ? """, (user_id, subscribed_id, datetime.now())) @@ -227,13 +232,14 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int: return 0 def get_email(user_id: int) -> Optional[str]: + """Returns the email address for a given user_id.""" with Database() as db: email = db.fetchone(""" SELECT email FROM users WHERE user_id = ? """, (user_id,)) - + return email["email"] if email else None def get_followed_streamers(user_id: int) -> Optional[List[dict]]: @@ -289,4 +295,4 @@ def get_user(user_id: int) -> Optional[dict]: SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users WHERE user_id = ?; """, (user_id,)) - return data \ No newline at end of file + return data diff --git a/web_server/utils/utils.py b/web_server/utils/utils.py index 6bd20b1..314f11b 100644 --- a/web_server/utils/utils.py +++ b/web_server/utils/utils.py @@ -1,14 +1,18 @@ -from database.database import Database +"""Input sanitization and validation utilities.""" + from typing import Optional, List from re import match +from database.database import Database + + def get_all_categories() -> Optional[List[dict]]: """ Returns all possible streaming categories """ with Database() as db: all_categories = db.fetchall("SELECT * FROM categories") - + return all_categories def get_all_tags() -> Optional[List[dict]]: @@ -17,7 +21,7 @@ def get_all_tags() -> Optional[List[dict]]: """ with Database() as db: all_tags = db.fetchall("SELECT * FROM tags") - + return all_tags def get_most_popular_category() -> Optional[List[dict]]: @@ -34,7 +38,7 @@ def get_most_popular_category() -> Optional[List[dict]]: ORDER BY SUM(streams.num_viewers) DESC LIMIT 1; """) - + return category def get_category_id(category_name: str): @@ -53,7 +57,7 @@ def get_category_id(category_name: str): def sanitize(user_input: str, input_type="default") -> str: """ Sanitizes user input based on the specified input type. - + `input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password'). """ # Strip leading and trailing whitespace @@ -84,10 +88,10 @@ def sanitize(user_input: str, input_type="default") -> str: } # Get the validation rules for the specified type - r = rules.get(input_type) - if not r or \ - not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \ - not match(r["pattern"], sanitised_input): + rule = rules.get(input_type) + if (not rule + or not (rule["min_length"] <= len(sanitised_input) <= rule["max_length"]) + or not match(rule["pattern"], sanitised_input)): raise ValueError("Unaccepted character or length in input") - return sanitised_input \ No newline at end of file + return sanitised_input