from llama_cpp import Llama
from transformers import AutoTokenizer, DistilBertForTokenClassification
from joblib import load
from huggingface_hub import hf_hub_download
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["LLAMA_LOG_LEVEL"] = "OFF"
os.environ["LLAMA_QUIET"] = "1"
os.environ["GGML_LOG_LEVEL"] = "ERROR"
[docs]
def load_model(model_name: str):
model_path = f"{model_name}"
if model_name == "regression":
model = hf_hub_download(repo_id="WiMarka/Random_Forest_Regression", filename="regression.joblib")
return load(model)
elif model_name == "explanation":
return Llama.from_pretrained(
repo_id="WiMarka/Gemma_3_12B_IT_INT8_Explanation_Generation",
filename="gemma_explanation_3.gguf",
chat_format="chatml",
n_ctx=768,
verbose=False
)
elif model_name == "correction":
return Llama.from_pretrained(
repo_id="WiMarka/Gemma_3_12B_IT_INT8",
filename="base_gemma_it_quantized.gguf",
chat_format="chatml",
n_ctx=1024,
verbose=False
)
if model_name == "error_detection":
model_path = "WiMarka/DistilBERT_Error_Detection"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = DistilBertForTokenClassification.from_pretrained(model_path)
return model, tokenizer
else:
raise ValueError(
f"Model name {model_name} is not recognized."
)