diff --git a/connectors/reddit_api.py b/connectors/reddit_api.py index 11f4e06..8fe5a99 100644 --- a/connectors/reddit_api.py +++ b/connectors/reddit_api.py @@ -28,8 +28,9 @@ class RedditAPI: data = self._fetch_data(url, params) return self._parse_posts(data) - def get_new_subreddit_posts(self, subreddit: str, limit: int = 10) -> list[Post]: + def get_new_subreddit_posts(self, subreddit: str, limit: int = 10) -> tuple[list[Post], list[Comment]]: posts = [] + comments = [] after = None url = f"r/{subreddit}/new.json" @@ -43,27 +44,30 @@ class RedditAPI: } data = self._fetch_data(url, params) - batch = self._parse_posts(data) + batch_posts, batch_comments = self._parse_posts(data) - logger.debug(f"Fetched {len(batch)} new posts from subreddit {subreddit}") + logger.debug(f"Fetched {len(batch_posts)} new posts and {len(batch_comments)} comments from subreddit {subreddit}") - if not batch: + if not batch_posts: break - - posts.extend(batch) + + posts.extend(batch_posts) + comments.extend(batch_comments) after = data['data'].get('after') if not after: break - return posts + return posts, comments def get_user(self, username: str) -> User: data = self._fetch_data(f"user/{username}/about.json", {}) return self._parse_user(data) ## Private Methods ## - def _parse_posts(self, data) -> list[Post]: + def _parse_posts(self, data) -> tuple[list[Post], list[Comment]]: posts = [] + comments = [] + total_num_posts = len(data['data']['children']) current_index = 0 @@ -84,8 +88,9 @@ class RedditAPI: post.upvotes = post_data['ups'] posts.append(post) - return posts - + comments.extend(self._get_post_comments(post.id)) + return posts, comments + def _get_post_comments(self, post_id: str) -> list[Comment]: comments: list[Comment] = [] url = f"comments/{post_id}.json" @@ -111,7 +116,6 @@ class RedditAPI: reply_to=parent_id or comment_info.get('parent_id', None), source=self.source_name ) - comment.upvotes = comment_info.get('ups', 0) comments.append(comment)