Source code for wimarka.utils.model

from llama_cpp import Llama
from transformers import AutoTokenizer, DistilBertForTokenClassification
from joblib import load
from huggingface_hub import hf_hub_download
import os

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["LLAMA_LOG_LEVEL"] = "OFF"
os.environ["LLAMA_QUIET"] = "1"
os.environ["GGML_LOG_LEVEL"] = "ERROR"

[docs] def load_model(model_name: str): model_path = f"{model_name}" if model_name == "regression": model = hf_hub_download(repo_id="WiMarka/Random_Forest_Regression", filename="regression.joblib") return load(model) elif model_name == "explanation": return Llama.from_pretrained( repo_id="WiMarka/Gemma_3_12B_IT_INT8_Explanation_Generation", filename="gemma_explanation_3.gguf", chat_format="chatml", n_ctx=768, verbose=False ) elif model_name == "correction": return Llama.from_pretrained( repo_id="WiMarka/Gemma_3_12B_IT_INT8", filename="base_gemma_it_quantized.gguf", chat_format="chatml", n_ctx=1024, verbose=False ) if model_name == "error_detection": model_path = "WiMarka/DistilBERT_Error_Detection" tokenizer = AutoTokenizer.from_pretrained(model_path) model = DistilBertForTokenClassification.from_pretrained(model_path) return model, tokenizer else: raise ValueError( f"Model name {model_name} is not recognized." )