feat: add login endpoint
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
|
||||||
|
|
||||||
class PostgresConnector:
|
class PostgresConnector:
|
||||||
@@ -18,19 +19,34 @@ class PostgresConnector:
|
|||||||
self.connection.autocommit = False
|
self.connection.autocommit = False
|
||||||
|
|
||||||
def execute(self, query, params=None, fetch=False):
|
def execute(self, query, params=None, fetch=False):
|
||||||
with self.connection.cursor() as cursor:
|
with self.connection.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
|
|
||||||
if fetch:
|
if fetch:
|
||||||
return cursor.fetchall()
|
return cursor.fetchall()
|
||||||
|
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def executemany(self, query, param_list):
|
def executemany(self, query, param_list):
|
||||||
with self.connection.cursor() as cursor:
|
with self.connection.cursor(cursor_factory=RealDictCursor) as cursor:
|
||||||
cursor.executemany(query, param_list)
|
cursor.executemany(query, param_list)
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
|
def save_user(self, username, email, password_hash):
|
||||||
|
query = """
|
||||||
|
INSERT INTO users (username, email, password_hash)
|
||||||
|
VALUES (%s, %s, %s)
|
||||||
|
"""
|
||||||
|
self.execute(query, (username, email, password_hash))
|
||||||
|
|
||||||
|
def get_user_by_username(self, username) -> dict:
|
||||||
|
query = "SELECT id, username, email, password_hash FROM users WHERE username = %s"
|
||||||
|
result = self.execute(query, (username,), fetch=True)
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def get_user_by_email(self, email) -> dict:
|
||||||
|
query = "SELECT id, username, email, password_hash FROM users WHERE email = %s"
|
||||||
|
result = self.execute(query, (email,), fetch=True)
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.connection:
|
if self.connection:
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
@@ -13,6 +13,7 @@ from flask_jwt_extended import (
|
|||||||
|
|
||||||
from server.stat_gen import StatGen
|
from server.stat_gen import StatGen
|
||||||
from db.database import PostgresConnector
|
from db.database import PostgresConnector
|
||||||
|
from server.auth import AuthManager
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import traceback
|
import traceback
|
||||||
@@ -34,28 +35,57 @@ app.config["JWT_ACCESS_TOKEN_EXPIRES"] = jwt_access_token_expires
|
|||||||
|
|
||||||
bcrypt = Bcrypt(app)
|
bcrypt = Bcrypt(app)
|
||||||
jwt = JWTManager(app)
|
jwt = JWTManager(app)
|
||||||
|
auth_manager = AuthManager(db, bcrypt)
|
||||||
|
|
||||||
# Global State
|
# Global State
|
||||||
posts_df = pd.read_json('small.jsonl', lines=True)
|
# posts_df = pd.read_json('small.jsonl', lines=True)
|
||||||
with open("topic_buckets.json", "r", encoding="utf-8") as f:
|
# with open("topic_buckets.json", "r", encoding="utf-8") as f:
|
||||||
domain_topics = json.load(f)
|
# domain_topics = json.load(f)
|
||||||
stat_obj = StatGen(posts_df, domain_topics)
|
# stat_obj = StatGen(posts_df, domain_topics)
|
||||||
|
stat_obj = None
|
||||||
|
|
||||||
@app.route('/register', methods=['POST'])
|
@app.route('/register', methods=['POST'])
|
||||||
def register_user():
|
def register_user():
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
|
if not data or "username" not in data or "email" not in data or "password" not in data:
|
||||||
|
return jsonify({"error": "Missing username, email, or password"}), 400
|
||||||
|
|
||||||
|
username = data["username"]
|
||||||
|
email = data["email"]
|
||||||
|
password = data["password"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_manager.register_user(username, email, password)
|
||||||
|
except ValueError as e:
|
||||||
|
return jsonify({"error": str(e)}), 400
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500
|
||||||
|
|
||||||
|
print(f"Registered new user: {username}")
|
||||||
|
return jsonify({"message": f"User '{username}' registered successfully"}), 200
|
||||||
|
|
||||||
|
@app.route('/login', methods=['POST'])
|
||||||
|
def login_user():
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
if not data or "username" not in data or "password" not in data:
|
if not data or "username" not in data or "password" not in data:
|
||||||
return jsonify({"error": "Missing username or password"}), 400
|
return jsonify({"error": "Missing username or password"}), 400
|
||||||
|
|
||||||
username = data["username"]
|
username = data["username"]
|
||||||
hashed_password = bcrypt.generate_password_hash(
|
password = data["password"]
|
||||||
data["password"]
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
|
try:
|
||||||
print(f"Registered new user: {username}")
|
user = auth_manager.authenticate_user(username, password)
|
||||||
return jsonify({"message": f"User '{username}' registered successfully"}), 200
|
if user:
|
||||||
|
access_token = create_access_token(identity=user['id'])
|
||||||
|
return jsonify({"access_token": access_token}), 200
|
||||||
|
else:
|
||||||
|
return jsonify({"error": "Invalid username or password"}), 401
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return jsonify({"error": f"An unexpected error occurred: {str(e)}"}), 500
|
||||||
|
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
def upload_data():
|
def upload_data():
|
||||||
|
|||||||
@@ -1,12 +1,24 @@
|
|||||||
from db.database import PostgresConnector
|
from db.database import PostgresConnector
|
||||||
|
from flask_bcrypt import Bcrypt
|
||||||
|
|
||||||
class AuthManager:
|
class AuthManager:
|
||||||
def __init__(self, db: PostgresConnector, bcrypt):
|
def __init__(self, db: PostgresConnector, bcrypt: Bcrypt):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.bcrypt = bcrypt
|
self.bcrypt = bcrypt
|
||||||
|
|
||||||
def register_user(self, username, password):
|
def register_user(self, username, email, password):
|
||||||
# Hash the password
|
|
||||||
hashed_password = self.bcrypt.generate_password_hash(password).decode("utf-8")
|
hashed_password = self.bcrypt.generate_password_hash(password).decode("utf-8")
|
||||||
# Save the user to the database
|
|
||||||
self.db.save_user(username, hashed_password)
|
if self.db.get_user_by_email(email):
|
||||||
|
raise ValueError("Email already registered")
|
||||||
|
|
||||||
|
if self.db.get_user_by_username(username):
|
||||||
|
raise ValueError("Username already taken")
|
||||||
|
|
||||||
|
self.db.save_user(username, email, hashed_password)
|
||||||
|
|
||||||
|
def authenticate_user(self, username, password):
|
||||||
|
user = self.db.get_user_by_username(username)
|
||||||
|
if user and self.bcrypt.check_password_hash(user['password_hash'], password):
|
||||||
|
return user
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user