import torch
from wimarka.utils.model import load_model
from wimarka.utils.torch import get_device, move_model_to_device
import re
[docs]
def split_words(text: str):
if not isinstance(text, str):
return [""]
return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)
[docs]
def tokenize_with_spans(text: str):
if not isinstance(text, str):
return []
pattern = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE)
spans = []
for m in pattern.finditer(text):
spans.append((m.group(0), m.start(), m.end()))
return spans
[docs]
def error_detection(source_sentence, target_sentence):
label_list = ["O", "MI_ST", "MI_SE", "MA_ST", "MA_SE"]
label_names = label_list
model, tokenizer = load_model("error_detection")
device = get_device()
model = move_model_to_device(model, device)
model.eval()
src_words = split_words(source_sentence)
tgt_token_spans = tokenize_with_spans(target_sentence)
tgt_words = [t[0] for t in tgt_token_spans]
enc = tokenizer(
[src_words],
[tgt_words],
is_split_into_words=True,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1).cpu().numpy()[0]
word_ids = enc.word_ids(batch_index=0)
seq_ids = enc.sequence_ids(batch_index=0)
assigned = {}
word_label_ids = [-100] * len(tgt_words)
for token_idx, word_idx in enumerate(word_ids):
if seq_ids[token_idx] != 1:
continue
if word_idx is None or word_idx < 0 or word_idx >= len(tgt_words):
continue
if word_idx in assigned:
continue
assigned[word_idx] = True
token_label_id = int(preds[token_idx])
word_label_ids[word_idx] = token_label_id
final_labels = [label_list[lid] if (lid != -100 and 0 <= lid < len(label_list)) else "O" for lid in word_label_ids]
annotated = format_tagged_sentence_using_spans(target_sentence, tgt_token_spans, final_labels)
return annotated