feat(nlp): add Named Entity Recognition to dataset
This commit is contained in:
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer
|
||||
class NLP:
|
||||
_topic_models: dict[str, SentenceTransformer] = {}
|
||||
_emotion_classifiers: dict[str, Any] = {}
|
||||
_entity_recognizers: dict[str, Any] = {}
|
||||
_topic_embedding_cache: dict[tuple[str, ...], np.ndarray] = {}
|
||||
|
||||
def __init__(
|
||||
@@ -29,6 +30,9 @@ class NLP:
|
||||
self.emotion_classifier = self._get_emotion_classifier(
|
||||
self.device_str, self.pipeline_device
|
||||
)
|
||||
self.entity_recognizer = self._get_entity_recognizer(
|
||||
self.device_str, self.pipeline_device
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if self.use_cuda and "out of memory" in str(exc).lower():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -87,6 +91,27 @@ class NLP:
|
||||
cls._emotion_classifiers[device_str] = classifier
|
||||
return classifier
|
||||
|
||||
@classmethod
|
||||
def _get_entity_recognizer(cls, device_str: str, pipeline_device: int) -> Any:
|
||||
recognizer = cls._entity_recognizers.get(device_str)
|
||||
if recognizer is None:
|
||||
pipeline_kwargs = {
|
||||
"aggregation_strategy": "simple", # merges subwords
|
||||
"device": pipeline_device,
|
||||
}
|
||||
|
||||
if device_str == "cuda":
|
||||
pipeline_kwargs["dtype"] = torch.float16
|
||||
|
||||
recognizer = pipeline(
|
||||
"token-classification",
|
||||
model="dslim/bert-base-NER",
|
||||
**pipeline_kwargs,
|
||||
)
|
||||
cls._entity_recognizers[device_str] = recognizer
|
||||
|
||||
return recognizer
|
||||
|
||||
def _encode_with_backoff(
|
||||
self, texts: list[str], initial_batch_size: int
|
||||
) -> np.ndarray:
|
||||
@@ -129,6 +154,26 @@ class NLP:
|
||||
continue
|
||||
raise
|
||||
|
||||
def _infer_entities_with_backoff(
|
||||
self, texts: list[str], initial_batch_size: int
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
|
||||
batch_size = initial_batch_size
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self.entity_recognizer(texts, batch_size=batch_size)
|
||||
except RuntimeError as exc:
|
||||
if (
|
||||
self.use_cuda
|
||||
and "out of memory" in str(exc).lower()
|
||||
and batch_size > 4
|
||||
):
|
||||
batch_size = max(4, batch_size // 2)
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
raise
|
||||
|
||||
def add_emotion_cols(self) -> None:
|
||||
texts = self.df[self.content_col].astype(str).str.slice(0, 512).tolist()
|
||||
|
||||
@@ -183,3 +228,51 @@ class NLP:
|
||||
self.df.loc[self.df["topic_confidence"] < confidence_threshold, "topic"] = (
|
||||
"Misc"
|
||||
)
|
||||
|
||||
def add_ner_cols(self, max_chars: int = 512) -> None:
|
||||
texts = (
|
||||
self.df[self.content_col]
|
||||
.fillna("")
|
||||
.astype(str)
|
||||
.str.slice(0, max_chars)
|
||||
.tolist()
|
||||
)
|
||||
|
||||
if not texts:
|
||||
self.df["entities"] = []
|
||||
self.df["entity_counts"] = []
|
||||
return
|
||||
|
||||
results = self._infer_entities_with_backoff(texts, 32 if self.use_cuda else 8)
|
||||
|
||||
entity_lists = []
|
||||
entity_count_dicts = []
|
||||
|
||||
for row in results:
|
||||
entities = []
|
||||
counts = {}
|
||||
|
||||
for ent in row:
|
||||
word = ent.get("word")
|
||||
label = ent.get("entity_group")
|
||||
|
||||
if isinstance(word, str) and isinstance(label, str):
|
||||
entities.append({"text": word, "label": label})
|
||||
counts[label] = counts.get(label, 0) + 1
|
||||
|
||||
entity_lists.append(entities)
|
||||
entity_count_dicts.append(counts)
|
||||
|
||||
self.df["entities"] = entity_lists
|
||||
self.df["entity_counts"] = entity_count_dicts
|
||||
|
||||
# Expand label counts into columns
|
||||
all_labels = set()
|
||||
for d in entity_count_dicts:
|
||||
all_labels.update(d.keys())
|
||||
|
||||
for label in all_labels:
|
||||
col_name = f"entity_{label}"
|
||||
self.df[col_name] = [
|
||||
d.get(label, 0) for d in entity_count_dicts
|
||||
]
|
||||
|
||||
@@ -58,6 +58,7 @@ class StatGen:
|
||||
|
||||
self.nlp.add_emotion_cols()
|
||||
self.nlp.add_topic_col()
|
||||
self.nlp.add_ner_cols()
|
||||
|
||||
## Public
|
||||
def time_analysis(self) -> pd.DataFrame:
|
||||
|
||||
Reference in New Issue
Block a user