From de61e7653f40fc14429c2b4e8e63bde205dff708 Mon Sep 17 00:00:00 2001 From: Dylan De Faoite Date: Sat, 4 Apr 2026 12:26:54 +0100 Subject: [PATCH] perf(connector): add reddit API authentication to speed up fetching This aligns better with ethics and massively increases rate limits. --- example.env | 2 ++ server/connectors/reddit_api.py | 48 +++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/example.env b/example.env index 5e62ce6..e718a5a 100644 --- a/example.env +++ b/example.env @@ -1,5 +1,7 @@ # API Keys YOUTUBE_API_KEY= +REDDIT_CLIENT_ID= +REDDIT_CLIENT_SECRET= # Database POSTGRES_USER= diff --git a/server/connectors/reddit_api.py b/server/connectors/reddit_api.py index 90a08d8..dc07daf 100644 --- a/server/connectors/reddit_api.py +++ b/server/connectors/reddit_api.py @@ -1,6 +1,10 @@ import requests import logging import time +import os + +from dotenv import load_dotenv +from requests.auth import HTTPBasicAuth from dto.post import Post from dto.user import User @@ -9,6 +13,8 @@ from server.connectors.base import BaseConnector logger = logging.getLogger(__name__) +CLIENT_ID = os.getenv("REDDIT_CLIENT_ID") +CLIENT_SECRET = os.getenv("REDDIT_CLIENT_SECRET") class RedditAPI(BaseConnector): source_name: str = "reddit" @@ -18,6 +24,8 @@ class RedditAPI(BaseConnector): def __init__(self): self.url = "https://www.reddit.com/" + self.token = None + self.token_expiry = 0 # Public Methods # def get_new_posts_by_search( @@ -171,9 +179,44 @@ class RedditAPI(BaseConnector): user = User(username=user_data["name"], created_utc=user_data["created_utc"]) user.karma = user_data["total_karma"] return user + + def _get_token(self): + if self.token and time.time() < self.token_expiry: + return self.token + + logger.info("Fetching new Reddit access token...") + + auth = HTTPBasicAuth(CLIENT_ID, CLIENT_SECRET) + + data = { + "grant_type": "client_credentials" + } + + headers = { + "User-Agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)" + } + + response = requests.post( + "https://www.reddit.com/api/v1/access_token", + auth=auth, + data=data, + headers=headers, + ) + + response.raise_for_status() + token_json = response.json() + + self.token = token_json["access_token"] + self.token_expiry = time.time() + token_json["expires_in"] - 60 + + logger.info( + f"Obtained new Reddit access token (expires in {token_json['expires_in']}s)" + ) + + return self.token def _fetch_post_overviews(self, endpoint: str, params: dict) -> dict: - url = f"{self.url}{endpoint}" + url = f"https://oauth.reddit.com/{endpoint.lstrip('/')}" max_retries = 15 backoff = 1 # seconds @@ -182,7 +225,8 @@ class RedditAPI(BaseConnector): response = requests.get( url, headers={ - "User-agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)" + "User-agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)", + "Authorization": f"Bearer {self._get_token()}", }, params=params, )