137 lines
6.3 KiB
Python
137 lines
6.3 KiB
Python
import json
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
import model_eval
|
|
|
|
|
|
class _FakeProcessor:
|
|
def __init__(self, *args, **kwargs):
|
|
_ = (args, kwargs)
|
|
|
|
def warmup(self, **kwargs):
|
|
_ = kwargs
|
|
return
|
|
|
|
def process(self, text, **kwargs):
|
|
_ = kwargs
|
|
return text.strip()
|
|
|
|
|
|
class ModelEvalTests(unittest.TestCase):
|
|
def test_load_eval_dataset_validates_required_fields(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
dataset = Path(td) / "dataset.jsonl"
|
|
dataset.write_text(
|
|
'{"id":"c1","input_text":"hello","expected_output":"hello"}\n',
|
|
encoding="utf-8",
|
|
)
|
|
cases = model_eval.load_eval_dataset(dataset)
|
|
self.assertEqual(len(cases), 1)
|
|
self.assertEqual(cases[0].case_id, "c1")
|
|
|
|
def test_run_model_eval_produces_report(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
model_file = Path(td) / "fake.gguf"
|
|
model_file.write_text("fake", encoding="utf-8")
|
|
dataset = Path(td) / "dataset.jsonl"
|
|
dataset.write_text(
|
|
(
|
|
'{"id":"c1","input_text":"hello world","expected_output":"hello world"}\n'
|
|
'{"id":"c2","input_text":"hello","expected_output":"hello"}\n'
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
matrix = Path(td) / "matrix.json"
|
|
heuristic_dataset = Path(td) / "heuristics.jsonl"
|
|
matrix.write_text(
|
|
json.dumps(
|
|
{
|
|
"warmup_runs": 0,
|
|
"measured_runs": 1,
|
|
"timeout_sec": 30,
|
|
"baseline_model": {
|
|
"name": "base",
|
|
"provider": "local_llama",
|
|
"model_path": str(model_file),
|
|
"profile": "default",
|
|
"param_grid": {"temperature": [0.0]},
|
|
},
|
|
"candidate_models": [
|
|
{
|
|
"name": "small",
|
|
"provider": "local_llama",
|
|
"model_path": str(model_file),
|
|
"profile": "fast",
|
|
"param_grid": {"temperature": [0.0, 0.1], "max_tokens": [96]},
|
|
}
|
|
],
|
|
}
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
heuristic_dataset.write_text(
|
|
(
|
|
'{"id":"h1","transcript":"set alarm for 6 i mean 7","words":[{"text":"set","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"alarm","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"for","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"6","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"i","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"mean","start_s":1.0,"end_s":1.1,"prob":0.9},{"text":"7","start_s":1.2,"end_s":1.3,"prob":0.9}],"expected_aligned_text":"set alarm for 7","expected":{"applied_min":1,"required_rule_ids":["cue_correction"]},"tags":["i_mean_correction"]}\n'
|
|
'{"id":"h2","transcript":"write exactly i mean this sincerely","words":[{"text":"write","start_s":0.0,"end_s":0.1,"prob":0.9},{"text":"exactly","start_s":0.2,"end_s":0.3,"prob":0.9},{"text":"i","start_s":0.4,"end_s":0.5,"prob":0.9},{"text":"mean","start_s":0.6,"end_s":0.7,"prob":0.9},{"text":"this","start_s":0.8,"end_s":0.9,"prob":0.9},{"text":"sincerely","start_s":1.0,"end_s":1.1,"prob":0.9}],"expected_aligned_text":"write exactly i mean this sincerely","expected":{"required_rule_ids":[],"forbidden_rule_ids":["cue_correction"]},"tags":["i_mean_literal"]}\n'
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
with patch("model_eval.LlamaProcessor", _FakeProcessor):
|
|
report = model_eval.run_model_eval(
|
|
dataset,
|
|
matrix,
|
|
heuristic_dataset_path=heuristic_dataset,
|
|
heuristic_weight=0.3,
|
|
report_version=2,
|
|
verbose=False,
|
|
)
|
|
|
|
self.assertEqual(report["report_version"], 2)
|
|
self.assertIn("models", report)
|
|
self.assertEqual(len(report["models"]), 2)
|
|
self.assertIn("winner_recommendation", report)
|
|
self.assertIn("heuristic_eval", report)
|
|
self.assertEqual(report["heuristic_eval"]["cases"], 2)
|
|
self.assertIn("combined_score", report["models"][0]["best_param_set"])
|
|
summary = model_eval.format_model_eval_summary(report)
|
|
self.assertIn("model eval summary", summary)
|
|
|
|
def test_load_heuristic_dataset_validates_required_fields(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
dataset = Path(td) / "heuristics.jsonl"
|
|
dataset.write_text(
|
|
'{"id":"h1","transcript":"hello world","words":[{"text":"hello","start_s":0.0,"end_s":0.1},{"text":"world","start_s":0.2,"end_s":0.3}],"expected_aligned_text":"hello world"}\n',
|
|
encoding="utf-8",
|
|
)
|
|
cases = model_eval.load_heuristic_dataset(dataset)
|
|
self.assertEqual(len(cases), 1)
|
|
self.assertEqual(cases[0].case_id, "h1")
|
|
self.assertEqual(cases[0].expected.applied_min, 0)
|
|
|
|
def test_build_heuristic_dataset_generates_words_when_missing(self):
|
|
with tempfile.TemporaryDirectory() as td:
|
|
source = Path(td) / "heuristics.raw.jsonl"
|
|
output = Path(td) / "heuristics.jsonl"
|
|
source.write_text(
|
|
'{"id":"h1","transcript":"please send it","expected_aligned_text":"please send it","expected":{"applied_min":0}}\n',
|
|
encoding="utf-8",
|
|
)
|
|
summary = model_eval.build_heuristic_dataset(source, output)
|
|
self.assertEqual(summary["written_rows"], 1)
|
|
self.assertEqual(summary["generated_word_rows"], 1)
|
|
loaded = model_eval.load_heuristic_dataset(output)
|
|
self.assertEqual(len(loaded), 1)
|
|
self.assertGreaterEqual(len(loaded[0].words), 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|