Source code for wimarka.tasks.error_detection

import torch
from wimarka.utils.model import load_model
from wimarka.utils.torch import get_device, move_model_to_device
import re

[docs] def split_words(text: str): if not isinstance(text, str): return [""] return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)
[docs] def tokenize_with_spans(text: str): if not isinstance(text, str): return [] pattern = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE) spans = [] for m in pattern.finditer(text): spans.append((m.group(0), m.start(), m.end())) return spans
[docs] def format_tagged_sentence_using_spans(original_text: str, token_spans, labels): ops = [] i = 0 n = min(len(token_spans), len(labels)) while i < n: lbl = labels[i] if lbl == "O": i += 1 continue start_pos = token_spans[i][1] j = i end_pos = token_spans[i][2] while j + 1 < n and labels[j + 1] == lbl: j += 1 end_pos = token_spans[j][2] ops.append((end_pos, f"[/{lbl}]")) ops.append((start_pos, f"[{lbl}]")) i = j + 1 ops.sort(key=lambda x: x[0], reverse=True) out = original_text for pos, text_to_insert in ops: out = out[:pos] + text_to_insert + out[pos:] return out
[docs] def error_detection(source_sentence, target_sentence): label_list = ["O", "MI_ST", "MI_SE", "MA_ST", "MA_SE"] label_names = label_list model, tokenizer = load_model("error_detection") device = get_device() model = move_model_to_device(model, device) model.eval() src_words = split_words(source_sentence) tgt_token_spans = tokenize_with_spans(target_sentence) tgt_words = [t[0] for t in tgt_token_spans] enc = tokenizer( [src_words], [tgt_words], is_split_into_words=True, truncation=True, padding="max_length", max_length=128, return_tensors="pt" ) input_ids = enc["input_ids"].to(device) attention_mask = enc["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits preds = torch.argmax(logits, dim=-1).cpu().numpy()[0] word_ids = enc.word_ids(batch_index=0) seq_ids = enc.sequence_ids(batch_index=0) assigned = {} word_label_ids = [-100] * len(tgt_words) for token_idx, word_idx in enumerate(word_ids): if seq_ids[token_idx] != 1: continue if word_idx is None or word_idx < 0 or word_idx >= len(tgt_words): continue if word_idx in assigned: continue assigned[word_idx] = True token_label_id = int(preds[token_idx]) word_label_ids[word_idx] = token_label_id final_labels = [label_list[lid] if (lid != -100 and 0 <= lid < len(label_list)) else "O" for lid in word_label_ids] annotated = format_tagged_sentence_using_spans(target_sentence, tgt_token_spans, final_labels) return annotated