diff --git a/web_server/blueprints/user.py b/web_server/blueprints/user.py index 146b735..ac96ac6 100644 --- a/web_server/blueprints/user.py +++ b/web_server/blueprints/user.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify, session from utils.user_utils import * from utils.auth import * +from utils.utils import get_category_id from blueprints.middleware import login_required from utils.email import send_email, forgot_password_body import redis @@ -101,6 +102,37 @@ def user_followed_streamers(): live_following_streams = get_followed_streamers(user_id) return live_following_streams +@login_required +@user_bp.route('/user/category/follow/') +def user_follow_category(category_name): + """ + Follows a category + """ + user_id = session.get("user_id") + category_id = get_category_id(category_name) + return follow_category(user_id, category_id) + +@login_required +@user_bp.route('/user/category/unfollow/') +def user_unfollow_category(category_name): + """ + Unfollows a category + """ + user_id = session.get("user_id") + category_id = get_category_id(category_name) + return unfollow_category(user_id, category_id) + +@user_bp.route('/user/category/following/') +def user_category_following(category_name: str): + """ + Checks to see if user is following a category + """ + user_id = session.get("user_id") + category_id = get_category_id(category_name) + if is_following_category(user_id, category_id): + return jsonify({"following": True}) + return jsonify({"following": False}) + ## Login Routes @user_bp.route('/user/login_status') def user_login_status(): diff --git a/web_server/database/app.db b/web_server/database/app.db index 3f252dd..e5f707d 100644 Binary files a/web_server/database/app.db and b/web_server/database/app.db differ diff --git a/web_server/database/testing_data.sql b/web_server/database/testing_data.sql index 93abead..0f78fac 100644 --- a/web_server/database/testing_data.sql +++ b/web_server/database/testing_data.sql @@ -60,11 +60,11 @@ INSERT INTO categories (category_name) VALUES ('Dota 2'), ('Apex Legends'), ('Grand Theft Auto V'), -('The Legend of Zelda: Breath of the Wild'), +('The Legend of Zelda Breath of the Wild'), ('Elden Ring'), ('Red Dead Redemption 2'), ('Cyberpunk 2077'), -('Super Smash Bros. Ultimate'), +('Super Smash Bros Ultimate'), ('Overwatch 2'), ('Genshin Impact'), ('World of Warcraft'), diff --git a/web_server/utils/user_utils.py b/web_server/utils/user_utils.py index 420e36c..3e1e4ee 100644 --- a/web_server/utils/user_utils.py +++ b/web_server/utils/user_utils.py @@ -106,6 +106,49 @@ def unfollow(user_id: int, followed_id: int): """, (user_id, followed_id)) return {"success": True} +def is_following_category(user_id: int, category_id: str): + """ + Checks if user is following category + """ + with Database() as db: + result = db.fetchone(""" + SELECT 1 + FROM followed_categories + WHERE user_id = ? + AND category_id = ? + """, (user_id, category_id)) + return bool(result) + +def follow_category(user_id: int, category_id: str): + """ + Follows category given user_id + """ + if is_following_category(user_id, category_id): + return {"success": False, "error": "Already following category"}, 400 + + with Database() as db: + db.execute(""" + INSERT INTO followed_categories (user_id, category_id) + VALUES(?,?); + """, (user_id, category_id)) + return {"success": True} + + +def unfollow_category(user_id: int, category_id: str): + """ + Unfollows category given user_id + """ + if not is_following_category(user_id, category_id): + return {"success": False, "error": "Not following category"}, 400 + + with Database() as db: + db.execute(""" + DELETE FROM followed_categories + WHERE user_id = ? + AND category_id = ? + """, (user_id, category_id)) + return {"success": True} + def subscribe(user_id: int, streamer_id: int): """ Subscribes user_id to streamer_id diff --git a/web_server/utils/utils.py b/web_server/utils/utils.py index 3c2f833..cee94a2 100644 --- a/web_server/utils/utils.py +++ b/web_server/utils/utils.py @@ -37,6 +37,19 @@ def get_most_popular_category() -> Optional[List[dict]]: return category +def get_category_id(category_name: str): + """ + Returns category_id given category_name + """ + with Database() as db: + category = db.fetchone(""" + SELECT category_id + FROM categories + WHERE category_name = ? + """, (category_name,)) + + return category["category_id"] + def sanitize(user_input: str, input_type="default") -> str: """ Sanitizes user input based on the specified input type.