diff --git a/web_server/blueprints/user.py b/web_server/blueprints/user.py index a6d4918..4622b1f 100644 --- a/web_server/blueprints/user.py +++ b/web_server/blueprints/user.py @@ -22,6 +22,17 @@ def user_data(username: str): return jsonify(data) ## Subscription Routes +@login_required +@user_bp.route('/user/subscribe/') +def user_subscribe(streamer_id): + """ + Given a streamer subscribes as user + """ + #TODO: Keep this route secure so only webhooks from Stripe payment can trigger it + user_id = session.get("user_id") + subscribe(user_id, streamer_id) + return jsonify({"status": True}) + @login_required @user_bp.route('/user/subscription/') def user_subscribed(subscribed_id: int): @@ -42,6 +53,9 @@ def user_subscription_expiration(subscribed_id: int): user_id = session.get("user_id") remaining_time = subscription_expiration(user_id, subscribed_id) + # Remove any expired subscriptions from the table + if remaining_time == 0: + delete_subscription(user_id, subscribed_id) return jsonify({"remaining_time": remaining_time}) diff --git a/web_server/utils/user_utils.py b/web_server/utils/user_utils.py index 1fd796f..91de868 100644 --- a/web_server/utils/user_utils.py +++ b/web_server/utils/user_utils.py @@ -1,6 +1,6 @@ from database.database import Database from typing import Optional, List -from datetime import datetime +from datetime import datetime, timedelta from dateutil import parser def get_user_id(username: str) -> Optional[int]: @@ -109,6 +109,39 @@ def unfollow(user_id: int, followed_id: int): """, (user_id, followed_id)) return {"success": True} +def subscribe(user_id: int, streamer_id: int): + """ + Subscribes user_id to streamer_id + """ + # If user is already subscribed then extend the expiration date else create a new entry + with Database() as db: + existing = db.fetchone(""" + SELECT expires + FROM subscribes + WHERE user_id = ? AND subscribed_id = ? + """, (user_id, streamer_id)) + if existing: + db.execute(""" + UPDATE subscribes SET expires = expires + ? + WHERE user_id = ? AND subscribed_id = ? + """, (timedelta(days=30), user_id, streamer_id)) + else: + db.execute(""" + INSERT INTO subscribes + (user_id, subscribed_id, since, expires) + VALUES (?,?,?,?) + """, (user_id, streamer_id, datetime.now(), datetime.now() + timedelta(days=30))) + +def delete_subscription(user_id: int, subscribed_id: int): + """ + Deletes a subscription entry given user_id and streamer_id + """ + with Database() as db: + db.execute(""" + DELETE FROM subscribes + WHERE user_id = ? AND subscribed_id = ? + """, (user_id, subscribed_id)) + def subscription_expiration(user_id: int, subscribed_id: int) -> int: """