|
|
|
|
@@ -1,6 +1,10 @@
|
|
|
|
|
import requests
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
from requests.auth import HTTPBasicAuth
|
|
|
|
|
|
|
|
|
|
from dto.post import Post
|
|
|
|
|
from dto.user import User
|
|
|
|
|
@@ -9,6 +13,8 @@ from server.connectors.base import BaseConnector
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
CLIENT_ID = os.getenv("REDDIT_CLIENT_ID")
|
|
|
|
|
CLIENT_SECRET = os.getenv("REDDIT_CLIENT_SECRET")
|
|
|
|
|
|
|
|
|
|
class RedditAPI(BaseConnector):
|
|
|
|
|
source_name: str = "reddit"
|
|
|
|
|
@@ -18,6 +24,8 @@ class RedditAPI(BaseConnector):
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.url = "https://www.reddit.com/"
|
|
|
|
|
self.token = None
|
|
|
|
|
self.token_expiry = 0
|
|
|
|
|
|
|
|
|
|
# Public Methods #
|
|
|
|
|
def get_new_posts_by_search(
|
|
|
|
|
@@ -171,9 +179,44 @@ class RedditAPI(BaseConnector):
|
|
|
|
|
user = User(username=user_data["name"], created_utc=user_data["created_utc"])
|
|
|
|
|
user.karma = user_data["total_karma"]
|
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
def _get_token(self):
|
|
|
|
|
if self.token and time.time() < self.token_expiry:
|
|
|
|
|
return self.token
|
|
|
|
|
|
|
|
|
|
logger.info("Fetching new Reddit access token...")
|
|
|
|
|
|
|
|
|
|
auth = HTTPBasicAuth(CLIENT_ID, CLIENT_SECRET)
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"grant_type": "client_credentials"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
"User-Agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response = requests.post(
|
|
|
|
|
"https://www.reddit.com/api/v1/access_token",
|
|
|
|
|
auth=auth,
|
|
|
|
|
data=data,
|
|
|
|
|
headers=headers,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
token_json = response.json()
|
|
|
|
|
|
|
|
|
|
self.token = token_json["access_token"]
|
|
|
|
|
self.token_expiry = time.time() + token_json["expires_in"] - 60
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Obtained new Reddit access token (expires in {token_json['expires_in']}s)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self.token
|
|
|
|
|
|
|
|
|
|
def _fetch_post_overviews(self, endpoint: str, params: dict) -> dict:
|
|
|
|
|
url = f"{self.url}{endpoint}"
|
|
|
|
|
url = f"https://oauth.reddit.com/{endpoint.lstrip('/')}"
|
|
|
|
|
max_retries = 15
|
|
|
|
|
backoff = 1 # seconds
|
|
|
|
|
|
|
|
|
|
@@ -182,13 +225,14 @@ class RedditAPI(BaseConnector):
|
|
|
|
|
response = requests.get(
|
|
|
|
|
url,
|
|
|
|
|
headers={
|
|
|
|
|
"User-agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)"
|
|
|
|
|
"User-agent": "python:ethnography-college-project:0.1 (by /u/ThisBirchWood)",
|
|
|
|
|
"Authorization": f"Bearer {self._get_token()}",
|
|
|
|
|
},
|
|
|
|
|
params=params,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if response.status_code == 429:
|
|
|
|
|
wait_time = response.headers.get("Retry-After", backoff)
|
|
|
|
|
wait_time = response.headers.get("X-Ratelimit-Reset", backoff)
|
|
|
|
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Rate limited by Reddit API. Retrying in {wait_time} seconds..."
|
|
|
|
|
|