style: run python linter & prettifier on backend code
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from transformers import pipeline
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class NLP:
|
||||
_topic_models: dict[str, SentenceTransformer] = {}
|
||||
_emotion_classifiers: dict[str, Any] = {}
|
||||
@@ -32,7 +33,7 @@ class NLP:
|
||||
)
|
||||
self.entity_recognizer = self._get_entity_recognizer(
|
||||
self.device_str, self.pipeline_device
|
||||
)
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if self.use_cuda and "out of memory" in str(exc).lower():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -90,7 +91,7 @@ class NLP:
|
||||
)
|
||||
cls._emotion_classifiers[device_str] = classifier
|
||||
return classifier
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_entity_recognizer(cls, device_str: str, pipeline_device: int) -> Any:
|
||||
recognizer = cls._entity_recognizers.get(device_str)
|
||||
@@ -207,8 +208,7 @@ class NLP:
|
||||
self.df.drop(columns=existing_drop, inplace=True)
|
||||
|
||||
remaining_emotion_cols = [
|
||||
c for c in self.df.columns
|
||||
if c.startswith("emotion_")
|
||||
c for c in self.df.columns if c.startswith("emotion_")
|
||||
]
|
||||
|
||||
if remaining_emotion_cols:
|
||||
@@ -227,8 +227,6 @@ class NLP:
|
||||
|
||||
self.df[remaining_emotion_cols] = normalized.values
|
||||
|
||||
|
||||
|
||||
def add_topic_col(self, confidence_threshold: float = 0.3) -> None:
|
||||
titles = self.df[self.title_col].fillna("").astype(str)
|
||||
contents = self.df[self.content_col].fillna("").astype(str)
|
||||
@@ -257,7 +255,7 @@ class NLP:
|
||||
self.df.loc[self.df["topic_confidence"] < confidence_threshold, "topic"] = (
|
||||
"Misc"
|
||||
)
|
||||
|
||||
|
||||
def add_ner_cols(self, max_chars: int = 512) -> None:
|
||||
texts = (
|
||||
self.df[self.content_col]
|
||||
@@ -302,8 +300,4 @@ class NLP:
|
||||
|
||||
for label in all_labels:
|
||||
col_name = f"entity_{label}"
|
||||
self.df[col_name] = [
|
||||
d.get(label, 0) for d in entity_count_dicts
|
||||
]
|
||||
|
||||
|
||||
self.df[col_name] = [d.get(label, 0) for d in entity_count_dicts]
|
||||
|
||||
Reference in New Issue
Block a user