MAJOR: Restructured backend Flask application moved all non-routes into utils, renamed routes to not prefix get, created middleware.py to replace utils.py within blueprints

This commit is contained in:
JustIceO7
2025-02-08 14:35:46 +00:00
parent ae623eee0d
commit e6b8ad9b9e
16 changed files with 631 additions and 550 deletions

View File

@@ -1,8 +1,7 @@
from flask import Flask from flask import Flask
from flask_session import Session from flask_session import Session
from flask_cors import CORS from flask_cors import CORS
from blueprints.utils import logged_in_user from blueprints.middleware import logged_in_user, register_error_handlers
from blueprints.errorhandlers import register_error_handlers
# from flask_wtf.csrf import CSRFProtect, generate_csrf # from flask_wtf.csrf import CSRFProtect, generate_csrf
from blueprints.authentication import auth_bp from blueprints.authentication import auth_bp

View File

@@ -1,5 +1,5 @@
from flask import Blueprint from flask import Blueprint
from blueprints.utils import admin_required from blueprints.middleware import admin_required
admin_bp = Blueprint("admin", __name__) admin_bp = Blueprint("admin", __name__)

View File

@@ -2,9 +2,10 @@ from flask import Blueprint, session, request, jsonify
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_cors import cross_origin from flask_cors import cross_origin
from database.database import Database from database.database import Database
from blueprints.utils import login_required, sanitizer from blueprints.middleware import login_required
from blueprints.email import send_email from utils.email import send_email
from blueprints.user import get_user_id from utils.user_utils import get_user_id
from utils.utils import sanitize
from secrets import token_hex from secrets import token_hex
auth_bp = Blueprint("auth", __name__) auth_bp = Blueprint("auth", __name__)
@@ -37,9 +38,9 @@ def signup():
# Sanitize the inputs - helps to prevent SQL injection # Sanitize the inputs - helps to prevent SQL injection
try: try:
username = sanitizer(username, "username") username = sanitize(username, "username")
email = sanitizer(email, "email") email = sanitize(email, "email")
password = sanitizer(password, "password") password = sanitize(password, "password")
except ValueError as e: except ValueError as e:
error_fields = get_error_fields([username, email, password]) error_fields = get_error_fields([username, email, password])
return jsonify({ return jsonify({
@@ -144,8 +145,8 @@ def login():
# Sanitize the inputs - helps to prevent SQL injection # Sanitize the inputs - helps to prevent SQL injection
try: try:
username = sanitizer(username, "username") username = sanitize(username, "username")
password = sanitizer(password, "password") password = sanitize(password, "password")
except ValueError as e: except ValueError as e:
return jsonify({ return jsonify({
"account_created": False, "account_created": False,

View File

@@ -3,7 +3,7 @@ from database.database import Database
from .socket import socketio from .socket import socketio
from flask_socketio import emit, join_room, leave_room from flask_socketio import emit, join_room, leave_room
from datetime import datetime from datetime import datetime
from blueprints.user import get_user_id from utils.user_utils import get_user_id
chat_bp = Blueprint("chat", __name__) chat_bp = Blueprint("chat", __name__)
@@ -117,4 +117,4 @@ def save_chat(chatter_id, stream_id, message):
db.execute(""" db.execute("""
INSERT INTO chat (chatter_id, stream_id, message) INSERT INTO chat (chatter_id, stream_id, message)
VALUES (?, ?, ?);""", (chatter_id, stream_id, message)) VALUES (?, ?, ?);""", (chatter_id, stream_id, message))
db.close_connection() db.close_connection()

View File

@@ -1,14 +0,0 @@
import logging
def register_error_handlers(app):
error_responses = {
400: "Bad Request",
403: "Forbidden",
404: "Not Found",
500: "Internal Server Error"
}
for code, message in error_responses.items():
@app.errorhandler(code)
def handle_error(error, message=message, code=code):
logging.error(f"Error {code}: {str(error)}")
return {"error": message}, code

View File

@@ -0,0 +1,46 @@
from flask import redirect, url_for, request, g, session
from functools import wraps
import logging
def logged_in_user():
"""
Validator to make sure a user is logged in.
"""
g.user = session.get("username", None)
g.admin = session.get("username", None)
def login_required(view):
"""
Add at start of routes where users need to be logged in to access.
"""
@wraps(view)
def wrapped_view(*args, **kwargs):
if g.user is None:
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def admin_required(view):
"""
Add at start of routes where admins need to be logged in to access.
"""
@wraps(view)
def wrapped_view(*args, **kwargs):
if g.admin != "admin":
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def register_error_handlers(app):
error_responses = {
400: "Bad Request",
403: "Forbidden",
404: "Not Found",
500: "Internal Server Error"
}
for code, message in error_responses.items():
@app.errorhandler(code)
def handle_error(error, message=message, code=code):
logging.error(f"Error {code}: {str(error)}")
return {"error": message}, code

View File

@@ -3,7 +3,7 @@ from database.database import Database
search_bp = Blueprint("search", __name__) search_bp = Blueprint("search", __name__)
@search_bp.route("/search/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/<string:query>", methods=["GET", "POST"])
def search_results(query: str): def search_results(query: str):
""" """
Return the most similar search results Return the most similar search results
@@ -46,7 +46,7 @@ def search_results(query: str):
return jsonify({"categories": categories, "users": users, "streams": streams}) return jsonify({"categories": categories, "users": users, "streams": streams})
@search_bp.route("/search/categories/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/categories/<string:query>", methods=["GET", "POST"])
def search_categories(query: str): def search_categories(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()
@@ -63,7 +63,7 @@ def search_categories(query: str):
return jsonify({"categories": categories}) return jsonify({"categories": categories})
@search_bp.route("/search/users/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/users/<string:query>", methods=["GET", "POST"])
def search_users(query: str): def search_users(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()
@@ -81,7 +81,7 @@ def search_users(query: str):
return jsonify({"users": users}) return jsonify({"users": users})
@search_bp.route("/search/streams/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/streams/<string:query>", methods=["GET", "POST"])
def search_streams(query: str): def search_streams(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()

View File

@@ -1,11 +1,11 @@
from flask import Blueprint, session, jsonify, request, redirect from flask import Blueprint, session, jsonify, request, redirect
from utils.stream_utils import * from utils.stream_utils import *
from blueprints.user import get_user_id from utils.recommendation_utils import *
from blueprints.utils import login_required from utils.user_utils import get_user_id
from blueprints.middleware import login_required
from database.database import Database from database.database import Database
from datetime import datetime from datetime import datetime
from celery_tasks import update_thumbnail from celery_tasks import update_thumbnail
from typing import List, Optional
stream_bp = Blueprint("stream", __name__) stream_bp = Blueprint("stream", __name__)
@@ -15,7 +15,7 @@ THUMBNAIL_GENERATION_INTERVAL = 180
## Stream Routes ## Stream Routes
@stream_bp.route('/streams/popular/<int:no_streams>') @stream_bp.route('/streams/popular/<int:no_streams>')
def get_popular_streams(no_streams) -> list[dict]: def popular_streams(no_streams) -> list[dict]:
""" """
Returns a list of streams live now with the highest viewers Returns a list of streams live now with the highest viewers
""" """
@@ -28,48 +28,23 @@ def get_popular_streams(no_streams) -> list[dict]:
no_streams = MAX_STREAMS no_streams = MAX_STREAMS
# Get the highest viewed streams # Get the highest viewed streams
with Database() as db: streams = get_highest_view_streams(no_streams)
streams = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return jsonify(streams) return jsonify(streams)
@stream_bp.route('/streams/popular/<string:category_name>') @stream_bp.route('/streams/popular/<string:category_name>')
def get_popular_streams_by_category(category_name) -> list[dict]: def popular_streams_by_category(category_name) -> list[dict]:
""" """
Returns a list of streams live now with the highest viewers in a given category Returns a list of streams live now with the highest viewers in a given category
""" """
with Database() as db: category_id = get_category_id(category_name)
data = db.fetchone("""
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
category_id = data["category_id"] if data else None
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT 25
""", (category_id,))
streams = get_streams_based_on_category(category_id)
return jsonify(streams) return jsonify(streams)
@login_required @login_required
@stream_bp.route('/streams/recommended') @stream_bp.route('/streams/recommended')
def get_recommended_streams() -> list[dict]: def recommended_streams() -> list[dict]:
""" """
Queries DB to get a list of recommended streams using an algorithm Queries DB to get a list of recommended streams using an algorithm
""" """
@@ -77,61 +52,22 @@ def get_recommended_streams() -> list[dict]:
user_id = session.get("user_id") user_id = session.get("user_id")
# Get the user's most popular categories # Get the user's most popular categories
with Database() as db: category = get_user_preferred_category(user_id)
category = db.fetchone(""" streams = get_streams_based_on_category(category)
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
category_id = category["category_id"] if category else None
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT 25
""", (category_id,))
return streams return streams
@stream_bp.route('/streams/<int:streamer_id>/data') @stream_bp.route('/streams/<int:streamer_id>/data')
def get_stream_data(streamer_id): def stream_data(streamer_id):
""" """
Returns a streamer's current stream data Returns a streamer's current stream data
""" """
with Database() as db:
most_recent_stream = db.fetchone(""" return jsonify(get_current_stream_data(streamer_id))
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
WHERE u.user_id = ?
""", (streamer_id,))
return jsonify(most_recent_stream)
def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
Given a stream return tags associated with the user's stream
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
""", (user_id,))
return tags
## Category Routes ## Category Routes
@stream_bp.route('/categories/popular/<int:no_categories>') @stream_bp.route('/categories/popular/<int:no_categories>')
def get_popular_categories(no_categories) -> list[dict]: def popular_categories(no_categories) -> list[dict]:
""" """
Returns a list of most popular categories Returns a list of most popular categories
""" """
@@ -141,62 +77,34 @@ def get_popular_categories(no_categories) -> list[dict]:
elif no_categories > 100: elif no_categories > 100:
no_categories = 100 no_categories = 100
with Database() as db: category_data = get_highest_view_categories(no_categories)
category_data = db.fetchall("""
SELECT categories.category_id, categories.category_name, SUM(streams.num_viewers) AS num_viewers
FROM streams
JOIN categories ON streams.category_id = categories.category_id
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT ?;
""", (no_categories,))
return jsonify(category_data) return jsonify(category_data)
@login_required @login_required
@stream_bp.route('/categories/recommended') @stream_bp.route('/categories/recommended')
def get_recommended_categories() -> list | list[dict]: def recommended_categories() -> list | list[dict]:
""" """
Queries DB to get a list of recommended categories for the user Queries DB to get a list of recommended categories for the user
""" """
user_id = session.get("user_id") user_id = session.get("user_id")
categories = get_user_category_recommendations(user_id)
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 5
""", (user_id,))
return jsonify(categories) return jsonify(categories)
@login_required @login_required
@stream_bp.route('/categories/following') @stream_bp.route('/categories/following')
def get_following_categories_streams(): def following_categories_streams():
""" """
Returns popular streams in categories which the user followed Returns popular streams in categories which the user followed
""" """
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT 25;
""", (session.get("user_id"),))
streams = get_followed_categories_recommendations(session.get('user_id'))
return jsonify(streams) return jsonify(streams)
## User Routes ## User Routes
@stream_bp.route('/user/<string:username>/status') @stream_bp.route('/user/<string:username>/status')
def get_user_live_status(username): def user_live_status(username):
""" """
Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live) Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live)
""" """
@@ -204,9 +112,7 @@ def get_user_live_status(username):
# Check if streamer is live and get their most recent vod # Check if streamer is live and get their most recent vod
is_live = True if get_streamer_live_status(user_id)['is_live'] else False is_live = True if get_streamer_live_status(user_id)['is_live'] else False
most_recent_vod = get_latest_vod(user_id)
with Database() as db:
most_recent_vod = db.fetchone("""SELECT * FROM vods WHERE user_id = ? ORDER BY vod_id DESC LIMIT 1;""", (user_id,))
# If there is no most recent vod, set it to None # If there is no most recent vod, set it to None
if not most_recent_vod: if not most_recent_vod:
@@ -222,30 +128,14 @@ def get_user_live_status(username):
## VOD Routes ## VOD Routes
@stream_bp.route('/vods/<string:username>') @stream_bp.route('/vods/<string:username>')
def get_vods(username): def vods(username):
""" """
Returns a JSON of all the vods of a streamer Returns a JSON of all the vods of a streamer
""" """
user_id = get_user_id(username) user_id = get_user_id(username)
vods = get_user_vods(user_id)
with Database() as db:
vods = db.fetchall("""SELECT * FROM vods WHERE user_id = ?;""", (user_id,))
return jsonify(vods) return jsonify(vods)
def get_vod_tags(vod_id: int):
"""
Given a vod return tags associated with the vod
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
""", (vod_id,))
return tags
## RTMP Server Routes ## RTMP Server Routes
@stream_bp.route("/publish_stream", methods=["POST"]) @stream_bp.route("/publish_stream", methods=["POST"])
@@ -293,31 +183,4 @@ def end_stream():
# Remove stream from database # Remove stream from database
db.execute("""DELETE FROM streams WHERE user_id = ?""", (user_info["user_id"],)) db.execute("""DELETE FROM streams WHERE user_id = ?""", (user_info["user_id"],))
return "Stream ended", 200 return "Stream ended", 200
def transfer_stream_to_vod(user_id: int):
"""
Deletes stream from stream table and moves it to VoD table
TODO: Add functionaliy to save stream permanently
"""
with Database() as db:
stream = db.fetchone("""
SELECT * FROM streams WHERE user_id = ?;
""", (user_id,))
if not stream:
return None
## TODO: calculate length in seconds, currently using temp value
db.execute("""
INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?);
""", (stream["user_id"], stream["title"], stream["datetime"], stream["category_id"], 10, stream["num_viewers"]))
db.execute("""
DELETE FROM streams WHERE user_id = ?;
""", (user_id,))
return True

View File

@@ -1,7 +1,8 @@
from flask import Blueprint, jsonify, session from flask import Blueprint, jsonify, session
from utils.user_utils import * from utils.user_utils import *
from blueprints.utils import login_required from utils.auth import *
from blueprints.email import send_email, forgot_password_body from blueprints.middleware import login_required
from utils.email import send_email, forgot_password_body
import redis import redis
redis_url = "redis://redis:6379/1" redis_url = "redis://redis:6379/1"
@@ -10,7 +11,7 @@ r = redis.from_url(redis_url, decode_responses=True)
user_bp = Blueprint("user", __name__) user_bp = Blueprint("user", __name__)
@user_bp.route('/user/<string:username>') @user_bp.route('/user/<string:username>')
def get_user_data(username: str): def user_data(username: str):
""" """
Returns a given user's data Returns a given user's data
""" """
@@ -20,93 +21,7 @@ def get_user_data(username: str):
data = get_user(user_id) data = get_user(user_id)
return jsonify(data) return jsonify(data)
def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
def get_username(user_id: str) -> Optional[str]:
"""
Returns username associated with given user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT username
FROM user
WHERE user_id = ?
""", (user_id,))
return data['username'] if data else None
def get_email(user_id: int) -> Optional[str]:
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_session_info_email(email: str) -> dict:
"""
Returns username and user_id given email
"""
with Database() as db:
session_info = db.fetchone("""
SELECT user_id, username
FROM user
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
def get_user(user_id: int) -> Optional[dict]:
"""
Returns information about a user from user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data
## Subscription Routes ## Subscription Routes
def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
"""
Returns True if user is subscribed to a streamer, else False
"""
with Database() as db:
result = db.fetchone("""
SELECT *
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""", (user_id, subscribed_to_id, datetime.now()))
print(result)
if result:
return True
return False
@login_required @login_required
@user_bp.route('/user/subscription/<int:subscribed_id>') @user_bp.route('/user/subscription/<int:subscribed_id>')
def user_subscribed(subscribed_id: int): def user_subscribed(subscribed_id: int):
@@ -124,21 +39,10 @@ def user_subscription_expiration(subscribed_id: int):
""" """
Returns remaining time until subscription expiration Returns remaining time until subscription expiration
""" """
with Database() as db:
data = db.fetchone("""
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?
""", (session.get("user_id"), subscribed_id, datetime.now()))
if data: user_id = session.get("user_id")
expiration_date = data["expires"] remaining_time = subscription_expiration(user_id, subscribed_id)
remaining_time = (parser.parse(expiration_date) - datetime.now()).seconds
else:
remaining_time = 0
return jsonify({"remaining_time": remaining_time}) return jsonify({"remaining_time": remaining_time})
## Follow Routes ## Follow Routes
@@ -163,7 +67,7 @@ def follow_user(target_user_id: int):
return follow(user_id, target_user_id) return follow(user_id, target_user_id)
@login_required @login_required
@user_bp.route('/user/unfollow/<int:target_user_id>') @user_bp.route('/user/unfollow/<string:target_user_id>')
def unfollow_user(target_user_id: int): def unfollow_user(target_user_id: int):
""" """
Unfollows a user Unfollows a user
@@ -173,40 +77,18 @@ def unfollow_user(target_user_id: int):
@login_required @login_required
@user_bp.route('/user/following') @user_bp.route('/user/following')
def get_followed_streamers(): def followed_streamers():
""" """
Queries DB to get a list of followed streamers Queries DB to get a list of followed streamers
""" """
user_id = session.get('user_id') user_id = session.get('user_id')
with Database() as db: live_following_streams = get_followed_streamers(user_id)
followed_streamers = db.fetchall(""" return live_following_streams
SELECT user_id, username
FROM users
WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?);
""", (user_id,))
return followed_streamers
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
## Login Routes ## Login Routes
@user_bp.route('/user/login_status') @user_bp.route('/user/login_status')
def get_login_status(): def login_status():
""" """
Returns whether the user is logged in or not Returns whether the user is logged in or not
""" """

View File

45
web_server/utils/auth.py Normal file
View File

@@ -0,0 +1,45 @@
from database.database import Database
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from typing import Optional
from dotenv import load_dotenv
from os import getenv
from werkzeug.security import generate_password_hash
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY"))
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
"""
token = serializer.dumps(email, salt=salt_value)
return token
def verify_token(token: str, salt_value) -> Optional[str]:
"""
Given a token, verifies and decodes it into an email
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
except SignatureExpired:
# Token expired
print("Token has expired", flush=True)
return None
except BadSignature:
# Invalid token
print("Token is invalid", flush=True)
return None
def reset_password(new_password: str, email: str) -> bool:
"""
Given email and new password reset the password for a given user
"""
with Database() as db:
db.execute("""
UPDATE users
SET password = ?
WHERE email = ?
""", (generate_password_hash(new_password), email))
return True

View File

@@ -1,84 +1,83 @@
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from os import getenv from os import getenv
from random import randrange from dotenv import load_dotenv
from dotenv import load_dotenv from utils.auth import generate_token
from utils.user_utils import generate_token from secrets import token_hex
from secrets import token_hex import redis
import redis
redis_url = "redis://redis:6379/1"
redis_url = "redis://redis:6379/1" r = redis.from_url(redis_url, decode_responses=True)
r = redis.from_url(redis_url, decode_responses=True)
load_dotenv()
load_dotenv()
def send_email(email, func) -> None:
def send_email(email, func) -> None: """
""" Send a verification email to the user.
Send a verification email to the user. """
"""
# Setup the sender email details
# Setup the sender email details SMTP_SERVER = "smtp.gmail.com"
SMTP_SERVER = "smtp.gmail.com" SMTP_PORT = 587
SMTP_PORT = 587 SMTP_EMAIL = getenv("EMAIL")
SMTP_EMAIL = getenv("EMAIL") SMTP_PASSWORD = getenv("EMAIL_PASSWORD")
SMTP_PASSWORD = getenv("EMAIL_PASSWORD")
# Setup up the receiver details
# Setup up the receiver details body = func()
body = func()
msg = MIMEText(body, "html")
msg = MIMEText(body, "html") msg["Subject"] = "Reset Gander Login"
msg["Subject"] = "Reset Gander Login" msg["From"] = SMTP_EMAIL
msg["From"] = SMTP_EMAIL msg["To"] = email
msg["To"] = email
# Send the email using smtplib
# Send the email using smtplib with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp: try:
try: smtp.starttls() # TLS handshake to start the connection
smtp.starttls() # TLS handshake to start the connection smtp.login(SMTP_EMAIL, SMTP_PASSWORD)
smtp.login(SMTP_EMAIL, SMTP_PASSWORD) smtp.ehlo()
smtp.ehlo() smtp.send_message(msg)
smtp.send_message(msg)
except TimeoutError:
except TimeoutError: print("Server timed out", flush=True)
print("Server timed out", flush=True)
except Exception as e:
except Exception as e: print("Error: ", e, flush=True)
print("Error: ", e, flush=True)
def forgot_password_body(email):
def forgot_password_body(email): """
""" Handles the creation of the email body for resetting password
Handles the creation of the email body for resetting password """
""" salt = token_hex(32)
salt = token_hex(32)
token = generate_token(email, salt)
token = generate_token(email, salt) url = getenv("VITE_API_URL")
url = getenv("VITE_API_URL") r.setex(token, 3600, salt)
r.setex(token, 3600, salt)
full_url = url + "/reset_password/" + token
full_url = url + "/reset_password/" + token content = f"""
content = f""" <html>
<html> <head>
<head> <meta charset="UTF-8">
<meta charset="UTF-8"> <title>Password Reset</title>
<title>Password Reset</title> <style>
<style> body {{ font-family: Arial, sans-serif; background-color: #f4f4f4; padding: 20px; text-align: center; }}
body {{ font-family: Arial, sans-serif; background-color: #f4f4f4; padding: 20px; text-align: center; }} .container {{ max-width: 400px; background: white; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); }}
.container {{ max-width: 400px; background: white; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); }} .btn {{ display: inline-block; padding: 10px 20px; color: white; background-color: #007bff; text-decoration: none; border-radius: 5px; }}
.btn {{ display: inline-block; padding: 10px 20px; color: white; background-color: #007bff; text-decoration: none; border-radius: 5px; }} .btn:hover {{ background-color: #0056b3; }}
.btn:hover {{ background-color: #0056b3; }} p {{ color: #333; }}
p {{ color: #333; }} </style>
</style> </head>
</head> <body>
<body> <div class="container">
<div class="container"> <h1>Gander</h1>
<h1>Gander</h1> <h2>Password Reset Request</h2>
<h2>Password Reset Request</h2> <p>Click the button below to reset your password. This link is valid for 1 hour.</p>
<p>Click the button below to reset your password. This link is valid for 1 hour.</p> <a href="{full_url}" class="btn">Reset Password</a>
<a href="{full_url}" class="btn">Reset Password</a> </div>
</div> </body>
</body> </html>
</html> """
"""
return content return content

View File

@@ -0,0 +1,97 @@
from database.database import Database
from typing import Optional, List
def get_user_preferred_category(user_id: int) -> Optional[int]:
"""
Queries user_preferences database to find users favourite streaming category and returns the category
"""
with Database() as db:
category = db.fetchone("""
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
return category["category_id"] if category else None
def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]:
"""
Returns top streams given a user's category following
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT ?;
""", (user_id, no_streams))
return streams
def get_streams_based_on_category(category_id: int, no_streams: int = 4) -> Optional[List[dict]]:
"""
Queries stream database to get top most viewed streams based on given category
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT ?
""", (category_id, no_streams))
return streams
def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]:
"""
Return a list of live streams by number of viewers
"""
with Database() as db:
data = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return data
def get_highest_view_categories(no_categories: int = 4) -> Optional[List[dict]]:
"""
Returns a list of top most popular categories
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, SUM(streams.num_viewers) AS num_viewers
FROM streams
JOIN categories ON streams.category_id = categories.category_id
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT ?;
""", (no_categories,))
return categories
def get_user_category_recommendations(user_id: int) -> Optional[List[dict]]:
"""
Queries user_preferences database to find users top 5 favourite streaming category and returns the category
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 5
""", (user_id,))
return categories

View File

@@ -1,5 +1,7 @@
from database.database import Database from database.database import Database
from typing import Optional
import os, subprocess import os, subprocess
from typing import Optional, List
def get_streamer_live_status(user_id: int): def get_streamer_live_status(user_id: int):
""" """
@@ -14,6 +16,72 @@ def get_streamer_live_status(user_id: int):
return is_live return is_live
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
def get_current_stream_data(user_id: int) -> Optional[dict]:
"""
Returns data of the most recent stream by a streamer
"""
with Database() as db:
most_recent_stream = db.fetchone("""
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
WHERE u.user_id = ?
""", (user_id,))
return most_recent_stream
def get_category_id(category_name: str) -> Optional[int]:
"""
Returns the category_id given a category name
"""
with Database() as db:
data = db.fetchone("""
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
return data['category_id'] if data else None
def get_vod(vod_id: int) -> dict:
"""
Returns data of a streamers vod
"""
with Database() as db:
vod = db.fetchone("""SELECT * FROM vods WHERE vod_id = ?;""", (vod_id,))
return vod
def get_latest_vod(user_id: int):
"""
Returns data of the most recent stream by a streamer
"""
with Database() as db:
latest_vod = db.fetchone("""SELECT * FROM vods WHERE user_id = ? ORDER BY vod_id DESC LIMIT 1;""", (user_id,))
return latest_vod
def get_user_vods(user_id: int):
"""
Returns data of all vods by a streamer
"""
with Database() as db:
vods = db.fetchall("""SELECT * FROM vods WHERE user_id = ?;""", (user_id,))
return vods
def generate_thumbnail(user_id: int) -> None: def generate_thumbnail(user_id: int) -> None:
""" """
@@ -42,3 +110,55 @@ def generate_thumbnail(user_id: int) -> None:
subprocess.run(thumbnail_command) subprocess.run(thumbnail_command)
def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
Given a stream return tags associated with the user's stream
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
""", (user_id,))
return tags
def get_vod_tags(vod_id: int):
"""
Given a vod return tags associated with the vod
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
""", (vod_id,))
return tags
def transfer_stream_to_vod(user_id: int):
"""
Deletes stream from stream table and moves it to VoD table
TODO: Add functionaliy to save stream permanently
"""
with Database() as db:
stream = db.fetchone("""
SELECT * FROM streams WHERE user_id = ?;
""", (user_id,))
if not stream:
return None
## TODO: calculate length in seconds, currently using temp value
db.execute("""
INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?);
""", (stream["user_id"], stream["title"], stream["datetime"], stream["category_id"], 10, stream["num_viewers"]))
db.execute("""
DELETE FROM streams WHERE user_id = ?;
""", (user_id,))
return True

View File

@@ -1,14 +1,72 @@
from database.database import Database from database.database import Database
from typing import Optional, List from typing import Optional, List
from datetime import datetime from datetime import datetime
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from os import getenv
from werkzeug.security import generate_password_hash
from dateutil import parser from dateutil import parser
from dotenv import load_dotenv
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY")) def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
def get_username(user_id: str) -> Optional[str]:
"""
Returns username associated with given user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT username
FROM user
WHERE user_id = ?
""", (user_id,))
return data['username'] if data else None
def get_session_info_email(email: str) -> dict:
"""
Returns username and user_id given email
"""
with Database as db:
session_info = db.fetchone("""
SELECT user_id, username
FROM user
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
"""
Returns True if user is subscribed to a streamer, else False
"""
with Database() as db:
result = db.fetchone("""
SELECT *
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""", (user_id, subscribed_to_id, datetime.now()))
print(result)
if result:
return True
return False
def is_following(user_id: int, followed_id: int) -> bool: def is_following(user_id: int, followed_id: int) -> bool:
""" """
@@ -51,40 +109,56 @@ def unfollow(user_id: int, followed_id: int):
""", (user_id, followed_id)) """, (user_id, followed_id))
return {"success": True} return {"success": True}
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
"""
token = serializer.dumps(email, salt=salt_value)
return token
def verify_token(token: str, salt_value) -> Optional[str]: def subscription_expiration(user_id: int, subscribed_id: int) -> int:
""" """
Given a token, verifies and decodes it into an email Returns the amount of time left until user subscription to a streamer ends
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
except SignatureExpired:
# Token expired
print("Token has expired", flush=True)
return None
except BadSignature:
# Invalid token
print("Token is invalid", flush=True)
return None
def reset_password(new_password: str, email: str) -> bool:
"""
Given email and new password reset the password for a given user
""" """
with Database() as db: with Database() as db:
db.execute(""" data = db.fetchone("""
UPDATE users SELECT expires
SET password = ? FROM subscribes
WHERE email = ? WHERE user_id = ?
""", (generate_password_hash(new_password), email)) AND subscribed_id = ?
AND expires > ?
return True """, (user_id, subscribed_id, datetime.now()))
if data:
expiration_date = data["expires"]
remaining_time = (parser.parse(expiration_date) - datetime.now()).seconds
return remaining_time
return 0
def get_email(user_id: int) -> Optional[str]:
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_followed_streamers(user_id: int) -> Optional[List[dict]]:
"""
Returns a list of streamers who the user follows
"""
with Database() as db:
followed_streamers = db.fetchall("""
SELECT user_id, username
FROM users
WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?);
""", (user_id,))
return followed_streamers
def get_user(user_id: int) -> Optional[dict]:
"""
Returns information about a user from user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data

View File

@@ -1,106 +1,75 @@
from flask import redirect, url_for, request, g, session from database.database import Database
from functools import wraps from typing import Optional, List
from re import match from re import match
from database.database import Database
from typing import Optional, List def get_all_categories() -> Optional[List[dict]]:
"""
def logged_in_user(): Returns all possible streaming categories
""" """
Validator to make sure a user is logged in. with Database() as db:
""" all_categories = db.fetchall("SELECT * FROM categories")
g.user = session.get("username", None)
g.admin = session.get("username", None) return all_categories
def login_required(view): def get_all_tags() -> Optional[List[dict]]:
""" """
Add at start of routes where users need to be logged in to access. Returns all possible streaming tags
""" """
@wraps(view) with Database() as db:
def wrapped_view(*args, **kwargs): all_tags = db.fetchall("SELECT * FROM tags")
if g.user is None:
return redirect(url_for("login", next=request.url)) return all_tags
return view(*args, **kwargs)
return wrapped_view def get_most_popular_category() -> Optional[List[dict]]:
"""
def admin_required(view): Returns the most popular category based on live stream viewers
""" """
Add at start of routes where admins need to be logged in to access. with Database() as db:
""" category = db.fetchone("""
@wraps(view) SELECT categories.category_id, categories.category_name
def wrapped_view(*args, **kwargs): FROM streams
if g.admin != "admin": JOIN categories ON streams.category_id = categories.category_id
return redirect(url_for("login", next=request.url)) WHERE streams.isLive = 1
return view(*args, **kwargs) GROUP BY categories.category_name
return wrapped_view ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
def sanitizer(user_input: str, input_type="username") -> str: """)
"""
Sanitizes user input based on the specified input type. return category
`input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password'). def sanitize(user_input: str, input_type="username") -> str:
""" """
# Strip leading and trailing whitespace Sanitizes user input based on the specified input type.
sanitised_input = user_input.strip()
`input_type`: The type of input to sanitize (e.g., 'username', 'email', 'password').
# Define allowed patterns and length constraints for each type """
rules = { # Strip leading and trailing whitespace
"username": { sanitised_input = user_input.strip()
"pattern": r"^[a-zA-Z0-9_]+$", # Alphanumeric + underscores
"min_length": 3, # Define allowed patterns and length constraints for each type
"max_length": 50, rules = {
}, "username": {
"email": { "pattern": r"^[a-zA-Z0-9_]+$", # Alphanumeric + underscores
"pattern": r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", # Standard email regex "min_length": 3,
"min_length": 5, "max_length": 50,
"max_length": 128, },
}, "email": {
"password": { "pattern": r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", # Standard email regex
"pattern": r"^[\S]+$", # Non-whitespace characters only "min_length": 5,
"min_length": 8, "max_length": 128,
"max_length": 256, },
}, "password": {
} "pattern": r"^[\S]+$", # Non-whitespace characters only
"min_length": 8,
# Get the validation rules for the specified type "max_length": 256,
r = rules.get(input_type) },
if not r or \ }
not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \
not match(r["pattern"], sanitised_input): # Get the validation rules for the specified type
raise ValueError("Unaccepted character or length in input") r = rules.get(input_type)
if not r or \
return sanitised_input not (r["min_length"] <= len(sanitised_input) <= r["max_length"]) or \
not match(r["pattern"], sanitised_input):
def categories() -> Optional[List[dict]]: raise ValueError("Unaccepted character or length in input")
"""
Returns all possible streaming categories return sanitised_input
"""
with Database() as db:
all_categories = db.fetchall("SELECT * FROM categories")
return all_categories
def tags() -> Optional[List[dict]]:
"""
Returns all possible streaming tags
"""
with Database() as db:
all_tags = db.fetchall("SELECT * FROM tags")
return all_tags
def most_popular_category() -> Optional[List[dict]]:
"""
Returns the most popular category based on live stream viewers
"""
with Database() as db:
category = db.fetchone("""
SELECT categories.category_id, categories.category_name
FROM streams
JOIN categories ON streams.category_id = categories.category_id
WHERE streams.isLive = 1
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
""")
return category