MAJOR: Restructured backend Flask application moved all non-routes into utils, renamed routes to not prefix get, created middleware.py to replace utils.py within blueprints

This commit is contained in:
JustIceO7
2025-02-08 14:35:46 +00:00
parent ae623eee0d
commit e6b8ad9b9e
16 changed files with 631 additions and 550 deletions

View File

@@ -1,8 +1,7 @@
from flask import Flask from flask import Flask
from flask_session import Session from flask_session import Session
from flask_cors import CORS from flask_cors import CORS
from blueprints.utils import logged_in_user from blueprints.middleware import logged_in_user, register_error_handlers
from blueprints.errorhandlers import register_error_handlers
# from flask_wtf.csrf import CSRFProtect, generate_csrf # from flask_wtf.csrf import CSRFProtect, generate_csrf
from blueprints.authentication import auth_bp from blueprints.authentication import auth_bp

View File

@@ -1,5 +1,5 @@
from flask import Blueprint from flask import Blueprint
from blueprints.utils import admin_required from blueprints.middleware import admin_required
admin_bp = Blueprint("admin", __name__) admin_bp = Blueprint("admin", __name__)

View File

@@ -2,9 +2,10 @@ from flask import Blueprint, session, request, jsonify
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_cors import cross_origin from flask_cors import cross_origin
from database.database import Database from database.database import Database
from blueprints.utils import login_required, sanitizer from blueprints.middleware import login_required
from blueprints.email import send_email from utils.email import send_email
from blueprints.user import get_user_id from utils.user_utils import get_user_id
from utils.utils import sanitize
from secrets import token_hex from secrets import token_hex
auth_bp = Blueprint("auth", __name__) auth_bp = Blueprint("auth", __name__)
@@ -37,9 +38,9 @@ def signup():
# Sanitize the inputs - helps to prevent SQL injection # Sanitize the inputs - helps to prevent SQL injection
try: try:
username = sanitizer(username, "username") username = sanitize(username, "username")
email = sanitizer(email, "email") email = sanitize(email, "email")
password = sanitizer(password, "password") password = sanitize(password, "password")
except ValueError as e: except ValueError as e:
error_fields = get_error_fields([username, email, password]) error_fields = get_error_fields([username, email, password])
return jsonify({ return jsonify({
@@ -144,8 +145,8 @@ def login():
# Sanitize the inputs - helps to prevent SQL injection # Sanitize the inputs - helps to prevent SQL injection
try: try:
username = sanitizer(username, "username") username = sanitize(username, "username")
password = sanitizer(password, "password") password = sanitize(password, "password")
except ValueError as e: except ValueError as e:
return jsonify({ return jsonify({
"account_created": False, "account_created": False,

View File

@@ -3,7 +3,7 @@ from database.database import Database
from .socket import socketio from .socket import socketio
from flask_socketio import emit, join_room, leave_room from flask_socketio import emit, join_room, leave_room
from datetime import datetime from datetime import datetime
from blueprints.user import get_user_id from utils.user_utils import get_user_id
chat_bp = Blueprint("chat", __name__) chat_bp = Blueprint("chat", __name__)

View File

@@ -1,14 +0,0 @@
import logging
def register_error_handlers(app):
error_responses = {
400: "Bad Request",
403: "Forbidden",
404: "Not Found",
500: "Internal Server Error"
}
for code, message in error_responses.items():
@app.errorhandler(code)
def handle_error(error, message=message, code=code):
logging.error(f"Error {code}: {str(error)}")
return {"error": message}, code

View File

@@ -0,0 +1,46 @@
from flask import redirect, url_for, request, g, session
from functools import wraps
import logging
def logged_in_user():
"""
Validator to make sure a user is logged in.
"""
g.user = session.get("username", None)
g.admin = session.get("username", None)
def login_required(view):
"""
Add at start of routes where users need to be logged in to access.
"""
@wraps(view)
def wrapped_view(*args, **kwargs):
if g.user is None:
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def admin_required(view):
"""
Add at start of routes where admins need to be logged in to access.
"""
@wraps(view)
def wrapped_view(*args, **kwargs):
if g.admin != "admin":
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def register_error_handlers(app):
error_responses = {
400: "Bad Request",
403: "Forbidden",
404: "Not Found",
500: "Internal Server Error"
}
for code, message in error_responses.items():
@app.errorhandler(code)
def handle_error(error, message=message, code=code):
logging.error(f"Error {code}: {str(error)}")
return {"error": message}, code

View File

@@ -3,7 +3,7 @@ from database.database import Database
search_bp = Blueprint("search", __name__) search_bp = Blueprint("search", __name__)
@search_bp.route("/search/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/<string:query>", methods=["GET", "POST"])
def search_results(query: str): def search_results(query: str):
""" """
Return the most similar search results Return the most similar search results
@@ -46,7 +46,7 @@ def search_results(query: str):
return jsonify({"categories": categories, "users": users, "streams": streams}) return jsonify({"categories": categories, "users": users, "streams": streams})
@search_bp.route("/search/categories/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/categories/<string:query>", methods=["GET", "POST"])
def search_categories(query: str): def search_categories(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()
@@ -63,7 +63,7 @@ def search_categories(query: str):
return jsonify({"categories": categories}) return jsonify({"categories": categories})
@search_bp.route("/search/users/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/users/<string:query>", methods=["GET", "POST"])
def search_users(query: str): def search_users(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()
@@ -81,7 +81,7 @@ def search_users(query: str):
return jsonify({"users": users}) return jsonify({"users": users})
@search_bp.route("/search/streams/<str:query>", methods=["GET", "POST"]) @search_bp.route("/search/streams/<string:query>", methods=["GET", "POST"])
def search_streams(query: str): def search_streams(query: str):
# Create the connection to the database # Create the connection to the database
db = Database() db = Database()

View File

@@ -1,11 +1,11 @@
from flask import Blueprint, session, jsonify, request, redirect from flask import Blueprint, session, jsonify, request, redirect
from utils.stream_utils import * from utils.stream_utils import *
from blueprints.user import get_user_id from utils.recommendation_utils import *
from blueprints.utils import login_required from utils.user_utils import get_user_id
from blueprints.middleware import login_required
from database.database import Database from database.database import Database
from datetime import datetime from datetime import datetime
from celery_tasks import update_thumbnail from celery_tasks import update_thumbnail
from typing import List, Optional
stream_bp = Blueprint("stream", __name__) stream_bp = Blueprint("stream", __name__)
@@ -15,7 +15,7 @@ THUMBNAIL_GENERATION_INTERVAL = 180
## Stream Routes ## Stream Routes
@stream_bp.route('/streams/popular/<int:no_streams>') @stream_bp.route('/streams/popular/<int:no_streams>')
def get_popular_streams(no_streams) -> list[dict]: def popular_streams(no_streams) -> list[dict]:
""" """
Returns a list of streams live now with the highest viewers Returns a list of streams live now with the highest viewers
""" """
@@ -28,48 +28,23 @@ def get_popular_streams(no_streams) -> list[dict]:
no_streams = MAX_STREAMS no_streams = MAX_STREAMS
# Get the highest viewed streams # Get the highest viewed streams
with Database() as db: streams = get_highest_view_streams(no_streams)
streams = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return jsonify(streams) return jsonify(streams)
@stream_bp.route('/streams/popular/<string:category_name>') @stream_bp.route('/streams/popular/<string:category_name>')
def get_popular_streams_by_category(category_name) -> list[dict]: def popular_streams_by_category(category_name) -> list[dict]:
""" """
Returns a list of streams live now with the highest viewers in a given category Returns a list of streams live now with the highest viewers in a given category
""" """
with Database() as db: category_id = get_category_id(category_name)
data = db.fetchone("""
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
category_id = data["category_id"] if data else None
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT 25
""", (category_id,))
streams = get_streams_based_on_category(category_id)
return jsonify(streams) return jsonify(streams)
@login_required @login_required
@stream_bp.route('/streams/recommended') @stream_bp.route('/streams/recommended')
def get_recommended_streams() -> list[dict]: def recommended_streams() -> list[dict]:
""" """
Queries DB to get a list of recommended streams using an algorithm Queries DB to get a list of recommended streams using an algorithm
""" """
@@ -77,61 +52,22 @@ def get_recommended_streams() -> list[dict]:
user_id = session.get("user_id") user_id = session.get("user_id")
# Get the user's most popular categories # Get the user's most popular categories
with Database() as db: category = get_user_preferred_category(user_id)
category = db.fetchone(""" streams = get_streams_based_on_category(category)
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
category_id = category["category_id"] if category else None
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT 25
""", (category_id,))
return streams return streams
@stream_bp.route('/streams/<int:streamer_id>/data') @stream_bp.route('/streams/<int:streamer_id>/data')
def get_stream_data(streamer_id): def stream_data(streamer_id):
""" """
Returns a streamer's current stream data Returns a streamer's current stream data
""" """
with Database() as db:
most_recent_stream = db.fetchone("""
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
WHERE u.user_id = ?
""", (streamer_id,))
return jsonify(most_recent_stream) return jsonify(get_current_stream_data(streamer_id))
def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
Given a stream return tags associated with the user's stream
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
""", (user_id,))
return tags
## Category Routes ## Category Routes
@stream_bp.route('/categories/popular/<int:no_categories>') @stream_bp.route('/categories/popular/<int:no_categories>')
def get_popular_categories(no_categories) -> list[dict]: def popular_categories(no_categories) -> list[dict]:
""" """
Returns a list of most popular categories Returns a list of most popular categories
""" """
@@ -141,62 +77,34 @@ def get_popular_categories(no_categories) -> list[dict]:
elif no_categories > 100: elif no_categories > 100:
no_categories = 100 no_categories = 100
with Database() as db: category_data = get_highest_view_categories(no_categories)
category_data = db.fetchall("""
SELECT categories.category_id, categories.category_name, SUM(streams.num_viewers) AS num_viewers
FROM streams
JOIN categories ON streams.category_id = categories.category_id
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT ?;
""", (no_categories,))
return jsonify(category_data) return jsonify(category_data)
@login_required @login_required
@stream_bp.route('/categories/recommended') @stream_bp.route('/categories/recommended')
def get_recommended_categories() -> list | list[dict]: def recommended_categories() -> list | list[dict]:
""" """
Queries DB to get a list of recommended categories for the user Queries DB to get a list of recommended categories for the user
""" """
user_id = session.get("user_id") user_id = session.get("user_id")
categories = get_user_category_recommendations(user_id)
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 5
""", (user_id,))
return jsonify(categories) return jsonify(categories)
@login_required @login_required
@stream_bp.route('/categories/following') @stream_bp.route('/categories/following')
def get_following_categories_streams(): def following_categories_streams():
""" """
Returns popular streams in categories which the user followed Returns popular streams in categories which the user followed
""" """
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT 25;
""", (session.get("user_id"),))
streams = get_followed_categories_recommendations(session.get('user_id'))
return jsonify(streams) return jsonify(streams)
## User Routes ## User Routes
@stream_bp.route('/user/<string:username>/status') @stream_bp.route('/user/<string:username>/status')
def get_user_live_status(username): def user_live_status(username):
""" """
Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live) Returns a streamer's status, if they are live or not and their most recent stream (as a vod) (their current stream if live)
""" """
@@ -204,9 +112,7 @@ def get_user_live_status(username):
# Check if streamer is live and get their most recent vod # Check if streamer is live and get their most recent vod
is_live = True if get_streamer_live_status(user_id)['is_live'] else False is_live = True if get_streamer_live_status(user_id)['is_live'] else False
most_recent_vod = get_latest_vod(user_id)
with Database() as db:
most_recent_vod = db.fetchone("""SELECT * FROM vods WHERE user_id = ? ORDER BY vod_id DESC LIMIT 1;""", (user_id,))
# If there is no most recent vod, set it to None # If there is no most recent vod, set it to None
if not most_recent_vod: if not most_recent_vod:
@@ -222,30 +128,14 @@ def get_user_live_status(username):
## VOD Routes ## VOD Routes
@stream_bp.route('/vods/<string:username>') @stream_bp.route('/vods/<string:username>')
def get_vods(username): def vods(username):
""" """
Returns a JSON of all the vods of a streamer Returns a JSON of all the vods of a streamer
""" """
user_id = get_user_id(username) user_id = get_user_id(username)
vods = get_user_vods(user_id)
with Database() as db:
vods = db.fetchall("""SELECT * FROM vods WHERE user_id = ?;""", (user_id,))
return jsonify(vods) return jsonify(vods)
def get_vod_tags(vod_id: int):
"""
Given a vod return tags associated with the vod
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
""", (vod_id,))
return tags
## RTMP Server Routes ## RTMP Server Routes
@stream_bp.route("/publish_stream", methods=["POST"]) @stream_bp.route("/publish_stream", methods=["POST"])
@@ -294,30 +184,3 @@ def end_stream():
db.execute("""DELETE FROM streams WHERE user_id = ?""", (user_info["user_id"],)) db.execute("""DELETE FROM streams WHERE user_id = ?""", (user_info["user_id"],))
return "Stream ended", 200 return "Stream ended", 200
def transfer_stream_to_vod(user_id: int):
"""
Deletes stream from stream table and moves it to VoD table
TODO: Add functionaliy to save stream permanently
"""
with Database() as db:
stream = db.fetchone("""
SELECT * FROM streams WHERE user_id = ?;
""", (user_id,))
if not stream:
return None
## TODO: calculate length in seconds, currently using temp value
db.execute("""
INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?);
""", (stream["user_id"], stream["title"], stream["datetime"], stream["category_id"], 10, stream["num_viewers"]))
db.execute("""
DELETE FROM streams WHERE user_id = ?;
""", (user_id,))
return True

View File

@@ -1,7 +1,8 @@
from flask import Blueprint, jsonify, session from flask import Blueprint, jsonify, session
from utils.user_utils import * from utils.user_utils import *
from blueprints.utils import login_required from utils.auth import *
from blueprints.email import send_email, forgot_password_body from blueprints.middleware import login_required
from utils.email import send_email, forgot_password_body
import redis import redis
redis_url = "redis://redis:6379/1" redis_url = "redis://redis:6379/1"
@@ -10,7 +11,7 @@ r = redis.from_url(redis_url, decode_responses=True)
user_bp = Blueprint("user", __name__) user_bp = Blueprint("user", __name__)
@user_bp.route('/user/<string:username>') @user_bp.route('/user/<string:username>')
def get_user_data(username: str): def user_data(username: str):
""" """
Returns a given user's data Returns a given user's data
""" """
@@ -20,93 +21,7 @@ def get_user_data(username: str):
data = get_user(user_id) data = get_user(user_id)
return jsonify(data) return jsonify(data)
def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
def get_username(user_id: str) -> Optional[str]:
"""
Returns username associated with given user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT username
FROM user
WHERE user_id = ?
""", (user_id,))
return data['username'] if data else None
def get_email(user_id: int) -> Optional[str]:
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_session_info_email(email: str) -> dict:
"""
Returns username and user_id given email
"""
with Database() as db:
session_info = db.fetchone("""
SELECT user_id, username
FROM user
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
def get_user(user_id: int) -> Optional[dict]:
"""
Returns information about a user from user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data
## Subscription Routes ## Subscription Routes
def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
"""
Returns True if user is subscribed to a streamer, else False
"""
with Database() as db:
result = db.fetchone("""
SELECT *
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""", (user_id, subscribed_to_id, datetime.now()))
print(result)
if result:
return True
return False
@login_required @login_required
@user_bp.route('/user/subscription/<int:subscribed_id>') @user_bp.route('/user/subscription/<int:subscribed_id>')
def user_subscribed(subscribed_id: int): def user_subscribed(subscribed_id: int):
@@ -124,20 +39,9 @@ def user_subscription_expiration(subscribed_id: int):
""" """
Returns remaining time until subscription expiration Returns remaining time until subscription expiration
""" """
with Database() as db:
data = db.fetchone("""
SELECT expires
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?
""", (session.get("user_id"), subscribed_id, datetime.now()))
if data: user_id = session.get("user_id")
expiration_date = data["expires"] remaining_time = subscription_expiration(user_id, subscribed_id)
remaining_time = (parser.parse(expiration_date) - datetime.now()).seconds
else:
remaining_time = 0
return jsonify({"remaining_time": remaining_time}) return jsonify({"remaining_time": remaining_time})
@@ -163,7 +67,7 @@ def follow_user(target_user_id: int):
return follow(user_id, target_user_id) return follow(user_id, target_user_id)
@login_required @login_required
@user_bp.route('/user/unfollow/<int:target_user_id>') @user_bp.route('/user/unfollow/<string:target_user_id>')
def unfollow_user(target_user_id: int): def unfollow_user(target_user_id: int):
""" """
Unfollows a user Unfollows a user
@@ -173,40 +77,18 @@ def unfollow_user(target_user_id: int):
@login_required @login_required
@user_bp.route('/user/following') @user_bp.route('/user/following')
def get_followed_streamers(): def followed_streamers():
""" """
Queries DB to get a list of followed streamers Queries DB to get a list of followed streamers
""" """
user_id = session.get('user_id') user_id = session.get('user_id')
with Database() as db: live_following_streams = get_followed_streamers(user_id)
followed_streamers = db.fetchall(""" return live_following_streams
SELECT user_id, username
FROM users
WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?);
""", (user_id,))
return followed_streamers
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
## Login Routes ## Login Routes
@user_bp.route('/user/login_status') @user_bp.route('/user/login_status')
def get_login_status(): def login_status():
""" """
Returns whether the user is logged in or not Returns whether the user is logged in or not
""" """

View File

45
web_server/utils/auth.py Normal file
View File

@@ -0,0 +1,45 @@
from database.database import Database
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from typing import Optional
from dotenv import load_dotenv
from os import getenv
from werkzeug.security import generate_password_hash
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY"))
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
"""
token = serializer.dumps(email, salt=salt_value)
return token
def verify_token(token: str, salt_value) -> Optional[str]:
"""
Given a token, verifies and decodes it into an email
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
except SignatureExpired:
# Token expired
print("Token has expired", flush=True)
return None
except BadSignature:
# Invalid token
print("Token is invalid", flush=True)
return None
def reset_password(new_password: str, email: str) -> bool:
"""
Given email and new password reset the password for a given user
"""
with Database() as db:
db.execute("""
UPDATE users
SET password = ?
WHERE email = ?
""", (generate_password_hash(new_password), email))
return True

View File

@@ -2,9 +2,8 @@ import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from os import getenv from os import getenv
from random import randrange
from dotenv import load_dotenv from dotenv import load_dotenv
from utils.user_utils import generate_token from utils.auth import generate_token
from secrets import token_hex from secrets import token_hex
import redis import redis

View File

@@ -0,0 +1,97 @@
from database.database import Database
from typing import Optional, List
def get_user_preferred_category(user_id: int) -> Optional[int]:
"""
Queries user_preferences database to find users favourite streaming category and returns the category
"""
with Database() as db:
category = db.fetchone("""
SELECT category_id
FROM user_preferences
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 1
""", (user_id,))
return category["category_id"] if category else None
def get_followed_categories_recommendations(user_id: int, no_streams: int = 4) -> Optional[List[dict]]:
"""
Returns top streams given a user's category following
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, u.username, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
WHERE categories.category_id IN (SELECT category_id FROM followed_categories WHERE user_id = ?)
ORDER BY num_viewers DESC
LIMIT ?;
""", (user_id, no_streams))
return streams
def get_streams_based_on_category(category_id: int, no_streams: int = 4) -> Optional[List[dict]]:
"""
Queries stream database to get top most viewed streams based on given category
"""
with Database() as db:
streams = db.fetchall("""
SELECT u.user_id, title, username, num_viewers, c.category_name
FROM streams s
JOIN users u ON s.user_id = u.user_id
JOIN categories c ON s.category_id = c.category_id
WHERE c.category_id = ?
ORDER BY num_viewers DESC
LIMIT ?
""", (category_id, no_streams))
return streams
def get_highest_view_streams(no_streams: int = 4) -> Optional[List[dict]]:
"""
Return a list of live streams by number of viewers
"""
with Database() as db:
data = db.fetchall("""
SELECT u.user_id, username, title, num_viewers, category_name
FROM streams
JOIN users u ON streams.user_id = u.user_id
JOIN categories ON streams.category_id = categories.category_id
ORDER BY num_viewers DESC
LIMIT ?;
""", (no_streams,))
return data
def get_highest_view_categories(no_categories: int = 4) -> Optional[List[dict]]:
"""
Returns a list of top most popular categories
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name, SUM(streams.num_viewers) AS num_viewers
FROM streams
JOIN categories ON streams.category_id = categories.category_id
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT ?;
""", (no_categories,))
return categories
def get_user_category_recommendations(user_id: int) -> Optional[List[dict]]:
"""
Queries user_preferences database to find users top 5 favourite streaming category and returns the category
"""
with Database() as db:
categories = db.fetchall("""
SELECT categories.category_id, categories.category_name
FROM categories
JOIN user_preferences ON categories.category_id = user_preferences.category_id
WHERE user_id = ?
ORDER BY favourability DESC
LIMIT 5
""", (user_id,))
return categories

View File

@@ -1,5 +1,7 @@
from database.database import Database from database.database import Database
from typing import Optional
import os, subprocess import os, subprocess
from typing import Optional, List
def get_streamer_live_status(user_id: int): def get_streamer_live_status(user_id: int):
""" """
@@ -14,6 +16,72 @@ def get_streamer_live_status(user_id: int):
return is_live return is_live
def get_followed_live_streams(user_id: int) -> Optional[List[dict]]:
"""
Searches for streamers who the user followed which are currently live
Returns a list of live streams with the streamer's user id, stream title, and number of viewers
"""
with Database() as db:
live_streams = db.fetchall("""
SELECT users.user_id, streams.title, streams.num_viewers, users.username
FROM streams JOIN users
ON streams.user_id = users.user_id
WHERE users.user_id IN
(SELECT followed_id FROM follows WHERE user_id = ?)
AND users.is_live = 1;
""", (user_id,))
return live_streams
def get_current_stream_data(user_id: int) -> Optional[dict]:
"""
Returns data of the most recent stream by a streamer
"""
with Database() as db:
most_recent_stream = db.fetchone("""
SELECT s.user_id, u.username, s.title, s.start_time, s.num_viewers, c.category_name
FROM streams AS s
JOIN categories AS c ON s.category_id = c.category_id
JOIN users AS u ON s.user_id = u.user_id
WHERE u.user_id = ?
""", (user_id,))
return most_recent_stream
def get_category_id(category_name: str) -> Optional[int]:
"""
Returns the category_id given a category name
"""
with Database() as db:
data = db.fetchone("""
SELECT category_id
FROM categories
WHERE category_name = ?;
""", (category_name,))
return data['category_id'] if data else None
def get_vod(vod_id: int) -> dict:
"""
Returns data of a streamers vod
"""
with Database() as db:
vod = db.fetchone("""SELECT * FROM vods WHERE vod_id = ?;""", (vod_id,))
return vod
def get_latest_vod(user_id: int):
"""
Returns data of the most recent stream by a streamer
"""
with Database() as db:
latest_vod = db.fetchone("""SELECT * FROM vods WHERE user_id = ? ORDER BY vod_id DESC LIMIT 1;""", (user_id,))
return latest_vod
def get_user_vods(user_id: int):
"""
Returns data of all vods by a streamer
"""
with Database() as db:
vods = db.fetchall("""SELECT * FROM vods WHERE user_id = ?;""", (user_id,))
return vods
def generate_thumbnail(user_id: int) -> None: def generate_thumbnail(user_id: int) -> None:
""" """
@@ -42,3 +110,55 @@ def generate_thumbnail(user_id: int) -> None:
subprocess.run(thumbnail_command) subprocess.run(thumbnail_command)
def get_stream_tags(user_id: int) -> Optional[List[str]]:
"""
Given a stream return tags associated with the user's stream
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN stream_tags ON tags.tag_id = stream_tags.tag_id
WHERE user_id = ?;
""", (user_id,))
return tags
def get_vod_tags(vod_id: int):
"""
Given a vod return tags associated with the vod
"""
with Database() as db:
tags = db.fetchall("""
SELECT tag_name
FROM tags
JOIN vod_tags ON tags.tag_id = vod_tags.tag_id
WHERE vod_id = ?;
""", (vod_id,))
return tags
def transfer_stream_to_vod(user_id: int):
"""
Deletes stream from stream table and moves it to VoD table
TODO: Add functionaliy to save stream permanently
"""
with Database() as db:
stream = db.fetchone("""
SELECT * FROM streams WHERE user_id = ?;
""", (user_id,))
if not stream:
return None
## TODO: calculate length in seconds, currently using temp value
db.execute("""
INSERT INTO vods (user_id, title, datetime, category_id, length, views)
VALUES (?, ?, ?, ?, ?, ?);
""", (stream["user_id"], stream["title"], stream["datetime"], stream["category_id"], 10, stream["num_viewers"]))
db.execute("""
DELETE FROM streams WHERE user_id = ?;
""", (user_id,))
return True

View File

@@ -1,14 +1,72 @@
from database.database import Database from database.database import Database
from typing import Optional, List from typing import Optional, List
from datetime import datetime from datetime import datetime
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from os import getenv
from werkzeug.security import generate_password_hash
from dateutil import parser from dateutil import parser
from dotenv import load_dotenv
load_dotenv()
serializer = URLSafeTimedSerializer(getenv("AUTH_SECRET_KEY")) def get_user_id(username: str) -> Optional[int]:
"""
Returns user_id associated with given username
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id
FROM users
WHERE username = ?
""", (username,))
return data['user_id'] if data else None
def get_username(user_id: str) -> Optional[str]:
"""
Returns username associated with given user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT username
FROM user
WHERE user_id = ?
""", (user_id,))
return data['username'] if data else None
def get_session_info_email(email: str) -> dict:
"""
Returns username and user_id given email
"""
with Database as db:
session_info = db.fetchone("""
SELECT user_id, username
FROM user
WHERE email = ?
""", (email,))
return session_info
def is_user_partner(user_id: int) -> bool:
"""
Returns True if user is a partner, else False
"""
with Database() as db:
data = db.fetchone("""
SELECT is_partnered
FROM users
WHERE user_id = ?
""", (user_id,))
return bool(data)
def is_subscribed(user_id: int, subscribed_to_id: int) -> bool:
"""
Returns True if user is subscribed to a streamer, else False
"""
with Database() as db:
result = db.fetchone("""
SELECT *
FROM subscribes
WHERE user_id = ?
AND subscribed_id = ?
AND expires > ?;
""", (user_id, subscribed_to_id, datetime.now()))
print(result)
if result:
return True
return False
def is_following(user_id: int, followed_id: int) -> bool: def is_following(user_id: int, followed_id: int) -> bool:
""" """
@@ -51,40 +109,56 @@ def unfollow(user_id: int, followed_id: int):
""", (user_id, followed_id)) """, (user_id, followed_id))
return {"success": True} return {"success": True}
def generate_token(email, salt_value) -> str:
"""
Creates a token for password reset
"""
token = serializer.dumps(email, salt=salt_value)
return token
def verify_token(token: str, salt_value) -> Optional[str]: def subscription_expiration(user_id: int, subscribed_id: int) -> int:
""" """
Given a token, verifies and decodes it into an email Returns the amount of time left until user subscription to a streamer ends
"""
try:
email = serializer.loads(token, salt=salt_value, max_age=3600)
return email
except SignatureExpired:
# Token expired
print("Token has expired", flush=True)
return None
except BadSignature:
# Invalid token
print("Token is invalid", flush=True)
return None
def reset_password(new_password: str, email: str) -> bool:
"""
Given email and new password reset the password for a given user
""" """
with Database() as db: with Database() as db:
db.execute(""" data = db.fetchone("""
UPDATE users SELECT expires
SET password = ? FROM subscribes
WHERE email = ? WHERE user_id = ?
""", (generate_password_hash(new_password), email)) AND subscribed_id = ?
AND expires > ?
""", (user_id, subscribed_id, datetime.now()))
return True if data:
expiration_date = data["expires"]
remaining_time = (parser.parse(expiration_date) - datetime.now()).seconds
return remaining_time
return 0
def get_email(user_id: int) -> Optional[str]:
with Database() as db:
email = db.fetchone("""
SELECT email
FROM users
WHERE user_id = ?
""", (user_id,))
return email["email"] if email else None
def get_followed_streamers(user_id: int) -> Optional[List[dict]]:
"""
Returns a list of streamers who the user follows
"""
with Database() as db:
followed_streamers = db.fetchall("""
SELECT user_id, username
FROM users
WHERE user_id IN (SELECT followed_id FROM follows WHERE user_id = ?);
""", (user_id,))
return followed_streamers
def get_user(user_id: int) -> Optional[dict]:
"""
Returns information about a user from user_id
"""
with Database() as db:
data = db.fetchone("""
SELECT user_id, username, bio, num_followers, is_partnered, is_live FROM users
WHERE user_id = ?;
""", (user_id,))
return data

View File

@@ -1,39 +1,43 @@
from flask import redirect, url_for, request, g, session
from functools import wraps
from re import match
from database.database import Database from database.database import Database
from typing import Optional, List from typing import Optional, List
from re import match
def logged_in_user(): def get_all_categories() -> Optional[List[dict]]:
""" """
Validator to make sure a user is logged in. Returns all possible streaming categories
""" """
g.user = session.get("username", None) with Database() as db:
g.admin = session.get("username", None) all_categories = db.fetchall("SELECT * FROM categories")
def login_required(view): return all_categories
"""
Add at start of routes where users need to be logged in to access.
"""
@wraps(view)
def wrapped_view(*args, **kwargs):
if g.user is None:
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def admin_required(view): def get_all_tags() -> Optional[List[dict]]:
""" """
Add at start of routes where admins need to be logged in to access. Returns all possible streaming tags
""" """
@wraps(view) with Database() as db:
def wrapped_view(*args, **kwargs): all_tags = db.fetchall("SELECT * FROM tags")
if g.admin != "admin":
return redirect(url_for("login", next=request.url))
return view(*args, **kwargs)
return wrapped_view
def sanitizer(user_input: str, input_type="username") -> str: return all_tags
def get_most_popular_category() -> Optional[List[dict]]:
"""
Returns the most popular category based on live stream viewers
"""
with Database() as db:
category = db.fetchone("""
SELECT categories.category_id, categories.category_name
FROM streams
JOIN categories ON streams.category_id = categories.category_id
WHERE streams.isLive = 1
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
""")
return category
def sanitize(user_input: str, input_type="username") -> str:
""" """
Sanitizes user input based on the specified input type. Sanitizes user input based on the specified input type.
@@ -69,38 +73,3 @@ def sanitizer(user_input: str, input_type="username") -> str:
raise ValueError("Unaccepted character or length in input") raise ValueError("Unaccepted character or length in input")
return sanitised_input return sanitised_input
def categories() -> Optional[List[dict]]:
"""
Returns all possible streaming categories
"""
with Database() as db:
all_categories = db.fetchall("SELECT * FROM categories")
return all_categories
def tags() -> Optional[List[dict]]:
"""
Returns all possible streaming tags
"""
with Database() as db:
all_tags = db.fetchall("SELECT * FROM tags")
return all_tags
def most_popular_category() -> Optional[List[dict]]:
"""
Returns the most popular category based on live stream viewers
"""
with Database() as db:
category = db.fetchone("""
SELECT categories.category_id, categories.category_name
FROM streams
JOIN categories ON streams.category_id = categories.category_id
WHERE streams.isLive = 1
GROUP BY categories.category_name
ORDER BY SUM(streams.num_viewers) DESC
LIMIT 1;
""")
return category