222 lines
6.7 KiB
Python
222 lines
6.7 KiB
Python
import re
|
||
import json
|
||
import time
|
||
import os
|
||
import sys
|
||
from transformers import pipeline
|
||
|
||
# Configuration
|
||
INPUT_FILE = 'exam_questions.md'
|
||
OUTPUT_MD = 'evaluation_report.md'
|
||
PROGRESS_FILE = 'evaluation_progress.json'
|
||
|
||
def parse_questions(filename):
|
||
"""
|
||
Parses the markdown file to extract questions.
|
||
Expects format:
|
||
Number. **中**:...
|
||
**EN**: ...
|
||
"""
|
||
if not os.path.exists(filename):
|
||
print(f"Error: Input file '{filename}' not found.")
|
||
sys.exit(1)
|
||
|
||
with open(filename, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
questions = []
|
||
lines = content.split('\n')
|
||
current_q = {}
|
||
|
||
# Regex patterns
|
||
# Matches: "1. **中**:" or "1. **中**:"
|
||
cn_pattern = re.compile(r'^(\d+)\.\s*\*\*中\*\*[::]\s*(.*)')
|
||
# Matches: "**EN**:" or "**EN**:"
|
||
en_pattern = re.compile(r'^\s*\*\*EN\*\*[::]\s*(.*)')
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line: continue
|
||
|
||
# Check for Chinese question start
|
||
m_cn = cn_pattern.match(line)
|
||
if m_cn:
|
||
# Save previous question if exists
|
||
if current_q:
|
||
questions.append(current_q)
|
||
|
||
current_q = {
|
||
'id': int(m_cn.group(1)),
|
||
'cn': m_cn.group(2),
|
||
'en': '',
|
||
'raw_cn': line # Store raw line just in case
|
||
}
|
||
continue
|
||
|
||
# Check for English question
|
||
m_en = en_pattern.match(line)
|
||
if m_en and current_q:
|
||
current_q['en'] = m_en.group(1)
|
||
continue
|
||
|
||
# Handle multi-line content (simple append)
|
||
# If we are inside a question, append to the last field found
|
||
if current_q:
|
||
if not current_q['en']:
|
||
# Still in CN part
|
||
current_q['cn'] += " " + line
|
||
else:
|
||
# In EN part
|
||
current_q['en'] += " " + line
|
||
|
||
if current_q:
|
||
questions.append(current_q)
|
||
|
||
return questions
|
||
|
||
# Global model pipeline
|
||
pipe = None
|
||
|
||
def setup_model():
|
||
global pipe
|
||
if pipe is not None:
|
||
return
|
||
print("Loading model PKU-Alignment/ProgressGym-HistLlama3-8B-C019-instruct-v0.2...")
|
||
try:
|
||
pipe = pipeline("text-generation", model="PKU-Alignment/ProgressGym-HistLlama3-8B-C019-instruct-v0.2", device=0)
|
||
print("Model loaded successfully.")
|
||
except Exception as e:
|
||
print(f"Error loading model: {e}")
|
||
sys.exit(1)
|
||
|
||
def query_model(prompt):
|
||
"""
|
||
Queries the actual model using the transformers pipeline.
|
||
"""
|
||
global pipe
|
||
if pipe is None:
|
||
setup_model()
|
||
|
||
print(f" -> Sending query to model: {prompt[:40]}...")
|
||
|
||
messages = [{"role": "user", "content": prompt}]
|
||
|
||
try:
|
||
# Generate response
|
||
outputs = pipe(messages, max_new_tokens=512)
|
||
|
||
# Extract the response
|
||
# The pipeline returns a list of results.
|
||
# With chat input, 'generated_text' is usually the list of messages.
|
||
result = outputs[0]['generated_text']
|
||
|
||
if isinstance(result, list) and len(result) > 0 and 'content' in result[-1]:
|
||
return result[-1]['content']
|
||
else:
|
||
return str(result)
|
||
except Exception as e:
|
||
print(f"Error during generation: {e}")
|
||
return f"[ERROR] Failed to generate response: {e}"
|
||
|
||
def generate_report(questions, results):
|
||
"""
|
||
Generates the final Markdown report.
|
||
"""
|
||
print(f"Generating report: {OUTPUT_MD}...")
|
||
with open(OUTPUT_MD, 'w', encoding='utf-8') as f:
|
||
f.write("# Model Evaluation Report\n\n")
|
||
f.write(f"**Date**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||
f.write(f"**Total Questions**: {len(questions)}\n")
|
||
f.write(f"**Answered**: {len(results)}\n\n")
|
||
f.write("---\n\n")
|
||
|
||
# Sort questions by ID to ensure order
|
||
sorted_questions = sorted(questions, key=lambda x: x['id'])
|
||
|
||
for q in sorted_questions:
|
||
qid = str(q['id'])
|
||
if qid not in results:
|
||
continue
|
||
|
||
data = results[qid]
|
||
|
||
f.write(f"## Q{qid}\n\n")
|
||
f.write(f"### Question\n")
|
||
f.write(f"**CN**: {q['cn']}\n\n")
|
||
f.write(f"**EN**: {q['en']}\n\n")
|
||
f.write(f"### Model Answer\n")
|
||
f.write(f"{data['answer']}\n\n")
|
||
f.write("---\n\n")
|
||
|
||
print("Report generation complete.")
|
||
|
||
def main():
|
||
print("Starting Model Evaluation Script...")
|
||
|
||
# 1. Parse Questions
|
||
questions = parse_questions(INPUT_FILE)
|
||
print(f"Loaded {len(questions)} questions from {INPUT_FILE}.")
|
||
|
||
if not questions:
|
||
print("No questions found. Check the input file format.")
|
||
return
|
||
|
||
# 2. Load Progress (Resume capability)
|
||
results = {}
|
||
if os.path.exists(PROGRESS_FILE):
|
||
try:
|
||
with open(PROGRESS_FILE, 'r', encoding='utf-8') as f:
|
||
results = json.load(f)
|
||
print(f"Resuming from {PROGRESS_FILE}. Found {len(results)} existing answers.")
|
||
except json.JSONDecodeError:
|
||
print("Warning: Progress file corrupted. Starting fresh.")
|
||
|
||
# 3. Process Questions
|
||
total = len(questions)
|
||
try:
|
||
for i, q in enumerate(questions):
|
||
qid = str(q['id'])
|
||
|
||
# Skip if already done
|
||
if qid in results:
|
||
continue
|
||
|
||
# Display Progress
|
||
print(f"[{i+1}/{total}] Processing Q{qid}...")
|
||
|
||
# Query Model
|
||
# We use the English question as requested
|
||
if not q['en']:
|
||
print(f"Warning: Q{qid} has no English text. Skipping.")
|
||
continue
|
||
|
||
answer = query_model(q['en'])
|
||
|
||
# Save Result
|
||
results[qid] = {
|
||
'id': q['id'],
|
||
'question_cn': q['cn'],
|
||
'question_en': q['en'],
|
||
'answer': answer,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
# Save progress to disk immediately (so we don't lose data on crash)
|
||
with open(PROGRESS_FILE, 'w', encoding='utf-8') as f:
|
||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\nProcess interrupted by user. Progress has been saved.")
|
||
print("Run the script again to resume.")
|
||
return
|
||
except Exception as e:
|
||
print(f"\n\nAn error occurred: {e}")
|
||
# Still try to generate report with what we have
|
||
|
||
# 4. Generate Final Report
|
||
generate_report(questions, results)
|
||
print("\nDone!")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|