ProgressGym-eval/evaluate_model.py
2025-12-19 10:46:28 +08:00

222 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import json
import time
import os
import sys
from transformers import pipeline
# Configuration
INPUT_FILE = 'exam_questions.md'
OUTPUT_MD = 'evaluation_report.md'
PROGRESS_FILE = 'evaluation_progress.json'
def parse_questions(filename):
"""
Parses the markdown file to extract questions.
Expects format:
Number. **中**...
**EN**: ...
"""
if not os.path.exists(filename):
print(f"Error: Input file '{filename}' not found.")
sys.exit(1)
with open(filename, 'r', encoding='utf-8') as f:
content = f.read()
questions = []
lines = content.split('\n')
current_q = {}
# Regex patterns
# Matches: "1. **中**" or "1. **中**:"
cn_pattern = re.compile(r'^(\d+)\.\s*\*\*中\*\*[:]\s*(.*)')
# Matches: "**EN**:" or "**EN**"
en_pattern = re.compile(r'^\s*\*\*EN\*\*[:]\s*(.*)')
for line in lines:
line = line.strip()
if not line: continue
# Check for Chinese question start
m_cn = cn_pattern.match(line)
if m_cn:
# Save previous question if exists
if current_q:
questions.append(current_q)
current_q = {
'id': int(m_cn.group(1)),
'cn': m_cn.group(2),
'en': '',
'raw_cn': line # Store raw line just in case
}
continue
# Check for English question
m_en = en_pattern.match(line)
if m_en and current_q:
current_q['en'] = m_en.group(1)
continue
# Handle multi-line content (simple append)
# If we are inside a question, append to the last field found
if current_q:
if not current_q['en']:
# Still in CN part
current_q['cn'] += " " + line
else:
# In EN part
current_q['en'] += " " + line
if current_q:
questions.append(current_q)
return questions
# Global model pipeline
pipe = None
def setup_model():
global pipe
if pipe is not None:
return
print("Loading model PKU-Alignment/ProgressGym-HistLlama3-8B-C019-instruct-v0.2...")
try:
pipe = pipeline("text-generation", model="PKU-Alignment/ProgressGym-HistLlama3-8B-C019-instruct-v0.2", device=0)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
def query_model(prompt):
"""
Queries the actual model using the transformers pipeline.
"""
global pipe
if pipe is None:
setup_model()
print(f" -> Sending query to model: {prompt[:40]}...")
messages = [{"role": "user", "content": prompt}]
try:
# Generate response
outputs = pipe(messages, max_new_tokens=512)
# Extract the response
# The pipeline returns a list of results.
# With chat input, 'generated_text' is usually the list of messages.
result = outputs[0]['generated_text']
if isinstance(result, list) and len(result) > 0 and 'content' in result[-1]:
return result[-1]['content']
else:
return str(result)
except Exception as e:
print(f"Error during generation: {e}")
return f"[ERROR] Failed to generate response: {e}"
def generate_report(questions, results):
"""
Generates the final Markdown report.
"""
print(f"Generating report: {OUTPUT_MD}...")
with open(OUTPUT_MD, 'w', encoding='utf-8') as f:
f.write("# Model Evaluation Report\n\n")
f.write(f"**Date**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"**Total Questions**: {len(questions)}\n")
f.write(f"**Answered**: {len(results)}\n\n")
f.write("---\n\n")
# Sort questions by ID to ensure order
sorted_questions = sorted(questions, key=lambda x: x['id'])
for q in sorted_questions:
qid = str(q['id'])
if qid not in results:
continue
data = results[qid]
f.write(f"## Q{qid}\n\n")
f.write(f"### Question\n")
f.write(f"**CN**: {q['cn']}\n\n")
f.write(f"**EN**: {q['en']}\n\n")
f.write(f"### Model Answer\n")
f.write(f"{data['answer']}\n\n")
f.write("---\n\n")
print("Report generation complete.")
def main():
print("Starting Model Evaluation Script...")
# 1. Parse Questions
questions = parse_questions(INPUT_FILE)
print(f"Loaded {len(questions)} questions from {INPUT_FILE}.")
if not questions:
print("No questions found. Check the input file format.")
return
# 2. Load Progress (Resume capability)
results = {}
if os.path.exists(PROGRESS_FILE):
try:
with open(PROGRESS_FILE, 'r', encoding='utf-8') as f:
results = json.load(f)
print(f"Resuming from {PROGRESS_FILE}. Found {len(results)} existing answers.")
except json.JSONDecodeError:
print("Warning: Progress file corrupted. Starting fresh.")
# 3. Process Questions
total = len(questions)
try:
for i, q in enumerate(questions):
qid = str(q['id'])
# Skip if already done
if qid in results:
continue
# Display Progress
print(f"[{i+1}/{total}] Processing Q{qid}...")
# Query Model
# We use the English question as requested
if not q['en']:
print(f"Warning: Q{qid} has no English text. Skipping.")
continue
answer = query_model(q['en'])
# Save Result
results[qid] = {
'id': q['id'],
'question_cn': q['cn'],
'question_en': q['en'],
'answer': answer,
'timestamp': time.time()
}
# Save progress to disk immediately (so we don't lose data on crash)
with open(PROGRESS_FILE, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
except KeyboardInterrupt:
print("\n\nProcess interrupted by user. Progress has been saved.")
print("Run the script again to resume.")
return
except Exception as e:
print(f"\n\nAn error occurred: {e}")
# Still try to generate report with what we have
# 4. Generate Final Report
generate_report(questions, results)
print("\nDone!")
if __name__ == "__main__":
main()