Fix/pylint cleanup (#8)
* Fix pylint warnings across all 24 Python files in web_server - Add module, class, and function docstrings (C0114, C0115, C0116) - Fix import ordering: stdlib before third-party before local (C0411) - Replace wildcard imports with explicit named imports (W0401) - Remove trailing whitespace and add missing final newlines (C0303, C0304) - Replace dict() with dict literals (R1735) - Remove unused imports and variables (W0611, W0612) - Narrow broad Exception catches to specific exceptions (W0718) - Replace f-string logging with lazy % formatting (W1203) - Fix variable naming: UPPER_CASE for constants, snake_case for locals (C0103) - Add pylint disable comments for necessary global statements (W0603) - Fix no-else-return, simplifiable-if-expression, singleton-comparison - Fix bad indentation in stripe.py (W0311) - Add encoding="utf-8" to open() calls (W1514) - Add check=True to subprocess.run() calls (W1510) - Register Celery task modules via conf.include * Update `package-lock.json` add peer dependencies
This commit is contained in:
committed by
GitHub
parent
fed1a2f288
commit
2758be8680
24
frontend/package-lock.json
generated
24
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "0.5.0",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "frontend",
|
||||
"version": "0.5.0",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@stripe/react-stripe-js": "^3.1.1",
|
||||
"@stripe/stripe-js": "^5.5.0",
|
||||
@@ -97,6 +97,7 @@
|
||||
"integrity": "sha512-i1SLeK+DzNnQ3LL/CswPCa/E5u4lh1k6IAEphON8F+cXt0t9euTshDru0q7/IqMa1PMPz5RnHuHscF8/ZJsStg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@ampproject/remapping": "^2.2.0",
|
||||
"@babel/code-frame": "^7.26.0",
|
||||
@@ -1393,6 +1394,7 @@
|
||||
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-5.5.0.tgz",
|
||||
"integrity": "sha512-lkfjyAd34aeMpTKKcEVfy8IUyEsjuAT3t9EXr5yZDtdIUncnZpedl/xLV16Dkd4z+fQwixScsCCDxSMNtBOgpQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12.16"
|
||||
}
|
||||
@@ -1482,6 +1484,7 @@
|
||||
"integrity": "sha512-t4yC+vtgnkYjNSKlFx1jkAhH8LgTo2N/7Qvi83kdEaUtMDiwpbLAktKDaAMlRcJ5eSxZkH74eEGt1ky31d7kfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"csstype": "^3.0.2"
|
||||
@@ -1539,6 +1542,7 @@
|
||||
"integrity": "sha512-gKXG7A5HMyjDIedBi6bUrDcun8GIjnI8qOwVLiY3rx6T/sHP/19XLJOnIq/FgQvWLHja5JN/LSE7eklNBr612g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.20.0",
|
||||
"@typescript-eslint/types": "8.20.0",
|
||||
@@ -1805,6 +1809,7 @@
|
||||
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -2018,6 +2023,7 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"caniuse-lite": "^1.0.30001688",
|
||||
"electron-to-chromium": "^1.5.73",
|
||||
@@ -2051,9 +2057,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/caniuse-lite": {
|
||||
"version": "1.0.30001695",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001695.tgz",
|
||||
"integrity": "sha512-vHyLade6wTgI2u1ec3WQBxv+2BrTERV28UXQu9LO6lZ9pYeMk34vjXFLOxo1A4UBA8XTL4njRQZdno/yYaSmWw==",
|
||||
"version": "1.0.30001769",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001769.tgz",
|
||||
"integrity": "sha512-BCfFL1sHijQlBGWBMuJyhZUhzo7wer5sVj9hqekB/7xn0Ypy+pER/edCYQm4exbXj4WiySGp40P8UuTh6w1srg==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
{
|
||||
@@ -2386,6 +2392,7 @@
|
||||
"integrity": "sha512-+waTfRWQlSbpt3KWE+CjrPPYnbq9kfZIYUqapc0uBXyjTp8aYXZDsUH16m39Ryq3NjAVP4tjuF7KaukeqoCoaA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.2.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -3534,6 +3541,7 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"nanoid": "^3.3.8",
|
||||
"picocolors": "^1.1.1",
|
||||
@@ -3723,6 +3731,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
|
||||
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0"
|
||||
},
|
||||
@@ -3741,6 +3750,7 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
|
||||
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0",
|
||||
"scheduler": "^0.23.2"
|
||||
@@ -4233,6 +4243,7 @@
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.17.tgz",
|
||||
"integrity": "sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@alloc/quick-lru": "^5.2.0",
|
||||
"arg": "^5.0.2",
|
||||
@@ -4351,6 +4362,7 @@
|
||||
"integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -4443,6 +4455,7 @@
|
||||
"resolved": "https://registry.npmjs.org/video.js/-/video.js-8.21.0.tgz",
|
||||
"integrity": "sha512-zcwerRb257QAuWfi8NH9yEX7vrGKFthjfcONmOQ4lxFRpDAbAi+u5LAjCjMWqhJda6zEmxkgdDpOMW3Y21QpXA==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.12.5",
|
||||
"@videojs/http-streaming": "^3.16.2",
|
||||
@@ -4495,6 +4508,7 @@
|
||||
"integrity": "sha512-4VL9mQPKoHy4+FE0NnRE/kbY51TOfaknxAjt3fJbGJxhIpBZiqVzlZDEesWWsuREXHwNdAoOFZ9MkPEVXczHwg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.24.2",
|
||||
"postcss": "^8.4.49",
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Flask application factory and blueprint registration."""
|
||||
|
||||
from os import getenv
|
||||
|
||||
from flask import Flask
|
||||
from flask_session import Session
|
||||
from flask_cors import CORS
|
||||
@@ -13,10 +17,8 @@ from blueprints.oauth import oauth_bp, init_oauth
|
||||
from blueprints.socket import socketio
|
||||
from blueprints.search_bar import search_bp
|
||||
|
||||
from celery import Celery
|
||||
from celery_tasks import celery_init_app
|
||||
|
||||
from os import getenv
|
||||
|
||||
def create_app():
|
||||
"""
|
||||
@@ -34,16 +36,15 @@ def create_app():
|
||||
app.config['GOOGLE_CLIENT_SECRET'] = getenv("GOOGLE_CLIENT_SECRET")
|
||||
app.config["SESSION_COOKIE_HTTPONLY"] = True
|
||||
|
||||
|
||||
app.config.from_mapping(
|
||||
CELERY=dict(
|
||||
broker_url="redis://redis:6379/0",
|
||||
result_backend="redis://redis:6379/0",
|
||||
task_ignore_result=True,
|
||||
),
|
||||
CELERY={
|
||||
"broker_url": "redis://redis:6379/0",
|
||||
"result_backend": "redis://redis:6379/0",
|
||||
"task_ignore_result": True,
|
||||
},
|
||||
)
|
||||
app.config.from_prefixed_env()
|
||||
celery = celery_init_app(app)
|
||||
celery_init_app(app)
|
||||
|
||||
#! ↓↓↓ For development purposes only - Allow cross-origin requests for the frontend
|
||||
CORS(app, supports_credentials=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Admin blueprint for user management operations."""
|
||||
|
||||
from flask import Blueprint, session
|
||||
from utils.utils import sanitize
|
||||
from utils.admin_utils import *
|
||||
from utils.admin_utils import check_if_admin, check_if_user_exists, ban_user
|
||||
|
||||
admin_bp = Blueprint("admin", __name__)
|
||||
|
||||
@@ -18,10 +20,10 @@ def admin_delete_user(banned_user):
|
||||
# Check if the user is an admin
|
||||
username = session.get("username")
|
||||
is_admin = check_if_admin(username)
|
||||
|
||||
|
||||
# Check if the user exists
|
||||
user_exists = check_if_user_exists(banned_user)
|
||||
|
||||
# If the user is an admin, try to delete the account
|
||||
if is_admin and user_exists:
|
||||
ban_user(banned_user)
|
||||
ban_user(banned_user)
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Authentication blueprint for user signup, login, and logout."""
|
||||
|
||||
import logging
|
||||
from secrets import token_hex
|
||||
|
||||
from flask import Blueprint, session, request, jsonify
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_cors import cross_origin
|
||||
@@ -5,18 +10,21 @@ from database.database import Database
|
||||
from blueprints.middleware import login_required
|
||||
from utils.user_utils import get_user_id
|
||||
from utils.utils import sanitize
|
||||
from secrets import token_hex
|
||||
from utils.path_manager import PathManager
|
||||
|
||||
auth_bp = Blueprint("auth", __name__)
|
||||
|
||||
path_manager = PathManager()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auth_bp.route("/signup", methods=["POST"])
|
||||
@cross_origin(supports_credentials=True)
|
||||
def signup():
|
||||
"""
|
||||
Route that allows a user to sign up by providing a `username`, `email` and `password`.
|
||||
Route that allows a user to sign up by providing a
|
||||
`username`, `email` and `password`.
|
||||
"""
|
||||
# ensure a JSON request is made to contact this route
|
||||
if not request.is_json:
|
||||
@@ -30,19 +38,21 @@ def signup():
|
||||
|
||||
# Validation - ensure all fields exist, users cannot have an empty field
|
||||
if not all([username, email, password]):
|
||||
error_fields = get_error_fields([username, email, password]) #!←← find the error_fields, to highlight them in red to the user on the frontend
|
||||
error_fields = get_error_fields(
|
||||
[username, email, password]
|
||||
)
|
||||
return jsonify({
|
||||
"account_created": False,
|
||||
"error_fields": error_fields,
|
||||
"message": "Missing required fields"
|
||||
}), 400
|
||||
|
||||
|
||||
# Sanitize the inputs - helps to prevent SQL injection
|
||||
try:
|
||||
username = sanitize(username, "username")
|
||||
email = sanitize(email, "email")
|
||||
password = sanitize(password, "password")
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
error_fields = get_error_fields([username, email, password])
|
||||
return jsonify({
|
||||
"account_created": False,
|
||||
@@ -81,7 +91,7 @@ def signup():
|
||||
|
||||
# Create new user once input is validated
|
||||
db.execute(
|
||||
"""INSERT INTO users
|
||||
"""INSERT INTO users
|
||||
(username, password, email, stream_key)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(
|
||||
@@ -100,16 +110,16 @@ def signup():
|
||||
"message": "Account created successfully"
|
||||
}), 201
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during signup: {e}") # Log the error
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
logger.error("Error during signup: %s", exc)
|
||||
return jsonify({
|
||||
"account_created": False,
|
||||
"message": "Server error occurred: " + str(e)
|
||||
"message": "Server error occurred: " + str(exc)
|
||||
}), 500
|
||||
|
||||
finally:
|
||||
db.close_connection()
|
||||
|
||||
|
||||
|
||||
@auth_bp.route("/login", methods=["POST"])
|
||||
@cross_origin(supports_credentials=True)
|
||||
@@ -127,24 +137,24 @@ def login():
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
|
||||
# Validation - ensure all fields exist, users cannot have an empty field
|
||||
# Validation - ensure all fields exist, users cannot have an empty field
|
||||
if not all([username, password]):
|
||||
return jsonify({
|
||||
"logged_in": False,
|
||||
"message": "Missing required fields"
|
||||
}), 400
|
||||
|
||||
|
||||
# Sanitize the inputs - helps to prevent SQL injection
|
||||
try:
|
||||
username = sanitize(username, "username")
|
||||
password = sanitize(password, "password")
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return jsonify({
|
||||
"account_created": False,
|
||||
"error_fields": ["username", "password"],
|
||||
"message": "Invalid input received"
|
||||
}), 400
|
||||
|
||||
|
||||
# Create a connection to the database
|
||||
db = Database()
|
||||
|
||||
@@ -169,7 +179,7 @@ def login():
|
||||
"error_fields": ["username", "password"],
|
||||
"message": "Invalid username or password"
|
||||
}), 401
|
||||
|
||||
|
||||
# Add user directories for stream data in case they don't exist
|
||||
path_manager.create_user(username)
|
||||
|
||||
@@ -177,7 +187,10 @@ def login():
|
||||
session.clear()
|
||||
session["username"] = username
|
||||
session["user_id"] = get_user_id(username)
|
||||
print(f"Logged in as {username}. session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True)
|
||||
logger.info(
|
||||
"Logged in as %s. session: %s. user_id: %s",
|
||||
username, session.get('username'), session.get('user_id')
|
||||
)
|
||||
|
||||
# User has been logged in, let frontend know that
|
||||
return jsonify({
|
||||
@@ -186,8 +199,8 @@ def login():
|
||||
"username": username
|
||||
}), 200
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during login: {e}") # Log the error
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
logger.error("Error during login: %s", exc)
|
||||
return jsonify({
|
||||
"logged_in": False,
|
||||
"message": "Server error occurred"
|
||||
@@ -202,31 +215,41 @@ def login():
|
||||
def logout() -> dict:
|
||||
"""
|
||||
Log out and clear the users session.
|
||||
|
||||
|
||||
If the user is currently streaming, end their stream first.
|
||||
Can only be accessed by a logged in user.
|
||||
"""
|
||||
from database.database import Database
|
||||
from utils.stream_utils import end_user_stream
|
||||
|
||||
|
||||
# Check if user is currently streaming
|
||||
user_id = session.get("user_id")
|
||||
username = session.get("username")
|
||||
|
||||
|
||||
with Database() as db:
|
||||
is_streaming = db.fetchone("""SELECT is_live FROM users WHERE user_id = ?""", (user_id,))
|
||||
|
||||
is_streaming = db.fetchone(
|
||||
"""SELECT is_live FROM users WHERE user_id = ?""",
|
||||
(user_id,)
|
||||
)
|
||||
|
||||
if is_streaming and is_streaming.get("is_live") == 1:
|
||||
# Get the user's stream key
|
||||
stream_key_info = db.fetchone("""SELECT stream_key FROM users WHERE user_id = ?""", (user_id,))
|
||||
stream_key = stream_key_info.get("stream_key") if stream_key_info else None
|
||||
|
||||
stream_key_info = db.fetchone(
|
||||
"""SELECT stream_key FROM users WHERE user_id = ?""",
|
||||
(user_id,)
|
||||
)
|
||||
stream_key = (
|
||||
stream_key_info.get("stream_key") if stream_key_info
|
||||
else None
|
||||
)
|
||||
|
||||
if stream_key:
|
||||
# End the stream
|
||||
end_user_stream(stream_key, user_id, username)
|
||||
session.clear()
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def get_error_fields(values: list):
|
||||
"""Return field names for empty values."""
|
||||
fields = ["username", "email", "password"]
|
||||
return [fields[i] for i, v in enumerate(values) if not v]
|
||||
return [fields[i] for i, v in enumerate(values) if not v]
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Chat blueprint for WebSocket-based real-time messaging."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Blueprint, jsonify
|
||||
from database.database import Database
|
||||
from .socket import socketio
|
||||
from flask_socketio import emit, join_room, leave_room
|
||||
from datetime import datetime
|
||||
from utils.user_utils import get_user_id, is_subscribed
|
||||
from utils.user_utils import is_subscribed
|
||||
import redis
|
||||
import json
|
||||
|
||||
redis_url = "redis://redis:6379/1"
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
REDIS_URL = "redis://redis:6379/1"
|
||||
r = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
chat_bp = Blueprint("chat", __name__)
|
||||
|
||||
@socketio.on("connect")
|
||||
@@ -32,7 +35,7 @@ def handle_join(data) -> None:
|
||||
join_room(stream_id)
|
||||
num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id)))
|
||||
update_viewers(stream_id, num_viewers)
|
||||
emit("status",
|
||||
emit("status",
|
||||
{
|
||||
"message": f"Welcome to the chat, stream_id: {stream_id}",
|
||||
"num_viewers": num_viewers
|
||||
@@ -53,7 +56,7 @@ def handle_leave(data) -> None:
|
||||
remove_favourability_entry(str(data["user_id"]), str(stream_id))
|
||||
num_viewers = len(list(socketio.server.manager.get_participants("/", stream_id)))
|
||||
update_viewers(stream_id, num_viewers)
|
||||
emit("status",
|
||||
emit("status",
|
||||
{
|
||||
"message": f"Welcome to the chat, stream_id: {stream_id}",
|
||||
"num_viewers": num_viewers
|
||||
@@ -78,10 +81,10 @@ def get_past_chat(stream_id: int):
|
||||
all_chats = db.fetchall("""
|
||||
SELECT user_id, username, message, time_sent, is_subscribed
|
||||
FROM (
|
||||
SELECT
|
||||
SELECT
|
||||
u.user_id,
|
||||
u.username,
|
||||
c.message,
|
||||
u.username,
|
||||
c.message,
|
||||
c.time_sent,
|
||||
CASE
|
||||
WHEN s.user_id IS NOT NULL AND s.expires > CURRENT_TIMESTAMP THEN 1
|
||||
@@ -101,8 +104,8 @@ def get_past_chat(stream_id: int):
|
||||
|
||||
# Create JSON output of chat_history to pass through NGINX proxy
|
||||
chat_history = [{"chatter_id": chat["user_id"],
|
||||
"chatter_username": chat["username"],
|
||||
"message": chat["message"],
|
||||
"chatter_username": chat["username"],
|
||||
"message": chat["message"],
|
||||
"time_sent": chat["time_sent"],
|
||||
"is_subscribed": bool(chat["is_subscribed"])} for chat in all_chats]
|
||||
print(chat_history)
|
||||
@@ -125,7 +128,13 @@ def send_chat(data) -> None:
|
||||
|
||||
# Input validation - chatter is logged in, message is not empty, stream exists
|
||||
if not all([chatter_name, message, stream_id]):
|
||||
emit("error", {"error": f"Unable to send a chat. The following info was given: chatter_name={chatter_name}, message={message}, stream_id={stream_id}"}, broadcast=False)
|
||||
emit("error", {
|
||||
"error": (
|
||||
f"Unable to send a chat. The following info was given: "
|
||||
f"chatter_name={chatter_name}, message={message}, "
|
||||
f"stream_id={stream_id}"
|
||||
)
|
||||
}, broadcast=False)
|
||||
return
|
||||
subscribed = is_subscribed(chatter_id, stream_id)
|
||||
# Send the chat message to the client so it can be displayed
|
||||
@@ -161,9 +170,8 @@ def update_viewers(user_id, num_viewers):
|
||||
SET num_viewers = ?
|
||||
WHERE user_id = ?;
|
||||
""", (num_viewers, user_id))
|
||||
db.close_connection
|
||||
|
||||
#TODO: Make sure that users entry within Redis is removed if they disconnect from socket
|
||||
db.close_connection()
|
||||
|
||||
def add_favourability_entry(user_id, stream_id):
|
||||
"""
|
||||
Adds entry to Redis that user is watching a streamer
|
||||
@@ -183,7 +191,7 @@ def add_favourability_entry(user_id, stream_id):
|
||||
else:
|
||||
# Creates new entry for user and stream
|
||||
current_viewers[user_id] = [stream_id]
|
||||
|
||||
|
||||
r.hset("current_viewers", "viewers", json.dumps(current_viewers))
|
||||
|
||||
def remove_favourability_entry(user_id, stream_id):
|
||||
@@ -202,9 +210,9 @@ def remove_favourability_entry(user_id, stream_id):
|
||||
if user_id in current_viewers:
|
||||
# Removes specific stream from user
|
||||
current_viewers[user_id] = [stream for stream in current_viewers[user_id] if stream != stream_id]
|
||||
|
||||
|
||||
# If user is no longer watching any streams
|
||||
if not current_viewers[user_id]:
|
||||
del current_viewers[user_id]
|
||||
|
||||
r.hset("current_viewers", "viewers", json.dumps(current_viewers))
|
||||
r.hset("current_viewers", "viewers", json.dumps(current_viewers))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from flask import redirect, g, session
|
||||
from functools import wraps
|
||||
"""Authentication middleware and error handler registration."""
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from os import getenv
|
||||
|
||||
from flask import redirect, g, session
|
||||
from dotenv import load_dotenv
|
||||
from database.database import Database
|
||||
|
||||
@@ -57,5 +60,5 @@ def register_error_handlers(app):
|
||||
for code, message in error_responses.items():
|
||||
@app.errorhandler(code)
|
||||
def handle_error(error, message=message, code=code):
|
||||
logging.error(f"Error {code}: {str(error)}")
|
||||
return {"error": message}, code
|
||||
logging.error("Error %d: %s", code, str(error))
|
||||
return {"error": message}, code
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""OAuth blueprint for Google authentication."""
|
||||
|
||||
from os import getenv
|
||||
from secrets import token_hex, token_urlsafe
|
||||
from random import randint
|
||||
|
||||
from authlib.integrations.flask_client import OAuth, OAuthError
|
||||
from flask import Blueprint, jsonify, session, redirect, request
|
||||
from blueprints.user import get_session_info_email
|
||||
from database.database import Database
|
||||
from dotenv import load_dotenv
|
||||
from secrets import token_hex, token_urlsafe
|
||||
from random import randint
|
||||
from utils.path_manager import PathManager
|
||||
|
||||
oauth_bp = Blueprint("oauth", __name__)
|
||||
google = None
|
||||
_google = None
|
||||
|
||||
load_dotenv()
|
||||
url_api = getenv("VITE_API_URL")
|
||||
@@ -23,8 +26,8 @@ def init_oauth(app):
|
||||
Initialise the OAuth functionality.
|
||||
"""
|
||||
oauth = OAuth(app)
|
||||
global google
|
||||
google = oauth.register(
|
||||
global _google # pylint: disable=global-statement
|
||||
_google = oauth.register(
|
||||
'google',
|
||||
client_id=app.config['GOOGLE_CLIENT_ID'],
|
||||
client_secret=app.config['GOOGLE_CLIENT_SECRET'],
|
||||
@@ -50,11 +53,11 @@ def login_google():
|
||||
session["nonce"] = token_urlsafe(16)
|
||||
session["state"] = token_urlsafe(32)
|
||||
session["origin"] = request.args.get("next")
|
||||
|
||||
|
||||
# Make sure session is saved before redirect
|
||||
session.modified = True
|
||||
|
||||
return google.authorize_redirect(
|
||||
|
||||
return _google.authorize_redirect(
|
||||
redirect_uri=f'{url}/api/google_auth',
|
||||
nonce=session['nonce'],
|
||||
state=session['state']
|
||||
@@ -70,23 +73,27 @@ def google_auth():
|
||||
# Check state parameter before authorizing
|
||||
returned_state = request.args.get('state')
|
||||
stored_state = session.get('state')
|
||||
|
||||
|
||||
if not stored_state or stored_state != returned_state:
|
||||
print(f"State mismatch: stored={stored_state}, returned={returned_state}", flush=True)
|
||||
print(
|
||||
f"State mismatch: stored={stored_state}, "
|
||||
f"returned={returned_state}", flush=True
|
||||
)
|
||||
return jsonify({
|
||||
'error': f"mismatching_state: CSRF Warning! State not equal in request and response.",
|
||||
'error': "mismatching_state: CSRF Warning! "
|
||||
"State not equal in request and response.",
|
||||
'message': 'Authentication failed'
|
||||
}), 400
|
||||
|
||||
|
||||
# State matched, proceed with token authorization
|
||||
token = google.authorize_access_token()
|
||||
token = _google.authorize_access_token()
|
||||
|
||||
# Verify nonce
|
||||
nonce = session.get('nonce')
|
||||
if not nonce:
|
||||
return jsonify({'error': 'Missing nonce in session'}), 400
|
||||
|
||||
user = google.parse_id_token(token, nonce=nonce)
|
||||
user = _google.parse_id_token(token, nonce=nonce)
|
||||
|
||||
# Check if email exists to login else create a database entry
|
||||
user_email = user.get("email")
|
||||
@@ -108,7 +115,7 @@ def google_auth():
|
||||
break
|
||||
|
||||
db.execute(
|
||||
"""INSERT INTO users
|
||||
"""INSERT INTO users
|
||||
(username, email, stream_key)
|
||||
VALUES (?, ?, ?)""",
|
||||
(
|
||||
@@ -124,16 +131,19 @@ def google_auth():
|
||||
origin = session.get("origin", f"{url.replace('/api', '')}")
|
||||
username = user_data["username"]
|
||||
user_id = user_data["user_id"]
|
||||
|
||||
|
||||
# Clear session and set new data
|
||||
session.clear()
|
||||
session["username"] = username
|
||||
session["user_id"] = user_id
|
||||
|
||||
|
||||
# Ensure session is saved
|
||||
session.modified = True
|
||||
|
||||
print(f"session: {session.get('username')}. user_id: {session.get('user_id')}", flush=True)
|
||||
|
||||
print(
|
||||
f"session: {session.get('username')}. "
|
||||
f"user_id: {session.get('user_id')}", flush=True
|
||||
)
|
||||
|
||||
return redirect(origin)
|
||||
|
||||
@@ -144,9 +154,9 @@ def google_auth():
|
||||
'error': str(e)
|
||||
}), 400
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
print(f"Unexpected Error: {str(e)}", flush=True)
|
||||
return jsonify({
|
||||
'message': 'An unexpected error occurred',
|
||||
'error': str(e)
|
||||
}), 500
|
||||
}), 500
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Search bar blueprint for querying categories, users, streams, and VODs."""
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
from database.database import Database
|
||||
from utils.utils import sanitize
|
||||
@@ -20,7 +22,7 @@ def rank_results(query, result):
|
||||
# Assign a score based on the level of the match
|
||||
if query in result:
|
||||
return 0
|
||||
elif all(c in charset for c in query):
|
||||
if all(c in charset for c in query):
|
||||
return 1
|
||||
return 2
|
||||
|
||||
@@ -50,7 +52,7 @@ def search_results():
|
||||
res_dict.append(c)
|
||||
categories = sorted(res_dict, key=lambda d: d["score"])
|
||||
categories = categories[:4]
|
||||
|
||||
|
||||
# 3 users
|
||||
res_dict = []
|
||||
users = db.fetchall("SELECT user_id, username, is_live FROM users")
|
||||
@@ -63,7 +65,7 @@ def search_results():
|
||||
users = sorted(res_dict, key=lambda d: d["score"])
|
||||
users = users[:4]
|
||||
|
||||
# 3 streams
|
||||
# 3 streams
|
||||
res_dict = []
|
||||
streams = db.fetchall("""SELECT s.user_id, s.title, s.num_viewers, c.category_name, u.username
|
||||
FROM streams AS s
|
||||
@@ -71,7 +73,7 @@ def search_results():
|
||||
INNER JOIN users AS u ON s.user_id = u.user_id
|
||||
INNER JOIN categories AS c ON s.category_id = c.category_id
|
||||
""")
|
||||
|
||||
|
||||
for s in streams:
|
||||
key = s.get("title")
|
||||
score = rank_results(query.lower(), key.lower())
|
||||
@@ -83,7 +85,7 @@ def search_results():
|
||||
|
||||
# 3 VODs
|
||||
res_dict = []
|
||||
vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username
|
||||
vods = db.fetchall("""SELECT v.vod_id, v.title, u.user_id, u.username
|
||||
FROM vods as v JOIN users as u
|
||||
ON v.user_id = u.user_id""")
|
||||
for v in vods:
|
||||
@@ -98,5 +100,5 @@ def search_results():
|
||||
db.close_connection()
|
||||
|
||||
print(query, streams, users, categories, vods, flush=True)
|
||||
|
||||
return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods})
|
||||
|
||||
return jsonify({"streams": streams, "categories": categories, "users": users, "vods": vods})
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""WebSocket configuration using Flask-SocketIO."""
|
||||
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
socketio = SocketIO(
|
||||
cors_allowed_origins="*",
|
||||
cors_allowed_origins="*",
|
||||
async_mode='gevent',
|
||||
logger=False, # Reduce logging
|
||||
engineio_logger=False, # Reduce logging
|
||||
ping_timeout=5000,
|
||||
ping_interval=25000
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
"""Stream and VOD management blueprint."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Blueprint, session, jsonify, request, redirect
|
||||
from utils.stream_utils import *
|
||||
from utils.recommendation_utils import *
|
||||
from utils.stream_utils import (
|
||||
get_category_id, get_current_stream_data, get_streamer_live_status,
|
||||
get_latest_vod, get_vod, get_user_vods, end_user_stream
|
||||
)
|
||||
from utils.recommendation_utils import (
|
||||
get_user_preferred_category, get_highest_view_streams,
|
||||
get_streams_based_on_category, get_highest_view_categories,
|
||||
get_user_category_recommendations, get_followed_categories_recommendations
|
||||
)
|
||||
from utils.user_utils import get_user_id
|
||||
from blueprints.middleware import login_required
|
||||
from database.database import Database
|
||||
from datetime import datetime
|
||||
from celery_tasks.streaming import update_thumbnail, combine_ts_stream
|
||||
from dateutil import parser
|
||||
from celery_tasks.streaming import update_thumbnail
|
||||
from utils.path_manager import PathManager
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
stream_bp = Blueprint("stream", __name__)
|
||||
|
||||
@@ -28,12 +37,12 @@ def popular_streams(no_streams) -> list[dict]:
|
||||
Returns a list of streams live now with the highest viewers
|
||||
"""
|
||||
|
||||
# Limit the number of streams to MAX_STREAMS
|
||||
MAX_STREAMS = 100
|
||||
# Limit the number of streams to max_streams
|
||||
max_streams = 100
|
||||
if no_streams < 1:
|
||||
return jsonify([])
|
||||
elif no_streams > MAX_STREAMS:
|
||||
no_streams = MAX_STREAMS
|
||||
if no_streams > max_streams:
|
||||
no_streams = max_streams
|
||||
|
||||
# Get the highest viewed streams
|
||||
streams = get_highest_view_streams(no_streams)
|
||||
@@ -101,7 +110,7 @@ def popular_categories(no_categories=4, offset=0) -> list[dict]:
|
||||
# Limit the number of categories to 100
|
||||
if no_categories < 1:
|
||||
return jsonify([])
|
||||
elif no_categories > 100:
|
||||
if no_categories > 100:
|
||||
no_categories = 100
|
||||
|
||||
category_data = get_highest_view_categories(no_categories, offset)
|
||||
@@ -135,7 +144,8 @@ def following_categories_streams():
|
||||
@stream_bp.route('/user/<string:username>/status')
|
||||
def user_live_status(username):
|
||||
"""
|
||||
Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live)
|
||||
Returns a streamer's status, if they are live or not and their
|
||||
most recent stream (as a vod) (their current stream if live)
|
||||
Returns:
|
||||
{
|
||||
"is_live": bool,
|
||||
@@ -146,7 +156,7 @@ def user_live_status(username):
|
||||
user_id = get_user_id(username)
|
||||
|
||||
# Check if streamer is live and get their most recent vod
|
||||
is_live = True if get_streamer_live_status(user_id)['is_live'] else False
|
||||
is_live = bool(get_streamer_live_status(user_id)['is_live'])
|
||||
most_recent_vod = get_latest_vod(user_id)
|
||||
|
||||
# If there is no most recent vod, set it to None
|
||||
@@ -171,25 +181,24 @@ def user_live_status_direct(username):
|
||||
"""
|
||||
|
||||
user_id = get_user_id(username)
|
||||
is_live = True if get_streamer_live_status(user_id)['is_live'] else False
|
||||
is_live = bool(get_streamer_live_status(user_id)['is_live'])
|
||||
|
||||
if is_live:
|
||||
return 'ok', 200
|
||||
else:
|
||||
return 'not live', 404
|
||||
return 'not live', 404
|
||||
|
||||
|
||||
# VOD Routes
|
||||
@stream_bp.route('/vods/<int:vod_id>')
|
||||
def vod(vod_id):
|
||||
def single_vod(vod_id):
|
||||
"""
|
||||
Returns details about a specific vod
|
||||
"""
|
||||
vod = get_vod(vod_id)
|
||||
return jsonify(vod)
|
||||
vod_data = get_vod(vod_id)
|
||||
return jsonify(vod_data)
|
||||
|
||||
@stream_bp.route('/vods/<string:username>')
|
||||
def vods(username):
|
||||
def user_vods(username):
|
||||
"""
|
||||
Returns a JSON of all the vods of a streamer
|
||||
Returns:
|
||||
@@ -204,11 +213,11 @@ def vods(username):
|
||||
"views": int
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
"""
|
||||
user_id = get_user_id(username)
|
||||
vods = get_user_vods(user_id)
|
||||
return jsonify(vods)
|
||||
vod_list = get_user_vods(user_id)
|
||||
return jsonify(vod_list)
|
||||
|
||||
@stream_bp.route('/vods/all')
|
||||
def get_all_vods():
|
||||
@@ -216,9 +225,14 @@ def get_all_vods():
|
||||
Returns data of all VODs by all streamers in a JSON-compatible format
|
||||
"""
|
||||
with Database() as db:
|
||||
vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id;""")
|
||||
|
||||
return jsonify(vods)
|
||||
all_vods = db.fetchall(
|
||||
"""SELECT vods.*, username, category_name
|
||||
FROM vods
|
||||
JOIN users ON vods.user_id = users.user_id
|
||||
JOIN categories ON vods.category_id = categories.category_id;"""
|
||||
)
|
||||
|
||||
return jsonify(all_vods)
|
||||
|
||||
# RTMP Server Routes
|
||||
|
||||
@@ -232,9 +246,10 @@ def init_stream():
|
||||
|
||||
with Database() as db:
|
||||
# Check if valid stream key and user is allowed to stream
|
||||
user_info = db.fetchone("""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
user_info = db.fetchone(
|
||||
"""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
|
||||
# No user found from stream key
|
||||
if not user_info:
|
||||
@@ -280,25 +295,27 @@ def publish_stream():
|
||||
username = None
|
||||
|
||||
with Database() as db:
|
||||
user_info = db.fetchone("""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
user_info = db.fetchone(
|
||||
"""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
|
||||
if not user_info or user_info.get("is_live"):
|
||||
print(
|
||||
"Unauthorized. No user found from Stream key or user is already streaming.", flush=True)
|
||||
"Unauthorized. No user found from Stream key "
|
||||
"or user is already streaming.", flush=True)
|
||||
return "Unauthorized", 403
|
||||
|
||||
user_id = user_info.get("user_id")
|
||||
username = user_info.get("username")
|
||||
|
||||
# Insert stream into database
|
||||
db.execute("""INSERT INTO streams (user_id, title, start_time, num_viewers, category_id)
|
||||
VALUES (?, ?, ?, ?, ?)""", (user_id,
|
||||
stream_title,
|
||||
datetime.now(),
|
||||
0,
|
||||
get_category_id(stream_category)))
|
||||
db.execute(
|
||||
"""INSERT INTO streams
|
||||
(user_id, title, start_time, num_viewers, category_id)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(user_id, stream_title, datetime.now(), 0,
|
||||
get_category_id(stream_category)))
|
||||
|
||||
# Set user as streaming
|
||||
db.execute("""UPDATE users SET is_live = 1 WHERE user_id = ?""",
|
||||
@@ -306,10 +323,12 @@ def publish_stream():
|
||||
|
||||
# Update thumbnail periodically only if a custom thumbnail is not provided
|
||||
if not stream_thumbnail:
|
||||
update_thumbnail.apply_async((user_id,
|
||||
path_manager.get_stream_file_path(username),
|
||||
path_manager.get_current_stream_thumbnail_file_path(username),
|
||||
THUMBNAIL_GENERATION_INTERVAL), countdown=10)
|
||||
update_thumbnail.apply_async(
|
||||
(user_id,
|
||||
path_manager.get_stream_file_path(username),
|
||||
path_manager.get_current_stream_thumbnail_file_path(username),
|
||||
THUMBNAIL_GENERATION_INTERVAL),
|
||||
countdown=10)
|
||||
|
||||
return "OK", 200
|
||||
|
||||
@@ -332,31 +351,36 @@ def update_stream():
|
||||
|
||||
|
||||
with Database() as db:
|
||||
user_info = db.fetchone("""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
user_info = db.fetchone(
|
||||
"""SELECT user_id, username, is_live
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
|
||||
if not user_info or not user_info.get("is_live"):
|
||||
print(
|
||||
"Unauthorized - No user found from stream key or user is not streaming", flush=True)
|
||||
"Unauthorized - No user found from stream key "
|
||||
"or user is not streaming", flush=True)
|
||||
return "Unauthorized", 403
|
||||
|
||||
user_id = user_info.get("user_id")
|
||||
username = user_info.get("username")
|
||||
|
||||
# TODO: Add update to thumbnail here
|
||||
db.execute("""UPDATE streams
|
||||
SET title = ?, category_id = ?
|
||||
WHERE user_id = ?""", (stream_title, get_category_id(stream_category), user_id))
|
||||
|
||||
db.execute(
|
||||
"""UPDATE streams
|
||||
SET title = ?, category_id = ?
|
||||
WHERE user_id = ?""",
|
||||
(stream_title, get_category_id(stream_category), user_id))
|
||||
|
||||
print("GOT: " + stream_thumbnail, flush=True)
|
||||
|
||||
|
||||
if stream_thumbnail:
|
||||
# Set custom thumbnail status to true
|
||||
db.execute("""UPDATE streams
|
||||
SET custom_thumbnail = ?
|
||||
WHERE user_id = ?""", (True, user_id))
|
||||
|
||||
db.execute(
|
||||
"""UPDATE streams
|
||||
SET custom_thumbnail = ?
|
||||
WHERE user_id = ?""", (True, user_id))
|
||||
|
||||
# Get thumbnail path
|
||||
thumbnail_path = path_manager.get_current_stream_thumbnail_file_path(username)
|
||||
|
||||
@@ -390,26 +414,27 @@ def end_stream():
|
||||
|
||||
# Get user info from stream key
|
||||
with Database() as db:
|
||||
user_info = db.fetchone("""SELECT user_id, username
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
|
||||
user_info = db.fetchone(
|
||||
"""SELECT user_id, username
|
||||
FROM users
|
||||
WHERE stream_key = ?""", (stream_key,))
|
||||
|
||||
# Return unauthorized if no user found
|
||||
if not user_info:
|
||||
print("Unauthorized - No user found from stream key", flush=True)
|
||||
return "Unauthorized", 403
|
||||
|
||||
|
||||
# Get user info
|
||||
user_id = user_info["user_id"]
|
||||
username = user_info["username"]
|
||||
|
||||
|
||||
# End stream
|
||||
result, message = end_user_stream(stream_key, user_id, username)
|
||||
|
||||
|
||||
# Return error if stream could not be ended
|
||||
if not result:
|
||||
print(f"Error ending stream: {message}", flush=True)
|
||||
return "Error ending stream", 500
|
||||
|
||||
|
||||
print(f"Stream ended: {message}", flush=True)
|
||||
return "Stream ended", 200
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Stripe payment integration blueprint."""
|
||||
|
||||
import os
|
||||
|
||||
from flask import Blueprint, request, jsonify, session as s
|
||||
from blueprints.middleware import login_required
|
||||
from utils.user_utils import subscribe
|
||||
import os, stripe
|
||||
import stripe
|
||||
|
||||
stripe_bp = Blueprint("stripe", __name__)
|
||||
|
||||
@@ -20,7 +24,7 @@ def create_checkout_session():
|
||||
# Checks to see who is subscribing to who
|
||||
user_id = s.get("user_id")
|
||||
streamer_id = request.args.get("streamer_id")
|
||||
session = stripe.checkout.Session.create(
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
ui_mode = 'embedded',
|
||||
payment_method_types=['card'],
|
||||
line_items=[
|
||||
@@ -33,19 +37,24 @@ def create_checkout_session():
|
||||
redirect_on_completion = 'never',
|
||||
client_reference_id = f"{user_id}-{streamer_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
return str(e), 500
|
||||
|
||||
return jsonify(clientSecret=session.client_secret)
|
||||
return jsonify(clientSecret=checkout_session.client_secret)
|
||||
|
||||
@stripe_bp.route('/session-status') # check for payment status
|
||||
@stripe_bp.route('/session-status') # check for payment status
|
||||
def session_status():
|
||||
"""
|
||||
Used to query payment status
|
||||
"""
|
||||
session = stripe.checkout.Session.retrieve(request.args.get('session_id'))
|
||||
"""
|
||||
Used to query payment status
|
||||
"""
|
||||
checkout_session = stripe.checkout.Session.retrieve(
|
||||
request.args.get('session_id')
|
||||
)
|
||||
|
||||
return jsonify(status=session.status, customer_email=session.customer_details.email)
|
||||
return jsonify(
|
||||
status=checkout_session.status,
|
||||
customer_email=checkout_session.customer_details.email
|
||||
)
|
||||
|
||||
@stripe_bp.route('/stripe/webhook', methods=['POST'])
|
||||
def stripe_webhook():
|
||||
@@ -65,13 +74,20 @@ def stripe_webhook():
|
||||
raise e
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
raise e
|
||||
|
||||
if event['type'] == "checkout.session.completed": # Handles payment success webhook
|
||||
session = event['data']['object']
|
||||
product_id = stripe.checkout.Session.list_line_items(session['id'])['data'][0]['price']['product']
|
||||
|
||||
# Handles payment success webhook
|
||||
if event['type'] == "checkout.session.completed":
|
||||
checkout_session = event['data']['object']
|
||||
product_id = stripe.checkout.Session.list_line_items(
|
||||
checkout_session['id']
|
||||
)['data'][0]['price']['product']
|
||||
if product_id == subscription:
|
||||
client_reference_id = session.get("client_reference_id")
|
||||
user_id, streamer_id = map(int, client_reference_id.split("-"))
|
||||
client_reference_id = checkout_session.get(
|
||||
"client_reference_id"
|
||||
)
|
||||
user_id, streamer_id = map(
|
||||
int, client_reference_id.split("-")
|
||||
)
|
||||
subscribe(user_id, streamer_id)
|
||||
|
||||
return "Success", 200
|
||||
return "Success", 200
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
"""User profile and account management blueprint."""
|
||||
|
||||
from flask import Blueprint, jsonify, session, request
|
||||
from utils.user_utils import *
|
||||
from utils.auth import *
|
||||
from utils.user_utils import (
|
||||
get_user_id, get_user, update_bio, is_subscribed,
|
||||
subscription_expiration, delete_subscription, is_following,
|
||||
follow, unfollow, get_followed_streamers, get_followed_categories,
|
||||
is_following_category, follow_category, unfollow_category,
|
||||
has_password, get_session_info_email
|
||||
)
|
||||
from database.database import Database
|
||||
from utils.auth import verify_token, reset_password
|
||||
from utils.utils import get_category_id
|
||||
from blueprints.middleware import login_required
|
||||
from utils.email import send_email, forgot_password_body, newsletter_conf, remove_from_newsletter, email_exists
|
||||
from utils.email import (
|
||||
send_email, forgot_password_body, newsletter_conf,
|
||||
remove_from_newsletter, email_exists
|
||||
)
|
||||
from utils.path_manager import PathManager
|
||||
from celery_tasks.streaming import convert_image_to_png
|
||||
import redis
|
||||
|
||||
from PIL import Image
|
||||
|
||||
redis_url = "redis://redis:6379/1"
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
REDIS_URL = "redis://redis:6379/1"
|
||||
r = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
user_bp = Blueprint("user", __name__)
|
||||
|
||||
@@ -24,7 +35,7 @@ def user_data(username: str):
|
||||
"""
|
||||
user_id = get_user_id(username)
|
||||
if not user_id:
|
||||
jsonify({"error": "User not found from username"}), 404
|
||||
return jsonify({"error": "User not found from username"}), 404
|
||||
data = get_user(user_id)
|
||||
return jsonify(data)
|
||||
|
||||
@@ -48,11 +59,11 @@ def user_profile_picture_save():
|
||||
"""
|
||||
username = session.get("username")
|
||||
thumbnail_path = path_manager.get_profile_picture_file_path(username)
|
||||
|
||||
|
||||
# Check if the post request has the file part
|
||||
if 'image' not in request.files:
|
||||
return jsonify({"error": "No image found in request"}), 400
|
||||
|
||||
|
||||
# Fetch image, convert to png, and save
|
||||
image = Image.open(request.files['image'])
|
||||
image.convert('RGB')
|
||||
@@ -73,7 +84,7 @@ def user_change_bio():
|
||||
bio = data.get("bio")
|
||||
update_bio(user_id, bio)
|
||||
return jsonify({"status": "Success"}), 200
|
||||
except Exception as e:
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
|
||||
## Subscription Routes
|
||||
@@ -186,7 +197,7 @@ def user_login_status():
|
||||
"""
|
||||
username = session.get("username")
|
||||
user_id = session.get("user_id")
|
||||
return jsonify({'status': username is not None,
|
||||
return jsonify({'status': username is not None,
|
||||
'username': username,
|
||||
'user_id': user_id})
|
||||
|
||||
@@ -198,13 +209,14 @@ def user_forgot_password(email):
|
||||
exists = email_exists(email)
|
||||
password = has_password(email)
|
||||
# Checks if password exists and is not a Google OAuth account
|
||||
if(exists and password):
|
||||
if exists and password:
|
||||
send_email(email, lambda: forgot_password_body(email))
|
||||
return email
|
||||
return jsonify({"error":"Invalid email or not found"}), 404
|
||||
return jsonify({"error": "Invalid email or not found"}), 404
|
||||
|
||||
@user_bp.route("/send_newsletter/<string:email>", methods=["POST"])
|
||||
def send_newsletter(email):
|
||||
"""Sends a newsletter confirmation email."""
|
||||
send_email(email, lambda: newsletter_conf(email))
|
||||
return email
|
||||
|
||||
@@ -226,6 +238,7 @@ def user_reset_password(token, new_password):
|
||||
|
||||
@user_bp.route("/user/unsubscribe/<string:token>", methods=["POST"])
|
||||
def unsubscribe(token):
|
||||
"""Unsubscribes a user from the newsletter."""
|
||||
salt = r.get(token)
|
||||
if salt:
|
||||
r.delete(token)
|
||||
@@ -237,4 +250,4 @@ def unsubscribe(token):
|
||||
if email:
|
||||
remove_from_newsletter(email)
|
||||
return jsonify({"message": "unsubscribed from newsletter"}), 200
|
||||
return jsonify({"error": "Invalid token"}), 400
|
||||
return jsonify({"error": "Invalid token"}), 400
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from celery import Celery, shared_task, Task
|
||||
"""Celery configuration and Flask app context setup for async tasks."""
|
||||
|
||||
from celery import Celery, Task
|
||||
|
||||
|
||||
def celery_init_app(app) -> Celery:
|
||||
"""Initialize Celery with Flask application context."""
|
||||
class FlaskTask(Task):
|
||||
"""Celery task that runs within Flask app context."""
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
with app.app_context():
|
||||
return self.run(*args, **kwargs)
|
||||
@@ -14,6 +19,10 @@ def celery_init_app(app) -> Celery:
|
||||
'schedule': 30.0,
|
||||
},
|
||||
}
|
||||
celery_app.conf.include = [
|
||||
'celery_tasks.preferences',
|
||||
'celery_tasks.streaming',
|
||||
]
|
||||
celery_app.set_default()
|
||||
app.extensions["celery"] = celery_app
|
||||
return celery_app
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Celery app initialization with Flask."""
|
||||
|
||||
from blueprints import create_app
|
||||
|
||||
flask_app = create_app()
|
||||
celery_app = flask_app.extensions["celery"]
|
||||
celery_app = flask_app.extensions["celery"]
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Scheduled task for updating user preferences based on stream viewing."""
|
||||
|
||||
import json
|
||||
|
||||
from celery import shared_task
|
||||
from database.database import Database
|
||||
import redis
|
||||
import json
|
||||
|
||||
redis_url = "redis://redis:6379/1"
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
REDIS_URL = "redis://redis:6379/1"
|
||||
r = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
@shared_task
|
||||
def user_preferences():
|
||||
"""
|
||||
Updates users preferences on different stream categories based on the streams they are currently watching
|
||||
Updates users preferences on different stream categories
|
||||
based on the streams they are currently watching
|
||||
"""
|
||||
stats = r.hget("current_viewers", "viewers")
|
||||
# If there are any current viewers
|
||||
@@ -21,13 +25,19 @@ def user_preferences():
|
||||
# For each user and stream combination
|
||||
for stream_id in stream_ids:
|
||||
# Retrieves category associated with stream
|
||||
current_category = db.fetchone("""SELECT category_id FROM streams
|
||||
WHERE user_id = ?
|
||||
""", (stream_id,))
|
||||
# If stream is still live then update the user_preferences table to reflect their preferences
|
||||
current_category = db.fetchone(
|
||||
"""SELECT category_id FROM streams
|
||||
WHERE user_id = ?
|
||||
""", (stream_id,))
|
||||
# If stream is still live then update the
|
||||
# user_preferences table to reflect their preferences
|
||||
if current_category:
|
||||
db.execute("""INSERT INTO user_preferences (user_id,category_id,favourability)
|
||||
VALUES (?,?,?)
|
||||
ON CONFLICT(user_id, category_id)
|
||||
DO UPDATE SET favourability = favourability + 1
|
||||
""", (user_id, current_category["category_id"], 1))
|
||||
db.execute(
|
||||
"""INSERT INTO user_preferences
|
||||
(user_id,category_id,favourability)
|
||||
VALUES (?,?,?)
|
||||
ON CONFLICT(user_id, category_id)
|
||||
DO UPDATE SET
|
||||
favourability = favourability + 1
|
||||
""", (user_id,
|
||||
current_category["category_id"], 1))
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from celery import Celery, shared_task, Task
|
||||
from datetime import datetime
|
||||
from celery_tasks.preferences import user_preferences
|
||||
from utils.stream_utils import generate_thumbnail, get_streamer_live_status, get_custom_thumbnail_status, remove_hls_files, get_video_duration
|
||||
from time import sleep
|
||||
from os import listdir, remove
|
||||
from utils.path_manager import PathManager
|
||||
"""Async tasks for stream thumbnail updates, VOD creation, and image conversion."""
|
||||
|
||||
import subprocess
|
||||
from os import listdir
|
||||
|
||||
from celery import shared_task
|
||||
from utils.stream_utils import (
|
||||
generate_thumbnail, get_streamer_live_status,
|
||||
get_custom_thumbnail_status, remove_hls_files, get_video_duration
|
||||
)
|
||||
from utils.path_manager import PathManager
|
||||
|
||||
path_manager = PathManager()
|
||||
|
||||
@@ -16,10 +19,13 @@ def update_thumbnail(user_id, stream_file, thumbnail_file, sleep_time, second_ca
|
||||
"""
|
||||
|
||||
# Check if stream is still live and custom thumbnail has not been set
|
||||
if get_streamer_live_status(user_id)['is_live'] and not get_custom_thumbnail_status(user_id)['custom_thumbnail']:
|
||||
if (get_streamer_live_status(user_id)['is_live']
|
||||
and not get_custom_thumbnail_status(user_id)['custom_thumbnail']):
|
||||
print("Updating thumbnail...")
|
||||
generate_thumbnail(stream_file, thumbnail_file)
|
||||
update_thumbnail.apply_async((user_id, stream_file, thumbnail_file, sleep_time, second_capture), countdown=sleep_time)
|
||||
update_thumbnail.apply_async(
|
||||
(user_id, stream_file, thumbnail_file, sleep_time, second_capture),
|
||||
countdown=sleep_time)
|
||||
else:
|
||||
print(f"Stopping thumbnail updates for stream of {user_id}")
|
||||
|
||||
@@ -34,12 +40,12 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) ->
|
||||
ts_files.sort()
|
||||
|
||||
# Create temp file listing all ts files
|
||||
with open(f"{stream_path}/list.txt", "w") as f:
|
||||
with open(f"{stream_path}/list.txt", "w", encoding="utf-8") as f:
|
||||
for ts_file in ts_files:
|
||||
f.write(f"file '{ts_file}'\n")
|
||||
|
||||
|
||||
# Concatenate all ts files into a single vod
|
||||
|
||||
|
||||
vod_command = [
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
@@ -53,7 +59,7 @@ def combine_ts_stream(stream_path, vods_path, vod_file_name, thumbnail_file) ->
|
||||
vod_file_path
|
||||
]
|
||||
|
||||
subprocess.run(vod_command)
|
||||
subprocess.run(vod_command, check=True)
|
||||
|
||||
# Remove HLS files, even if user is not streaming
|
||||
remove_hls_files(stream_path)
|
||||
@@ -78,4 +84,4 @@ def convert_image_to_png(image_path, png_path):
|
||||
png_path
|
||||
]
|
||||
|
||||
subprocess.run(image_command)
|
||||
subprocess.run(image_command, check=True)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""SQLite database connection management with context manager support."""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
|
||||
class Database:
|
||||
"""Database wrapper providing connection management and query execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._db = os.path.join(os.path.abspath(os.path.dirname(__file__)), "app.db")
|
||||
self._conn = None
|
||||
@@ -63,4 +68,6 @@ class Database:
|
||||
if not result:
|
||||
return []
|
||||
columns = [desc[0] for desc in self.cursor.description]
|
||||
return [dict(zip(columns, row)) for row in result] if isinstance(result, list) else dict(zip(columns, result))
|
||||
if isinstance(result, list):
|
||||
return [dict(zip(columns, row)) for row in result]
|
||||
return dict(zip(columns, result))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Admin utility functions for user management."""
|
||||
|
||||
from database.database import Database
|
||||
|
||||
def check_if_admin(username):
|
||||
"""
|
||||
Returns whether user is admin
|
||||
Returns whether user is admin
|
||||
"""
|
||||
with Database() as db:
|
||||
is_admin = db.fetchone("""
|
||||
@@ -10,7 +12,7 @@ def check_if_admin(username):
|
||||
FROM users
|
||||
WHERE username = ?;
|
||||
""", (username,))
|
||||
|
||||
|
||||
return bool(is_admin)
|
||||
|
||||
def check_if_user_exists(banned_user):
|
||||
@@ -19,11 +21,11 @@ def check_if_user_exists(banned_user):
|
||||
"""
|
||||
with Database() as db:
|
||||
user_exists = db.fetchone("""
|
||||
SELECT user_id
|
||||
SELECT user_id
|
||||
FROM users
|
||||
WHERE username = ?;""",
|
||||
(banned_user,))
|
||||
|
||||
|
||||
return bool(user_exists)
|
||||
|
||||
def ban_user(banned_user):
|
||||
@@ -33,6 +35,5 @@ def ban_user(banned_user):
|
||||
with Database() as db:
|
||||
db.execute("""
|
||||
DELETE FROM users
|
||||
WHERE username = ?;""",
|
||||
(banned_user)
|
||||
)
|
||||
WHERE username = ?;""",
|
||||
(banned_user,))
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Token generation and verification for password resets."""
|
||||
|
||||
from typing import Optional
|
||||
from os import getenv
|
||||
|
||||
from database.database import Database
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
from os import getenv
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
load_dotenv()
|
||||
|
||||
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY"))
|
||||
|
||||
def generate_token(email, salt_value) -> str:
|
||||
"""
|
||||
Creates a token for password reset
|
||||
@@ -19,7 +23,6 @@ def verify_token(token: str, salt_value) -> Optional[str]:
|
||||
"""
|
||||
Given a token, verifies and decodes it into an email
|
||||
"""
|
||||
|
||||
try:
|
||||
email = serializer.loads(token, salt=salt_value, max_age=3600)
|
||||
return email
|
||||
@@ -38,7 +41,7 @@ def reset_password(new_password: str, email: str):
|
||||
"""
|
||||
with Database() as db:
|
||||
db.execute("""
|
||||
UPDATE users
|
||||
SET password = ?
|
||||
UPDATE users
|
||||
SET password = ?
|
||||
WHERE email = ?
|
||||
""", (generate_password_hash(new_password), email))
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""Email sending utilities for password reset, account confirmation, and newsletters."""
|
||||
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from os import getenv
|
||||
from secrets import token_hex
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from utils.auth import generate_token
|
||||
from secrets import token_hex
|
||||
from .user_utils import get_session_info_email
|
||||
import redis
|
||||
from database.database import Database
|
||||
|
||||
redis_url = "redis://redis:6379/1"
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
REDIS_URL = "redis://redis:6379/1"
|
||||
r = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -23,31 +25,31 @@ def send_email(email, func) -> None:
|
||||
"""
|
||||
|
||||
# Setup the sender email details
|
||||
SMTP_SERVER = "smtp.gmail.com"
|
||||
SMTP_PORT = 587
|
||||
SMTP_EMAIL = getenv("EMAIL")
|
||||
SMTP_PASSWORD = getenv("EMAIL_PASSWORD")
|
||||
smtp_server = "smtp.gmail.com"
|
||||
smtp_port = 587
|
||||
smtp_email = getenv("EMAIL")
|
||||
smtp_password = getenv("EMAIL_PASSWORD")
|
||||
|
||||
# Setup up the receiver details
|
||||
body, subject = func()
|
||||
|
||||
msg = MIMEText(body, "html")
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = SMTP_EMAIL
|
||||
msg["From"] = smtp_email
|
||||
msg["To"] = email
|
||||
|
||||
# Send the email using smtplib
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp:
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as smtp:
|
||||
try:
|
||||
smtp.starttls() # TLS handshake to start the connection
|
||||
smtp.login(SMTP_EMAIL, SMTP_PASSWORD)
|
||||
smtp.login(smtp_email, smtp_password)
|
||||
smtp.ehlo()
|
||||
smtp.send_message(msg)
|
||||
|
||||
|
||||
except TimeoutError:
|
||||
print("Server timed out", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
except smtplib.SMTPException as e:
|
||||
print("Error: ", e, flush=True)
|
||||
|
||||
def forgot_password_body(email) -> str:
|
||||
@@ -57,7 +59,7 @@ def forgot_password_body(email) -> str:
|
||||
salt = token_hex(32)
|
||||
|
||||
token = generate_token(email, salt)
|
||||
token += "R3sET"
|
||||
token += "R3sET"
|
||||
r.setex(token, 3600, salt)
|
||||
username = (get_session_info_email(email))["username"]
|
||||
|
||||
@@ -77,7 +79,7 @@ def forgot_password_body(email) -> str:
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Gander</h1>
|
||||
<h1>Gander</h1>
|
||||
<h2>Password Reset Request</h2>
|
||||
<p>Click the button below to reset your password for your account {username}. This link is valid for 1 hour.</p>
|
||||
<a href="{full_url}" class="btn">Reset Password</a>
|
||||
@@ -116,7 +118,7 @@ def confirm_account_creation_body(email) -> str:
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Gander</h1>
|
||||
<h1>Gander</h1>
|
||||
<h2>Confirm Account Creation</h2>
|
||||
<p>Click the button below to create your account. This link is valid for 1 hour.</p>
|
||||
<a href="{full_url}" class="btn">Create Account</a>
|
||||
@@ -154,7 +156,7 @@ def newsletter_conf(email):
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Gander</h1>
|
||||
<h1>Gander</h1>
|
||||
<h2>Welcome to the Official Gander Newsletter!</h2>
|
||||
<p>If you are receiving this email, it means that you have been officially added to the Monthly Gander newsletter.</p>
|
||||
<p>In this newsletter, you will receive updates about: your favourite streamers; important Gander updates; and more!</p>
|
||||
@@ -169,7 +171,7 @@ def newsletter_conf(email):
|
||||
user_exists = db.fetchone("""
|
||||
SELECT *
|
||||
FROM newsletter
|
||||
WHERE email = ?;""",
|
||||
WHERE email = ?;""",
|
||||
(email,))
|
||||
print(user_exists, flush=True)
|
||||
|
||||
@@ -190,7 +192,7 @@ def add_to_newsletter(email):
|
||||
INSERT INTO newsletter (email)
|
||||
VALUES (?);
|
||||
""", (email,))
|
||||
|
||||
|
||||
def remove_from_newsletter(email):
|
||||
"""
|
||||
Remove a person from the newsletter database
|
||||
@@ -202,7 +204,7 @@ def remove_from_newsletter(email):
|
||||
DELETE FROM newsletter
|
||||
WHERE email = ?;
|
||||
""", (email,))
|
||||
|
||||
|
||||
def email_exists(email):
|
||||
"""
|
||||
Returns whether email exists within database
|
||||
@@ -210,6 +212,6 @@ def email_exists(email):
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT * FROM users
|
||||
WHERE email = ?
|
||||
WHERE email = ?
|
||||
""", (email,))
|
||||
return bool(data)
|
||||
return bool(data)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Description: This file contains the PathManager class which is responsible for managing the paths of the stream data."""
|
||||
"""File system path management for user streams, VODs, and profile pictures."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class PathManager():
|
||||
"""Manages paths for user stream data, VODs, and profile pictures."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.root_path = "user_data"
|
||||
self.vod_directory_name = "vods"
|
||||
@@ -26,7 +29,7 @@ class PathManager():
|
||||
self._create_directory(os.path.join(self.root_path, username))
|
||||
vods_path = self.get_vods_path(username)
|
||||
stream_path = self.get_stream_path(username)
|
||||
|
||||
|
||||
self._create_directory(vods_path)
|
||||
self._create_directory(stream_path)
|
||||
|
||||
@@ -39,25 +42,33 @@ class PathManager():
|
||||
os.rmdir(user_path)
|
||||
|
||||
def get_user_path(self, username):
|
||||
"""Returns the base path for a user's data directory."""
|
||||
return os.path.join(self.root_path, username)
|
||||
|
||||
def get_vods_path(self, username):
|
||||
"""Returns the path to a user's VODs directory."""
|
||||
return os.path.join(self.root_path, username, self.vod_directory_name)
|
||||
|
||||
|
||||
def get_stream_path(self, username):
|
||||
"""Returns the path to a user's stream directory."""
|
||||
return os.path.join(self.root_path, username, self.stream_directory_name)
|
||||
|
||||
|
||||
def get_stream_file_path(self, username):
|
||||
"""Returns the path to a user's stream index file."""
|
||||
return os.path.join(self.get_stream_path(username), "index.m3u8")
|
||||
|
||||
|
||||
def get_current_stream_thumbnail_file_path(self, username):
|
||||
"""Returns the path to a user's current stream thumbnail."""
|
||||
return os.path.join(self.get_stream_path(username), "index.png")
|
||||
|
||||
|
||||
def get_vod_file_path(self, username, vod_id):
|
||||
"""Returns the path to a specific VOD file."""
|
||||
return os.path.join(self.get_vods_path(username), f"{vod_id}.mp4")
|
||||
|
||||
|
||||
def get_vod_thumbnail_file_path(self, username, vod_id):
|
||||
"""Returns the path to a specific VOD thumbnail."""
|
||||
return os.path.join(self.get_vods_path(username), f"{vod_id}.png")
|
||||
|
||||
|
||||
def get_profile_picture_file_path(self, username):
|
||||
return os.path.join(self.root_path, username, self.profile_picture_name)
|
||||
"""Returns the path to a user's profile picture."""
|
||||
return os.path.join(self.root_path, username, self.profile_picture_name)
|
||||
|
||||
@@ -1,32 +1,42 @@
|
||||
from database.database import Database
|
||||
"""Personalized stream and category recommendations based on user preferences."""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from database.database import Database
|
||||
|
||||
def get_user_preferred_category(user_id: int) -> Optional[int]:
|
||||
"""
|
||||
Queries user_preferences database to find users favourite streaming category and returns the category
|
||||
Queries user_preferences database to find users favourite
|
||||
streaming category and returns the category
|
||||
"""
|
||||
with Database() as db:
|
||||
category = db.fetchone("""
|
||||
SELECT category_id
|
||||
FROM user_preferences
|
||||
WHERE user_id = ?
|
||||
ORDER BY favourability DESC
|
||||
SELECT category_id
|
||||
FROM user_preferences
|
||||
WHERE user_id = ?
|
||||
ORDER BY favourability DESC
|
||||
LIMIT 1
|
||||
""", (user_id,))
|
||||
return category["category_id"] if category else None
|
||||
|
||||
|
||||
def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]:
|
||||
def get_followed_categories_recommendations(
|
||||
user_id: int, no_streams: int = 4
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
Returns top streams given a user's category following
|
||||
"""
|
||||
with Database() as db:
|
||||
streams = db.fetchall("""
|
||||
SELECT u.user_id, title, u.username, num_viewers, category_name
|
||||
FROM streams
|
||||
SELECT u.user_id, title, u.username, num_viewers,
|
||||
category_name
|
||||
FROM streams
|
||||
JOIN users u ON streams.user_id = u.user_id
|
||||
JOIN categories ON streams.category_id = categories.category_id
|
||||
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
|
||||
JOIN categories
|
||||
ON streams.category_id = categories.category_id
|
||||
WHERE categories.category_id IN
|
||||
(SELECT category_id
|
||||
FROM followed_categories WHERE user_id = ?)
|
||||
ORDER BY num_viewers DESC
|
||||
LIMIT ?;
|
||||
""", (user_id, no_streams))
|
||||
@@ -40,25 +50,29 @@ def get_followed_your_categories(user_id: int) -> Optional[List[dict]]:
|
||||
categories = db.fetchall("""
|
||||
SELECT categories.category_name
|
||||
FROM categories
|
||||
JOIN followed_categories
|
||||
JOIN followed_categories
|
||||
ON categories.category_id = followed_categories.category_id
|
||||
WHERE followed_categories.user_id = ?;
|
||||
""", (user_id,))
|
||||
return categories
|
||||
|
||||
|
||||
def get_streams_based_on_category(category_id: int, no_streams: int = 4, offset: int = 0) -> Optional[List[dict]]:
|
||||
def get_streams_based_on_category(
|
||||
category_id: int, no_streams: int = 4, offset: int = 0
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
Queries stream database to get top most viewed streams based on given category
|
||||
Queries stream database to get top most viewed streams
|
||||
based on given category
|
||||
"""
|
||||
with Database() as db:
|
||||
streams = db.fetchall("""
|
||||
SELECT u.user_id, title, username, num_viewers, c.category_name
|
||||
SELECT u.user_id, title, username, num_viewers,
|
||||
c.category_name
|
||||
FROM streams s
|
||||
JOIN users u ON s.user_id = u.user_id
|
||||
JOIN categories c ON s.category_id = c.category_id
|
||||
WHERE c.category_id = ?
|
||||
ORDER BY num_viewers DESC
|
||||
WHERE c.category_id = ?
|
||||
ORDER BY num_viewers DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""", (category_id, no_streams, offset))
|
||||
return streams
|
||||
@@ -70,43 +84,63 @@ def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]:
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchall("""
|
||||
SELECT u.user_id, username, title, num_viewers, category_name
|
||||
FROM streams
|
||||
SELECT u.user_id, username, title, num_viewers,
|
||||
category_name
|
||||
FROM streams
|
||||
JOIN users u ON streams.user_id = u.user_id
|
||||
JOIN categories ON streams.category_id = categories.category_id
|
||||
ORDER BY num_viewers DESC
|
||||
JOIN categories
|
||||
ON streams.category_id = categories.category_id
|
||||
ORDER BY num_viewers DESC
|
||||
LIMIT ?;
|
||||
""", (no_streams,))
|
||||
return data
|
||||
|
||||
def get_highest_view_categories(no_categories: int = 4, offset: int = 0) -> Optional[List[dict]]:
|
||||
def get_highest_view_categories(
|
||||
no_categories: int = 4, offset: int = 0
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
Returns a list of top most popular categories given offset
|
||||
"""
|
||||
with Database() as db:
|
||||
categories = db.fetchall("""
|
||||
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
|
||||
SELECT categories.category_id,
|
||||
categories.category_name,
|
||||
COALESCE(SUM(streams.num_viewers), 0)
|
||||
AS num_viewers
|
||||
FROM categories
|
||||
LEFT JOIN streams ON streams.category_id = categories.category_id
|
||||
GROUP BY categories.category_id, categories.category_name
|
||||
LEFT JOIN streams
|
||||
ON streams.category_id = categories.category_id
|
||||
GROUP BY categories.category_id,
|
||||
categories.category_name
|
||||
ORDER BY num_viewers DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""", (no_categories, offset))
|
||||
return categories
|
||||
|
||||
def get_user_category_recommendations(user_id = 1, no_categories: int = 4) -> Optional[List[dict]]:
|
||||
def get_user_category_recommendations(
|
||||
user_id=1, no_categories: int = 4
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
Queries user_preferences database to find users top favourite streaming category and returns the category
|
||||
Queries user_preferences database to find users top favourite
|
||||
streaming category and returns the category
|
||||
"""
|
||||
with Database() as db:
|
||||
categories = db.fetchall("""
|
||||
SELECT categories.category_id, categories.category_name, COALESCE(SUM(streams.num_viewers), 0) AS num_viewers
|
||||
FROM categories
|
||||
JOIN user_preferences ON categories.category_id = user_preferences.category_id
|
||||
LEFT JOIN streams ON categories.category_id = streams.category_id
|
||||
WHERE user_preferences.user_id = ?
|
||||
GROUP BY categories.category_id, categories.category_name
|
||||
ORDER BY user_preferences.favourability DESC
|
||||
SELECT categories.category_id,
|
||||
categories.category_name,
|
||||
COALESCE(SUM(streams.num_viewers), 0)
|
||||
AS num_viewers
|
||||
FROM categories
|
||||
JOIN user_preferences
|
||||
ON categories.category_id
|
||||
= user_preferences.category_id
|
||||
LEFT JOIN streams
|
||||
ON categories.category_id
|
||||
= streams.category_id
|
||||
WHERE user_preferences.user_id = ?
|
||||
GROUP BY categories.category_id,
|
||||
categories.category_name
|
||||
ORDER BY user_preferences.favourability DESC
|
||||
LIMIT ?
|
||||
""", (user_id, no_categories))
|
||||
return categories
|
||||
return categories
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from database.database import Database
|
||||
from typing import Optional
|
||||
import os, subprocess
|
||||
"""Stream data retrieval and management utilities."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional, List
|
||||
from time import sleep
|
||||
from utils.path_manager import PathManager
|
||||
|
||||
from database.database import Database
|
||||
|
||||
|
||||
def get_streamer_live_status(user_id: int):
|
||||
"""
|
||||
@@ -11,8 +13,8 @@ def get_streamer_live_status(user_id: int):
|
||||
"""
|
||||
with Database() as db:
|
||||
is_live = db.fetchone("""
|
||||
SELECT is_live
|
||||
FROM users
|
||||
SELECT is_live
|
||||
FROM users
|
||||
WHERE user_id = ?;
|
||||
""", (user_id,))
|
||||
|
||||
@@ -21,17 +23,19 @@ def get_streamer_live_status(user_id: int):
|
||||
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
|
||||
"""
|
||||
Searches for streamers who the user followed which are currently live
|
||||
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
|
||||
Returns a list of live streams with the streamer's user id,
|
||||
stream title, and number of viewers
|
||||
"""
|
||||
with Database() as db:
|
||||
live_streams = db.fetchall("""
|
||||
SELECT users.user_id, streams.title, streams.num_viewers, users.username
|
||||
FROM streams JOIN users
|
||||
ON streams.user_id = users.user_id
|
||||
WHERE users.user_id IN
|
||||
(SELECT followed_id FROM follows WHERE user_id = ?)
|
||||
AND users.is_live = 1;
|
||||
""", (user_id,))
|
||||
SELECT users.user_id, streams.title,
|
||||
streams.num_viewers, users.username
|
||||
FROM streams JOIN users
|
||||
ON streams.user_id = users.user_id
|
||||
WHERE users.user_id IN
|
||||
(SELECT followed_id FROM follows WHERE user_id = ?)
|
||||
AND users.is_live = 1;
|
||||
""", (user_id,))
|
||||
return live_streams
|
||||
|
||||
def get_current_stream_data(user_id: int) -> Optional[dict]:
|
||||
@@ -40,7 +44,8 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
|
||||
"""
|
||||
with Database() as db:
|
||||
most_recent_stream = db.fetchone("""
|
||||
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name, c.category_id
|
||||
SELECT s.user_id, u.username, s.title, s.start_time,
|
||||
s.num_viewers, c.category_name, c.category_id
|
||||
FROM streams AS s
|
||||
JOIN categories AS c ON s.category_id = c.category_id
|
||||
JOIN users AS u ON s.user_id = u.user_id
|
||||
@@ -51,78 +56,92 @@ def get_current_stream_data(user_id: int) -> Optional[dict]:
|
||||
def end_user_stream(stream_key, user_id, username):
|
||||
"""
|
||||
Utility function to end a user's stream
|
||||
|
||||
|
||||
Parameters:
|
||||
stream_key: The stream key of the user
|
||||
user_id: The ID of the user
|
||||
username: The username of the user
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if stream was ended successfully, False otherwise
|
||||
"""
|
||||
from flask import current_app
|
||||
from datetime import datetime
|
||||
from dateutil import parser
|
||||
from celery_tasks.streaming import combine_ts_stream
|
||||
from utils.path_manager import PathManager
|
||||
|
||||
path_manager = PathManager()
|
||||
|
||||
pm = PathManager()
|
||||
print(f"Ending stream for user {username} (ID: {user_id})", flush=True)
|
||||
|
||||
|
||||
if not stream_key or not user_id or not username:
|
||||
print("Cannot end stream - missing required information", flush=True)
|
||||
return False
|
||||
|
||||
return False, "Missing required information"
|
||||
|
||||
try:
|
||||
# Open database connection
|
||||
with Database() as db:
|
||||
# Get stream info
|
||||
stream_info = db.fetchone("""SELECT *
|
||||
FROM streams
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
stream_info = db.fetchone(
|
||||
"""SELECT *
|
||||
FROM streams
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
# If user is not streaming, just return
|
||||
if not stream_info:
|
||||
print(f"User {username} (ID: {user_id}) is not streaming", flush=True)
|
||||
print(
|
||||
f"User {username} (ID: {user_id}) "
|
||||
"is not streaming", flush=True)
|
||||
return True, "User is not streaming"
|
||||
|
||||
|
||||
# Remove stream from database
|
||||
db.execute("""DELETE FROM streams
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
db.execute(
|
||||
"""DELETE FROM streams
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
# Move stream to vod table
|
||||
stream_length = int(
|
||||
(datetime.now() - parser.parse(stream_info.get("start_time"))).total_seconds())
|
||||
|
||||
db.execute("""INSERT INTO vods (user_id, title, datetime, category_id, length, views)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""", (user_id,
|
||||
stream_info.get("title"),
|
||||
stream_info.get("start_time"),
|
||||
stream_info.get("category_id"),
|
||||
stream_length,
|
||||
0))
|
||||
|
||||
(datetime.now() - parser.parse(
|
||||
stream_info.get("start_time")
|
||||
)).total_seconds())
|
||||
|
||||
db.execute(
|
||||
"""INSERT INTO vods
|
||||
(user_id, title, datetime, category_id,
|
||||
length, views)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(user_id,
|
||||
stream_info.get("title"),
|
||||
stream_info.get("start_time"),
|
||||
stream_info.get("category_id"),
|
||||
stream_length,
|
||||
0))
|
||||
|
||||
vod_id = db.get_last_insert_id()
|
||||
|
||||
|
||||
# Set user as not streaming
|
||||
db.execute("""UPDATE users
|
||||
SET is_live = 0
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
db.execute(
|
||||
"""UPDATE users
|
||||
SET is_live = 0
|
||||
WHERE user_id = ?""", (user_id,))
|
||||
|
||||
# Queue task to combine TS files into MP4
|
||||
combine_ts_stream.delay(
|
||||
path_manager.get_stream_path(username),
|
||||
path_manager.get_vods_path(username),
|
||||
pm.get_stream_path(username),
|
||||
pm.get_vods_path(username),
|
||||
vod_id,
|
||||
path_manager.get_vod_thumbnail_file_path(username, vod_id)
|
||||
pm.get_vod_thumbnail_file_path(username, vod_id)
|
||||
)
|
||||
|
||||
print(f"Stream ended for user {username} (ID: {user_id})", flush=True)
|
||||
|
||||
print(
|
||||
f"Stream ended for user {username} (ID: {user_id})",
|
||||
flush=True)
|
||||
return True, "Stream ended successfully"
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error ending stream for user {username}: {str(e)}", flush=True)
|
||||
return False, f"Error ending stream: {str(e)}"
|
||||
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
print(
|
||||
f"Error ending stream for user {username}: {str(exc)}",
|
||||
flush=True)
|
||||
return False, f"Error ending stream: {str(exc)}"
|
||||
|
||||
def get_category_id(category_name: str) -> Optional[int]:
|
||||
"""
|
||||
@@ -130,8 +149,8 @@ def get_category_id(category_name: str) -> Optional[int]:
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT category_id
|
||||
FROM categories
|
||||
SELECT category_id
|
||||
FROM categories
|
||||
WHERE category_name = ?;
|
||||
""", (category_name,))
|
||||
return data['category_id'] if data else None
|
||||
@@ -142,7 +161,7 @@ def get_custom_thumbnail_status(user_id: int) -> Optional[dict]:
|
||||
"""
|
||||
with Database() as db:
|
||||
custom_thumbnail = db.fetchone("""
|
||||
SELECT custom_thumbnail
|
||||
SELECT custom_thumbnail
|
||||
FROM streams
|
||||
WHERE user_id = ?;
|
||||
""", (user_id,))
|
||||
@@ -153,7 +172,13 @@ def get_vod(vod_id: int) -> dict:
|
||||
Returns data of a streamers vod
|
||||
"""
|
||||
with Database() as db:
|
||||
vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vod_id = ?;""", (vod_id,))
|
||||
vod = db.fetchone(
|
||||
"""SELECT vods.*, username, category_name
|
||||
FROM vods
|
||||
JOIN users ON vods.user_id = users.user_id
|
||||
JOIN categories
|
||||
ON vods.category_id = categories.category_id
|
||||
WHERE vod_id = ?;""", (vod_id,))
|
||||
return vod
|
||||
|
||||
def get_latest_vod(user_id: int):
|
||||
@@ -161,7 +186,14 @@ def get_latest_vod(user_id: int):
|
||||
Returns data of the most recent stream by a streamer
|
||||
"""
|
||||
with Database() as db:
|
||||
latest_vod = db.fetchone("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
|
||||
latest_vod = db.fetchone(
|
||||
"""SELECT vods.*, username, category_name
|
||||
FROM vods
|
||||
JOIN users ON vods.user_id = users.user_id
|
||||
JOIN categories
|
||||
ON vods.category_id = categories.category_id
|
||||
WHERE vods.user_id = ?
|
||||
ORDER BY vod_id DESC;""", (user_id,))
|
||||
return latest_vod
|
||||
|
||||
def get_user_vods(user_id: int):
|
||||
@@ -169,10 +201,17 @@ def get_user_vods(user_id: int):
|
||||
Returns data of all vods by a streamer
|
||||
"""
|
||||
with Database() as db:
|
||||
vods = db.fetchall("""SELECT vods.*, username, category_name FROM vods JOIN users ON vods.user_id = users.user_id JOIN categories ON vods.category_id = categories.category_id WHERE vods.user_id = ? ORDER BY vod_id DESC;""", (user_id,))
|
||||
vods = db.fetchall(
|
||||
"""SELECT vods.*, username, category_name
|
||||
FROM vods
|
||||
JOIN users ON vods.user_id = users.user_id
|
||||
JOIN categories
|
||||
ON vods.category_id = categories.category_id
|
||||
WHERE vods.user_id = ?
|
||||
ORDER BY vod_id DESC;""", (user_id,))
|
||||
return vods
|
||||
|
||||
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) -> None:
|
||||
def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture=0) -> None:
|
||||
"""
|
||||
Generates the thumbnail of a stream
|
||||
"""
|
||||
@@ -192,10 +231,16 @@ def generate_thumbnail(stream_file: str, thumbnail_file: str, second_capture) ->
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(thumbnail_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
|
||||
subprocess.run(
|
||||
thumbnail_command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True)
|
||||
print(f"Thumbnail generated for {stream_file}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"No information available for {stream_file}, aborting thumbnail generation")
|
||||
except subprocess.CalledProcessError:
|
||||
print(
|
||||
f"No information available for {stream_file}, "
|
||||
"aborting thumbnail generation")
|
||||
|
||||
def remove_hls_files(stream_path: str) -> None:
|
||||
"""
|
||||
@@ -221,11 +266,13 @@ def get_video_duration(video_path: str) -> int:
|
||||
]
|
||||
|
||||
try:
|
||||
video_length = subprocess.check_output(video_length_command).decode("utf-8")
|
||||
video_length = subprocess.check_output(
|
||||
video_length_command
|
||||
).decode("utf-8")
|
||||
print(f"Video length: {video_length}")
|
||||
return int(float(video_length))
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting video length: {e}")
|
||||
except subprocess.CalledProcessError:
|
||||
print("Error getting video length")
|
||||
return 0
|
||||
|
||||
def get_stream_tags(user_id: int) -> Optional[List[str]]:
|
||||
@@ -234,10 +281,10 @@ def get_stream_tags(user_id: int) -> Optional[List[str]]:
|
||||
"""
|
||||
with Database() as db:
|
||||
tags = db.fetchall("""
|
||||
SELECT tag_name
|
||||
SELECT tag_name
|
||||
FROM tags
|
||||
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
|
||||
WHERE user_id = ?;
|
||||
WHERE user_id = ?;
|
||||
""", (user_id,))
|
||||
return tags
|
||||
|
||||
@@ -247,9 +294,9 @@ def get_vod_tags(vod_id: int):
|
||||
"""
|
||||
with Database() as db:
|
||||
tags = db.fetchall("""
|
||||
SELECT tag_name
|
||||
SELECT tag_name
|
||||
FROM tags
|
||||
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
|
||||
WHERE vod_id = ?;
|
||||
WHERE vod_id = ?;
|
||||
""", (vod_id,))
|
||||
return tags
|
||||
return tags
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from database.database import Database
|
||||
"""User profile management, following, and subscription utilities."""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from database.database import Database
|
||||
from dateutil import parser
|
||||
|
||||
|
||||
def get_user_id(username: str) -> Optional[int]:
|
||||
"""
|
||||
Returns user_id associated with given username
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT user_id
|
||||
FROM users
|
||||
SELECT user_id
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
return data['user_id'] if data else None
|
||||
@@ -21,7 +25,7 @@ def get_username(user_id: str) -> Optional[str]:
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT username
|
||||
SELECT username
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
""", (user_id,))
|
||||
@@ -40,15 +44,16 @@ def update_bio(user_id: int, bio: str):
|
||||
|
||||
def has_password(email: str):
|
||||
"""
|
||||
Returns if account associated with this email has password, i.e is account from Google OAuth
|
||||
Returns if account associated with this email has password,
|
||||
i.e is account from Google OAuth
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT password
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
WHERE email = ?
|
||||
""", (email,))
|
||||
return False if data["password"] == None else True
|
||||
return data["password"] is not None
|
||||
|
||||
def get_session_info_email(email: str) -> dict:
|
||||
"""
|
||||
@@ -61,15 +66,15 @@ def get_session_info_email(email: str) -> dict:
|
||||
WHERE email = ?
|
||||
""", (email,))
|
||||
return session_info
|
||||
|
||||
|
||||
def is_user_partner(user_id: int) -> bool:
|
||||
"""
|
||||
Returns True if user is a partner, else False
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT is_partnered
|
||||
FROM users
|
||||
SELECT is_partnered
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
""", (user_id,))
|
||||
return bool(data)
|
||||
@@ -79,12 +84,12 @@ def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
|
||||
with Database() as db:
|
||||
return bool(db.fetchone(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM subscribes
|
||||
WHERE user_id = ?
|
||||
AND subscribed_id = ?
|
||||
SELECT 1
|
||||
FROM subscribes
|
||||
WHERE user_id = ?
|
||||
AND subscribed_id = ?
|
||||
AND expires > ?;
|
||||
""",
|
||||
""",
|
||||
(user_id, subscribed_to_id, datetime.now())
|
||||
))
|
||||
|
||||
@@ -94,9 +99,9 @@ def is_following(user_id: int, followed_id: int) -> bool:
|
||||
"""
|
||||
with Database() as db:
|
||||
result = db.fetchone("""
|
||||
SELECT 1
|
||||
FROM follows
|
||||
WHERE user_id = ?
|
||||
SELECT 1
|
||||
FROM follows
|
||||
WHERE user_id = ?
|
||||
AND followed_id = ?
|
||||
""", (user_id, followed_id))
|
||||
return bool(result)
|
||||
@@ -107,7 +112,7 @@ def follow(user_id: int, followed_id: int):
|
||||
"""
|
||||
if is_following(user_id, followed_id):
|
||||
return {"success": False, "error": "Already following user"}, 400
|
||||
|
||||
|
||||
with Database() as db:
|
||||
db.execute("""
|
||||
INSERT INTO follows (user_id, followed_id)
|
||||
@@ -135,9 +140,9 @@ def is_following_category(user_id: int, category_id: str):
|
||||
"""
|
||||
with Database() as db:
|
||||
result = db.fetchone("""
|
||||
SELECT 1
|
||||
FROM followed_categories
|
||||
WHERE user_id = ?
|
||||
SELECT 1
|
||||
FROM followed_categories
|
||||
WHERE user_id = ?
|
||||
AND category_id = ?
|
||||
""", (user_id, category_id))
|
||||
return bool(result)
|
||||
@@ -148,7 +153,7 @@ def follow_category(user_id: int, category_id: str):
|
||||
"""
|
||||
if is_following_category(user_id, category_id):
|
||||
return {"success": False, "error": "Already following category"}, 400
|
||||
|
||||
|
||||
with Database() as db:
|
||||
db.execute("""
|
||||
INSERT INTO followed_categories (user_id, category_id)
|
||||
@@ -163,7 +168,7 @@ def unfollow_category(user_id: int, category_id: str):
|
||||
"""
|
||||
if not is_following_category(user_id, category_id):
|
||||
return {"success": False, "error": "Not following category"}, 400
|
||||
|
||||
|
||||
with Database() as db:
|
||||
db.execute("""
|
||||
DELETE FROM followed_categories
|
||||
@@ -179,8 +184,8 @@ def subscribe(user_id: int, streamer_id: int):
|
||||
# If user is already subscribed then extend the expiration date else create a new entry
|
||||
with Database() as db:
|
||||
existing = db.fetchone("""
|
||||
SELECT expires
|
||||
FROM subscribes
|
||||
SELECT expires
|
||||
FROM subscribes
|
||||
WHERE user_id = ? AND subscribed_id = ?
|
||||
""", (user_id, streamer_id))
|
||||
if existing:
|
||||
@@ -188,7 +193,7 @@ def subscribe(user_id: int, streamer_id: int):
|
||||
UPDATE subscribes SET expires = expires + ?
|
||||
WHERE user_id = ? AND subscribed_id = ?
|
||||
""", (timedelta(days=30), user_id, streamer_id))
|
||||
else:
|
||||
else:
|
||||
db.execute("""
|
||||
INSERT INTO subscribes
|
||||
(user_id, subscribed_id, since, expires)
|
||||
@@ -212,10 +217,10 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
|
||||
"""
|
||||
with Database() as db:
|
||||
data = db.fetchone("""
|
||||
SELECT expires
|
||||
FROM subscribes
|
||||
WHERE user_id = ?
|
||||
AND subscribed_id = ?
|
||||
SELECT expires
|
||||
FROM subscribes
|
||||
WHERE user_id = ?
|
||||
AND subscribed_id = ?
|
||||
AND expires > ?
|
||||
""", (user_id, subscribed_id, datetime.now()))
|
||||
|
||||
@@ -227,13 +232,14 @@ def subscription_expiration(user_id: int, subscribed_id: int) -> int:
|
||||
return 0
|
||||
|
||||
def get_email(user_id: int) -> Optional[str]:
|
||||
"""Returns the email address for a given user_id."""
|
||||
with Database() as db:
|
||||
email = db.fetchone("""
|
||||
SELECT email
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
""", (user_id,))
|
||||
|
||||
|
||||
return email["email"] if email else None
|
||||
|
||||
def get_followed_streamers(user_id: int) -> Optional[List[dict]]:
|
||||
@@ -289,4 +295,4 @@ def get_user(user_id: int) -> Optional[dict]:
|
||||
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
|
||||
WHERE user_id = ?;
|
||||
""", (user_id,))
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
from database.database import Database
|
||||
"""Input sanitization and validation utilities."""
|
||||
|
||||
from typing import Optional, List
|
||||
from re import match
|
||||
|
||||
from database.database import Database
|
||||
|
||||
|
||||
def get_all_categories() -> Optional[List[dict]]:
|
||||
"""
|
||||
Returns all possible streaming categories
|
||||
"""
|
||||
with Database() as db:
|
||||
all_categories = db.fetchall("SELECT * FROM categories")
|
||||
|
||||
|
||||
return all_categories
|
||||
|
||||
def get_all_tags() -> Optional[List[dict]]:
|
||||
@@ -17,7 +21,7 @@ def get_all_tags() -> Optional[List[dict]]:
|
||||
"""
|
||||
with Database() as db:
|
||||
all_tags = db.fetchall("SELECT * FROM tags")
|
||||
|
||||
|
||||
return all_tags
|
||||
|
||||
def get_most_popular_category() -> Optional[List[dict]]:
|
||||
@@ -34,7 +38,7 @@ def get_most_popular_category() -> Optional[List[dict]]:
|
||||
ORDER BY SUM(streams.num_viewers) DESC
|
||||
LIMIT 1;
|
||||
""")
|
||||
|
||||
|
||||
return category
|
||||
|
||||
def get_category_id(category_name: str):
|
||||
@@ -53,7 +57,7 @@ def get_category_id(category_name: str):
|
||||
def sanitize(user_input: str, input_type="default") -> str:
|
||||
"""
|
||||
Sanitizes user input based on the specified input type.
|
||||
|
||||
|
||||
`input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password').
|
||||
"""
|
||||
# Strip leading and trailing whitespace
|
||||
@@ -84,10 +88,10 @@ def sanitize(user_input: str, input_type="default") -> str:
|
||||
}
|
||||
|
||||
# Get the validation rules for the specified type
|
||||
r = rules.get(input_type)
|
||||
if not r or \
|
||||
not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \
|
||||
not match(r["pattern"], sanitised_input):
|
||||
rule = rules.get(input_type)
|
||||
if (not rule
|
||||
or not (rule["min_length"] <= len(sanitised_input) <= rule["max_length"])
|
||||
or not match(rule["pattern"], sanitised_input)):
|
||||
raise ValueError("Unaccepted character or length in input")
|
||||
|
||||
return sanitised_input
|
||||
return sanitised_input
|
||||
|
||||
Reference in New Issue
Block a user