first commit
This commit is contained in:
commit
f9125370af
223
.gitignore
vendored
Normal file
223
.gitignore
vendored
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# Machine Learning / AI specific
|
||||||
|
# Model files (you may want to track some models)
|
||||||
|
*.pkl
|
||||||
|
*.joblib
|
||||||
|
*.h5
|
||||||
|
*.hdf5
|
||||||
|
*.pb
|
||||||
|
*.onnx
|
||||||
|
# Uncomment the next line if you don't want to track your trained models
|
||||||
|
# *.pt
|
||||||
|
|
||||||
|
# Dataset and data files
|
||||||
|
data/
|
||||||
|
datasets/
|
||||||
|
*.csv
|
||||||
|
*.xml
|
||||||
|
*.txt
|
||||||
|
# Image files (uncomment if you have large image datasets)
|
||||||
|
# *.jpg
|
||||||
|
# *.jpeg
|
||||||
|
# *.png
|
||||||
|
# *.gif
|
||||||
|
# *.bmp
|
||||||
|
# *.tiff
|
||||||
|
|
||||||
|
# Logs and outputs
|
||||||
|
logs/
|
||||||
|
outputs/
|
||||||
|
results/
|
||||||
|
checkpoints/
|
||||||
|
runs/
|
||||||
|
wandb/
|
||||||
|
|
||||||
|
# TensorBoard
|
||||||
|
events.out.tfevents.*
|
||||||
|
|
||||||
|
# Jupyter Notebook checkpoints
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
*.iws
|
||||||
|
*.iml
|
||||||
|
*.ipr
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
Icon?
|
||||||
|
._*
|
||||||
|
.DocumentRevisions-V100
|
||||||
|
.fseventsd
|
||||||
|
.Spotlight-V100
|
||||||
|
.TemporaryItems
|
||||||
|
.Trashes
|
||||||
|
.VolumeIcon.icns
|
||||||
|
.com.apple.timemachine.donotpresent
|
||||||
|
.AppleDB
|
||||||
|
.AppleDesktop
|
||||||
|
Network Trash Folder
|
||||||
|
Temporary Items
|
||||||
|
.apdisk
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
Thumbs.db
|
||||||
|
Thumbs.db:encryptable
|
||||||
|
ehthumbs.db
|
||||||
|
ehthumbs_vista.db
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
Desktop.ini
|
||||||
|
$RECYCLE.BIN/
|
||||||
|
*.cab
|
||||||
|
*.msi
|
||||||
|
*.msix
|
||||||
|
*.msm
|
||||||
|
*.msp
|
||||||
|
*.lnk
|
||||||
|
|
||||||
|
# Linux
|
||||||
|
*~
|
||||||
|
.fuse_hidden*
|
||||||
|
.directory
|
||||||
|
.Trash-*
|
||||||
|
.nfs*
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
# Add any project-specific files you want to ignore here
|
||||||
|
test_images/
|
||||||
|
temp/
|
||||||
|
cache/
|
||||||
26
circuit-recognition.service
Normal file
26
circuit-recognition.service
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=Circuit Recognition Web Service
|
||||||
|
After=network.target
|
||||||
|
Wants=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=root
|
||||||
|
Group=root
|
||||||
|
WorkingDirectory=/home/feie9454/elements_wires_congition
|
||||||
|
Environment=PATH=/home/feie9454/elements_wires_congition/venv/bin
|
||||||
|
ExecStart=/home/feie9454/elements_wires_congition/venv/bin/python main.py
|
||||||
|
Restart=always
|
||||||
|
RestartSec=10
|
||||||
|
StandardOutput=journal
|
||||||
|
StandardError=journal
|
||||||
|
SyslogIdentifier=circuit-recognition
|
||||||
|
|
||||||
|
# 安全设置
|
||||||
|
NoNewPrivileges=true
|
||||||
|
PrivateTmp=true
|
||||||
|
ProtectSystem=strict
|
||||||
|
ReadWritePaths=/home/feie9454/elements_wires_congition
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
80
main.py
Normal file
80
main.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from flask import Flask, request, jsonify, send_from_directory, send_file
|
||||||
|
from flask_cors import CORS
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
from src.pre_func.img_prec import img_recognition
|
||||||
|
|
||||||
|
# 设置静态文件目录
|
||||||
|
app = Flask(__name__, static_folder='src/web', static_url_path='/static')
|
||||||
|
CORS(app)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""主页路由,提供index.html"""
|
||||||
|
return send_from_directory(app.static_folder, 'index.html')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/web')
|
||||||
|
def web_index():
|
||||||
|
"""Web应用入口"""
|
||||||
|
return send_from_directory(app.static_folder, 'index.html')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/web/<path:filename>')
|
||||||
|
def web_static(filename):
|
||||||
|
"""提供web目录下的静态文件"""
|
||||||
|
return send_from_directory(app.static_folder, filename)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/static/<path:filename>')
|
||||||
|
def static_files(filename):
|
||||||
|
"""提供静态文件服务"""
|
||||||
|
return send_from_directory(app.static_folder, filename)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/status')
|
||||||
|
def api_status():
|
||||||
|
"""API状态检查"""
|
||||||
|
return jsonify({
|
||||||
|
'status': 'running',
|
||||||
|
'message': 'Elements Wires Recognition API is running',
|
||||||
|
'endpoints': {
|
||||||
|
'POST /process_image': 'Image recognition endpoint',
|
||||||
|
'GET /': 'Main web interface',
|
||||||
|
'GET /web': 'Web application',
|
||||||
|
'GET /api/status': 'API status check'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/process_image', methods=['POST'])
|
||||||
|
def process_image():
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({'error': 'No image part in the request'}), 400
|
||||||
|
|
||||||
|
file = request.files['image']
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({'error': 'No selected image'}), 400
|
||||||
|
|
||||||
|
file_bytes = np.frombuffer(file.read(), np.uint8)
|
||||||
|
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
result = img_recognition(img)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print("🚀 Starting Elements Wires Recognition Web Server...")
|
||||||
|
print("📁 Static files served from: src/web/")
|
||||||
|
print("🌐 Web interface available at:")
|
||||||
|
print(" - http://localhost:8000/")
|
||||||
|
print(" - http://0.0.0.0:8000/")
|
||||||
|
print("🔧 API endpoints:")
|
||||||
|
print(" - POST /process_image (for image recognition)")
|
||||||
|
print(" - GET /web/<filename> (for static files)")
|
||||||
|
print(" - GET /static/<filename> (for static files)")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
app.run(host='0.0.0.0', debug=True, port=12399)
|
||||||
63
manage_service.sh
Normal file
63
manage_service.sh
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Circuit Recognition Service Management Script
|
||||||
|
# 使用方法: ./manage_service.sh [install|start|stop|restart|status|logs|uninstall]
|
||||||
|
|
||||||
|
SERVICE_NAME="circuit-recognition"
|
||||||
|
SERVICE_FILE="/home/feie9454/elements_wires_congition/circuit-recognition.service"
|
||||||
|
SYSTEM_SERVICE_PATH="/etc/systemd/system/circuit-recognition.service"
|
||||||
|
|
||||||
|
case "$1" in
|
||||||
|
install)
|
||||||
|
echo "安装服务..."
|
||||||
|
sudo cp "$SERVICE_FILE" "$SYSTEM_SERVICE_PATH"
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl enable "$SERVICE_NAME"
|
||||||
|
echo "✅ 服务已安装并设置为开机自启"
|
||||||
|
echo "使用 './manage_service.sh start' 启动服务"
|
||||||
|
;;
|
||||||
|
start)
|
||||||
|
echo "启动服务..."
|
||||||
|
sudo systemctl start "$SERVICE_NAME"
|
||||||
|
echo "✅ 服务已启动"
|
||||||
|
;;
|
||||||
|
stop)
|
||||||
|
echo "停止服务..."
|
||||||
|
sudo systemctl stop "$SERVICE_NAME"
|
||||||
|
echo "✅ 服务已停止"
|
||||||
|
;;
|
||||||
|
restart)
|
||||||
|
echo "重启服务..."
|
||||||
|
sudo systemctl restart "$SERVICE_NAME"
|
||||||
|
echo "✅ 服务已重启"
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
echo "服务状态:"
|
||||||
|
sudo systemctl status "$SERVICE_NAME"
|
||||||
|
;;
|
||||||
|
logs)
|
||||||
|
echo "查看服务日志:"
|
||||||
|
sudo journalctl -u "$SERVICE_NAME" -f
|
||||||
|
;;
|
||||||
|
uninstall)
|
||||||
|
echo "卸载服务..."
|
||||||
|
sudo systemctl stop "$SERVICE_NAME"
|
||||||
|
sudo systemctl disable "$SERVICE_NAME"
|
||||||
|
sudo rm -f "$SYSTEM_SERVICE_PATH"
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
echo "✅ 服务已卸载"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "使用方法: $0 {install|start|stop|restart|status|logs|uninstall}"
|
||||||
|
echo ""
|
||||||
|
echo "命令说明:"
|
||||||
|
echo " install - 安装服务并设置开机自启"
|
||||||
|
echo " start - 启动服务"
|
||||||
|
echo " stop - 停止服务"
|
||||||
|
echo " restart - 重启服务"
|
||||||
|
echo " status - 查看服务状态"
|
||||||
|
echo " logs - 实时查看服务日志"
|
||||||
|
echo " uninstall - 卸载服务"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
15
pyproject.toml
Normal file
15
pyproject.toml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
[project]
|
||||||
|
name = "circuit-recognition"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"albumentations>=2.0.8",
|
||||||
|
"flask>=3.1.1",
|
||||||
|
"flask-cors>=6.0.1",
|
||||||
|
"numpy>=2.3.0",
|
||||||
|
"opencv-python>=4.11.0.86",
|
||||||
|
"scikit-image>=0.25.2",
|
||||||
|
"ultralytics>=8.3.156",
|
||||||
|
]
|
||||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 这个文件使 src 目录成为一个 Python 包
|
||||||
BIN
src/best_model/best9.17.pt
Normal file
BIN
src/best_model/best9.17.pt
Normal file
Binary file not shown.
BIN
src/best_model/unet_wire_red.pth
Normal file
BIN
src/best_model/unet_wire_red.pth
Normal file
Binary file not shown.
1
src/pre_func/__init__.py
Normal file
1
src/pre_func/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 这个文件使 pre_func 目录成为一个 Python 包
|
||||||
53
src/pre_func/color.py
Normal file
53
src/pre_func/color.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
|
||||||
|
# =============================
|
||||||
|
# 1. 加载训练好的 U-Net 模型
|
||||||
|
# =============================
|
||||||
|
model = load_model("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/unet_wire.pth", compile=False)
|
||||||
|
|
||||||
|
# 读入测试图像
|
||||||
|
img = cv2.imread("/home/gqw/unet_ai/test/IMG_0059.jpeg")
|
||||||
|
h, w, _ = img.shape
|
||||||
|
input_img = cv2.resize(img, (256, 256)) / 255.0
|
||||||
|
input_img = np.expand_dims(input_img, axis=0)
|
||||||
|
|
||||||
|
# =============================
|
||||||
|
# 2. U-Net 分割预测(得到二值 mask)
|
||||||
|
# =============================
|
||||||
|
pred = model.predict(input_img)[0, :, :, 0]
|
||||||
|
mask = (pred > 0.5).astype(np.uint8) # 阈值化
|
||||||
|
mask = cv2.resize(mask, (w, h))
|
||||||
|
|
||||||
|
# =============================
|
||||||
|
# 3. 提取导线区域的像素颜色
|
||||||
|
# =============================
|
||||||
|
wire_pixels = img[mask == 1] # N x 3 (RGB)
|
||||||
|
|
||||||
|
# KMeans 聚类(假设有3种颜色的导线)
|
||||||
|
n_colors = 3
|
||||||
|
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(wire_pixels)
|
||||||
|
labels = kmeans.labels_
|
||||||
|
|
||||||
|
# =============================
|
||||||
|
# 4. 生成每个颜色的 mask
|
||||||
|
# =============================
|
||||||
|
colored_masks = []
|
||||||
|
for i in range(n_colors):
|
||||||
|
color_mask = np.zeros(mask.shape, dtype=np.uint8)
|
||||||
|
color_mask[mask == 1] = (labels == i).astype(np.uint8)
|
||||||
|
colored_masks.append(color_mask)
|
||||||
|
|
||||||
|
# =============================
|
||||||
|
# 5. 可视化结果
|
||||||
|
# =============================
|
||||||
|
result = img.copy()
|
||||||
|
colors = [(0,0,255), (0,255,0), (255,0,0)] # 每个聚类显示的颜色
|
||||||
|
for i, cmask in enumerate(colored_masks):
|
||||||
|
result[cmask == 1] = colors[i % len(colors)]
|
||||||
|
|
||||||
|
cv2.imwrite("segmented_by_color.jpg", result)
|
||||||
|
|
||||||
|
print("输出完成:segmented_by_color.jpg")
|
||||||
5
src/pre_func/config.py
Normal file
5
src/pre_func/config.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# 模型与路径配置
|
||||||
|
YOLO_MODEL_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.17.pt' # 从infer.py迁移
|
||||||
|
DETECTION_CONFIDENCE = 0.5 # 从infer.py迁移(可根据需要调整)
|
||||||
|
OUTPUT_BASE_DIR = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 从infer.py迁移
|
||||||
|
WIRE_CONNECT_THRESHOLD = 50 # 导线端点与连接点的最大允许距离(像素)
|
||||||
83
src/pre_func/ele_recog.py
Normal file
83
src/pre_func/ele_recog.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from config import YOLO_MODEL_PATH, DETECTION_CONFIDENCE
|
||||||
|
from utils import ensure_output_dir # 导入工具函数
|
||||||
|
|
||||||
|
|
||||||
|
def elements_recognition(img):
|
||||||
|
# 加载模型
|
||||||
|
model = YOLO(YOLO_MODEL_PATH)
|
||||||
|
original = img
|
||||||
|
# 保持原缩放逻辑(与导线识别尺寸一致)
|
||||||
|
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||||
|
# 推理预测
|
||||||
|
results = model(img, conf=DETECTION_CONFIDENCE)[0]
|
||||||
|
components = []
|
||||||
|
|
||||||
|
# 创建只含YOLO识别结果的图像
|
||||||
|
yolo_img = img.copy()
|
||||||
|
|
||||||
|
# 标签映射(与img_prec.py保持一致)
|
||||||
|
labels = {
|
||||||
|
"微安电流表": "ammeter",
|
||||||
|
"黑盒电流表": "ammeter",
|
||||||
|
"待测表头": "voltmeter",
|
||||||
|
"黑盒电压表": "voltmeter",
|
||||||
|
"电阻箱": "resistance_box",
|
||||||
|
"滑动变阻器": "sliding_rheostat",
|
||||||
|
"单刀双掷开关": "switch",
|
||||||
|
"单刀单掷开关": "switch",
|
||||||
|
"灯泡": "light_bulb",
|
||||||
|
"电源": "power_supply",
|
||||||
|
"电池电源": "battery",
|
||||||
|
"电阻": "resistor",
|
||||||
|
"螺线管": "inductor",
|
||||||
|
"电容": "capacitor"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 为每种元件类型定义专属颜色 (B, G, R)
|
||||||
|
color_map = {
|
||||||
|
"ammeter": (0, 255, 0), # 电流表:绿色
|
||||||
|
"voltmeter": (255, 0, 0), # 电压表:蓝色
|
||||||
|
"resistance_box": (0, 255, 255), # 电阻箱:黄色
|
||||||
|
"sliding_rheostat": (128, 0, 128), # 滑动变阻器:紫色
|
||||||
|
"switch": (255, 165, 0), # 开关:橙色
|
||||||
|
"light_bulb": (255, 255, 0), # 灯泡:青色
|
||||||
|
"power_supply": (0, 0, 255), # 电源:红色
|
||||||
|
"battery": (128, 128, 0), # 电池:橄榄绿
|
||||||
|
"resistor": (0, 128, 128), # 电阻:深青
|
||||||
|
"inductor": (128, 0, 0), # 电感(螺线管):深红色
|
||||||
|
"capacitor": (0, 0, 128) # 电容:深蓝色
|
||||||
|
}
|
||||||
|
|
||||||
|
# 遍历检测结果并绘制
|
||||||
|
for box in results.boxes:
|
||||||
|
if box.conf < DETECTION_CONFIDENCE:
|
||||||
|
continue
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
original_label = model.names[cls]
|
||||||
|
# 映射为显示标签
|
||||||
|
display_label = labels.get(original_label, original_label)
|
||||||
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||||
|
|
||||||
|
components.append({
|
||||||
|
"label": original_label,
|
||||||
|
"bbox": [x1, y1, x2, y2]
|
||||||
|
})
|
||||||
|
|
||||||
|
# 获取当前元件的颜色(默认灰色)
|
||||||
|
box_color = color_map.get(display_label, (128, 128, 128))
|
||||||
|
|
||||||
|
# 在yolo_img上绘制
|
||||||
|
cv2.rectangle(yolo_img, (x1, y1), (x2, y2), box_color, 2)
|
||||||
|
cv2.putText(yolo_img, display_label, (x1, y1 - 8),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
|
||||||
|
|
||||||
|
# 保存YOLO识别结果图像
|
||||||
|
output_dir = ensure_output_dir()
|
||||||
|
save_path = os.path.join(output_dir, "yolo_recognition.png")
|
||||||
|
cv2.imwrite(save_path, yolo_img)
|
||||||
|
print(f"YOLO识别结果已保存到: {save_path}")
|
||||||
|
|
||||||
|
return components
|
||||||
484
src/pre_func/img_prec.py
Normal file
484
src/pre_func/img_prec.py
Normal file
@ -0,0 +1,484 @@
|
|||||||
|
import datetime
|
||||||
|
import wires_recog
|
||||||
|
import ele_recog
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from utils import ensure_output_dir # 导入工具函数
|
||||||
|
|
||||||
|
# 从 elements.ts 转换过来的 Python 结构(包含电流表、电压表独立定义)
|
||||||
|
elements_ts = [
|
||||||
|
# 文本框
|
||||||
|
{
|
||||||
|
"key": "text_box",
|
||||||
|
"name": "文本框",
|
||||||
|
"defaultSize": 0,
|
||||||
|
"connectionPoints": [],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "text", "label": "文本", "type": "text", "default": "双击编辑"},
|
||||||
|
{"key": "fontSize", "label": "字体大小", "type": "number", "unit": "px", "default": 24},
|
||||||
|
{"key": "color", "label": "颜色", "type": "color", "default": "#111827"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电池
|
||||||
|
{
|
||||||
|
"key": "battery",
|
||||||
|
"name": "电池",
|
||||||
|
"defaultSize": 110,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.95, "y": 0.5, "name": "正极"},
|
||||||
|
{"x": 0.05, "y": 0.5, "name": "负极"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 3},
|
||||||
|
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.5}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电源
|
||||||
|
{
|
||||||
|
"key": "power_supply",
|
||||||
|
"name": "电源",
|
||||||
|
"defaultSize": 180,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.8, "y": 0.8, "name": "+"},
|
||||||
|
{"x": 0.85, "y": 0.2, "name": "-"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 5},
|
||||||
|
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.1}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 开关
|
||||||
|
{
|
||||||
|
"key": "switch",
|
||||||
|
"name": "开关",
|
||||||
|
"defaultSize": 130,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.05, "y": 0.7, "name": "A"},
|
||||||
|
{"x": 0.95, "y": 0.7, "name": "B"}
|
||||||
|
],
|
||||||
|
"stateImages": {"on": "switchOn", "off": "switchOff"},
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "state", "label": "状态", "type": "select", "options": ["off", "on"], "default": "off"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 灯泡
|
||||||
|
{
|
||||||
|
"key": "light_bulb",
|
||||||
|
"name": "灯泡",
|
||||||
|
"defaultSize": 80,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.5, "y": 1, "name": "A"},
|
||||||
|
{"x": 0.72, "y": 0.8, "name": "B"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 20}
|
||||||
|
],
|
||||||
|
"stateImages": {"on": "lightBulbGrow", "off": "lightBulb"}
|
||||||
|
},
|
||||||
|
# 滑动变阻器
|
||||||
|
{
|
||||||
|
"key": "sliding_rheostat",
|
||||||
|
"name": "滑动变阻器",
|
||||||
|
"defaultSize": 180,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.15, "y": 0.28, "name": "A"},
|
||||||
|
{"x": 0.85, "y": 0.76, "name": "D"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "maxResistance", "label": "最大电阻", "type": "number", "unit": "Ω", "default": 100},
|
||||||
|
{"key": "position", "label": "滑块位置", "type": "number", "default": 0.5}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 普通电阻
|
||||||
|
{
|
||||||
|
"key": "resistor",
|
||||||
|
"name": "电阻",
|
||||||
|
"defaultSize": 110,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.05, "y": 0.5, "name": "A"},
|
||||||
|
{"x": 0.95, "y": 0.5, "name": "B"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 5}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电阻箱
|
||||||
|
{
|
||||||
|
"key": "resistance_box",
|
||||||
|
"name": "电阻箱",
|
||||||
|
"defaultSize": 180,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.1, "y": 0.4, "name": "A"},
|
||||||
|
{"x": 0.8, "y": 0.2, "name": "B"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 100}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电流表(独立定义,替代原meter的电流表预设)
|
||||||
|
{
|
||||||
|
"key": "ammeter",
|
||||||
|
"name": "电流表",
|
||||||
|
"defaultSize": 150,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.25, "y": 0.8, "name": "正极"},
|
||||||
|
{"x": 0.85, "y": 0.8, "name": "负极"}
|
||||||
|
],
|
||||||
|
"preset": [
|
||||||
|
{"name": "微安电流表", "propertyValues": {"resistance": 0.05, "renderFunc": "(i) => `${(i * 1000000).toFixed(0)} uA`"}},
|
||||||
|
{"name": "毫安电流表", "propertyValues": {"resistance": 0.5, "renderFunc": "(i) => `${(i * 1000).toFixed(1)} mA`"}},
|
||||||
|
{"name": "电流表", "propertyValues": {"resistance": 0.1, "renderFunc": "(i) => `${(i).toFixed(2)} A`"}}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 0.1},
|
||||||
|
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i) => `${(i).toFixed(2)} A`"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电压表(独立定义,替代原meter的电压表预设)
|
||||||
|
{
|
||||||
|
"key": "voltmeter",
|
||||||
|
"name": "电压表",
|
||||||
|
"defaultSize": 150,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.25, "y": 1, "name": "正极"},
|
||||||
|
{"x": 0.45, "y": 1, "name": "负极"}
|
||||||
|
],
|
||||||
|
"preset": [
|
||||||
|
{"name": "电压表", "propertyValues": {"resistance": 10000000, "renderFunc": "(i, r) => `${(i * r).toFixed(2)} V`"}}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 10000000},
|
||||||
|
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i, r) => `${(i * r).toFixed(2)} V`"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电容
|
||||||
|
{
|
||||||
|
"key": "capacitor",
|
||||||
|
"name": "电容",
|
||||||
|
"defaultSize": 120,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.4, "y": 0.9, "name": "正极"},
|
||||||
|
{"x": 0.6, "y": 0.96, "name": "负极"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "capacitance", "label": "电容", "type": "number", "unit": "F", "default": 0.000001}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# 电感
|
||||||
|
{
|
||||||
|
"key": "inductor",
|
||||||
|
"name": "电感",
|
||||||
|
"defaultSize": 100,
|
||||||
|
"connectionPoints": [
|
||||||
|
{"x": 0.05, "y": 0.5, "name": "A"},
|
||||||
|
{"x": 0.95, "y": 0.5, "name": "B"}
|
||||||
|
],
|
||||||
|
"propertySchemas": [
|
||||||
|
{"key": "inductance", "label": "电感", "type": "number", "unit": "H", "default": 0.001}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate_json(wires, components):
|
||||||
|
# 标签映射(保持不变)
|
||||||
|
labels = {
|
||||||
|
"微安电流表": "ammeter", "待测表头(电流表)": "ammeter", "黑盒电流表": "ammeter",
|
||||||
|
"待测表头": "voltmeter", "黑盒电压表": "voltmeter",
|
||||||
|
"电阻箱": "resistance_box", "滑动变阻器": "sliding_rheostat",
|
||||||
|
"单刀双掷开关": "switch", "单刀单掷开关": "switch", "灯泡": "light_bulb",
|
||||||
|
"电源": "power_supply", "电池电源": "battery", "电阻": "resistor",
|
||||||
|
"螺线管": "inductor", "电容": "capacitor"
|
||||||
|
}
|
||||||
|
|
||||||
|
instances = []
|
||||||
|
element_count = 0
|
||||||
|
# 新增:存储每个元件的连接点实际坐标(格式:{元件id: [{"x": 实际x, "y": 实际y, "cpIndex": 连接点索引}, ...]})
|
||||||
|
component_connection_points = {}
|
||||||
|
|
||||||
|
# 处理元件(新增:计算连接点实际坐标)
|
||||||
|
for comp in components:
|
||||||
|
bbox = comp["bbox"]
|
||||||
|
original_label = comp["label"]
|
||||||
|
label = labels.get(original_label, original_label)
|
||||||
|
|
||||||
|
x = (bbox[0] + bbox[2]) / 2 # 元件中心x
|
||||||
|
y = (bbox[1] + bbox[3]) / 2 # 元件中心y
|
||||||
|
comp_id = str(element_count + 1) # 当前元件id
|
||||||
|
|
||||||
|
# 获取默认尺寸和连接点定义
|
||||||
|
default_size = 100
|
||||||
|
conn_points = []
|
||||||
|
for elem in elements_ts:
|
||||||
|
if elem["key"] == label:
|
||||||
|
default_size = elem.get("defaultSize", 100)
|
||||||
|
conn_points = elem.get("connectionPoints", [])
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算连接点实际坐标(基于边界框)
|
||||||
|
comp_width = bbox[2] - bbox[0]
|
||||||
|
comp_height = bbox[3] - bbox[1]
|
||||||
|
actual_conn_points = []
|
||||||
|
for cp_idx, cp in enumerate(conn_points):
|
||||||
|
# 连接点在图像中的实际坐标(相对边界框的比例转换)
|
||||||
|
actual_x = bbox[0] + cp["x"] * comp_width
|
||||||
|
actual_y = bbox[1] + cp["y"] * comp_height
|
||||||
|
actual_conn_points.append({
|
||||||
|
"x": actual_x,
|
||||||
|
"y": actual_y,
|
||||||
|
"cpIndex": cp_idx
|
||||||
|
})
|
||||||
|
# 存储当前元件的连接点信息
|
||||||
|
component_connection_points[comp_id] = actual_conn_points
|
||||||
|
|
||||||
|
# 元件实例(保持不变)
|
||||||
|
instance = {
|
||||||
|
"id": comp_id,
|
||||||
|
"key": label,
|
||||||
|
"x": x,
|
||||||
|
"y": y,
|
||||||
|
"size": default_size,
|
||||||
|
"rotation": 0,
|
||||||
|
"props": {
|
||||||
|
"__connections": {},
|
||||||
|
"originalLabel": original_label
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# ...(省略元件属性赋值,保持不变)
|
||||||
|
instances.append(instance)
|
||||||
|
element_count += 1
|
||||||
|
|
||||||
|
# 处理导线(基于坐标匹配连接点)
|
||||||
|
wire_id = element_count + 1
|
||||||
|
for wire in wires:
|
||||||
|
# 导线起点和终点的实际坐标(来自wires_recog的检测结果)
|
||||||
|
wire_start = (wire["start"]["x"], wire["start"]["y"])
|
||||||
|
wire_end = (wire["end"]["x"], wire["end"]["y"])
|
||||||
|
|
||||||
|
# 找到距离导线起点最近的元件连接点
|
||||||
|
min_dist_start = float('inf')
|
||||||
|
start_match = None # 格式:(元件id, 连接点索引)
|
||||||
|
for comp_id, cp_list in component_connection_points.items():
|
||||||
|
for cp in cp_list:
|
||||||
|
# 计算欧氏距离
|
||||||
|
dist = np.hypot(cp["x"] - wire_start[0], cp["y"] - wire_start[1])
|
||||||
|
if dist < min_dist_start:
|
||||||
|
min_dist_start = dist
|
||||||
|
start_match = (comp_id, cp["cpIndex"])
|
||||||
|
|
||||||
|
# 找到距离导线终点最近的元件连接点
|
||||||
|
min_dist_end = float('inf')
|
||||||
|
end_match = None
|
||||||
|
for comp_id, cp_list in component_connection_points.items():
|
||||||
|
for cp in cp_list:
|
||||||
|
dist = np.hypot(cp["x"] - wire_end[0], cp["y"] - wire_end[1])
|
||||||
|
if dist < min_dist_end:
|
||||||
|
min_dist_end = dist
|
||||||
|
end_match = (comp_id, cp["cpIndex"])
|
||||||
|
|
||||||
|
# 处理无匹配的情况(避免报错)
|
||||||
|
if not start_match or not end_match:
|
||||||
|
# 若没有匹配到,默认连接第一个元件(兼容旧逻辑,但尽量避免)
|
||||||
|
start_match = (instances[0]["id"], 0) if instances else (str(wire_id), 0)
|
||||||
|
end_match = (instances[-1]["id"], 0) if instances else (str(wire_id + 1), 0)
|
||||||
|
|
||||||
|
# 添加导线实例(使用匹配到的连接点)
|
||||||
|
instances.append({
|
||||||
|
"id": str(wire_id),
|
||||||
|
"key": "wire",
|
||||||
|
"x": (wire_start[0] + wire_end[0]) / 2,
|
||||||
|
"y": (wire_start[1] + wire_end[1]) / 2,
|
||||||
|
"size": 1,
|
||||||
|
"rotation": 0,
|
||||||
|
"props": {
|
||||||
|
"__connections": {
|
||||||
|
"0": [{"instId": start_match[0], "cpIndex": start_match[1]}],
|
||||||
|
"1": [{"instId": end_match[0], "cpIndex": end_match[1]}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wire_id += 1
|
||||||
|
|
||||||
|
# 生成最终JSON(保持不变)
|
||||||
|
return {
|
||||||
|
"version": 2,
|
||||||
|
"world": {
|
||||||
|
"scale": 1,
|
||||||
|
"translateX": -2089,
|
||||||
|
"translateY": -2123,
|
||||||
|
"worldSize": 5000,
|
||||||
|
"gridSize": 40
|
||||||
|
},
|
||||||
|
"instances": instances
|
||||||
|
}
|
||||||
|
|
||||||
|
def visualize_wires_and_components(image, results, components):
|
||||||
|
# 确保output目录存在
|
||||||
|
def ensure_output_dir():
|
||||||
|
if not os.path.exists("output"):
|
||||||
|
os.makedirs("output")
|
||||||
|
return "output"
|
||||||
|
|
||||||
|
original = image
|
||||||
|
# 缩放图像(与元件识别尺寸一致,确保坐标匹配)
|
||||||
|
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||||
|
if img is None:
|
||||||
|
raise FileNotFoundError("无法读取图像")
|
||||||
|
|
||||||
|
# === 步骤1:建立元件ID映射(适配ammeter/voltmeter)===
|
||||||
|
comp_id_to_info = {}
|
||||||
|
# 可视化用的标签映射(与generate_json保持一致)
|
||||||
|
labels = {
|
||||||
|
"微安电流表": "ammeter",
|
||||||
|
"黑盒电流表": "ammeter",
|
||||||
|
"待测表头": "voltmeter",
|
||||||
|
"黑盒电压表": "voltmeter",
|
||||||
|
"电阻箱": "resistance_box",
|
||||||
|
"滑动变阻器": "sliding_rheostat",
|
||||||
|
"单刀双掷开关": "switch",
|
||||||
|
"单刀单掷开关": "switch",
|
||||||
|
"灯泡": "light_bulb",
|
||||||
|
"电源": "power_supply",
|
||||||
|
"电池电源": "battery",
|
||||||
|
"电阻": "resistor",
|
||||||
|
"螺线管": "inductor",
|
||||||
|
"电容": "capacitor"
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, comp in enumerate(components):
|
||||||
|
comp_id = str(idx + 1)
|
||||||
|
bbox = comp["bbox"]
|
||||||
|
original_label = comp["label"]
|
||||||
|
comp_type = labels.get(original_label, original_label)
|
||||||
|
|
||||||
|
# 从elements_ts获取当前元件的连接点(兼容ammeter/voltmeter)
|
||||||
|
conn_points = []
|
||||||
|
for elem in elements_ts:
|
||||||
|
if elem["key"] == comp_type:
|
||||||
|
conn_points = elem["connectionPoints"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算连接点实际像素坐标
|
||||||
|
comp_width = bbox[2] - bbox[0]
|
||||||
|
comp_height = bbox[3] - bbox[1]
|
||||||
|
points = []
|
||||||
|
for p in conn_points:
|
||||||
|
x = bbox[0] + (p["x"] * comp_width)
|
||||||
|
y = bbox[1] + (p["y"] * comp_height)
|
||||||
|
points.append({"x": int(x), "y": int(y), "name": p["name"]})
|
||||||
|
|
||||||
|
# 保存元件信息(用于后续绘制)
|
||||||
|
comp_id_to_info[comp_id] = {
|
||||||
|
"bbox": bbox,
|
||||||
|
"conn_points": points,
|
||||||
|
"type": comp_type,
|
||||||
|
"displayLabel": comp_type # 显示ammeter/voltmeter,明确区分
|
||||||
|
}
|
||||||
|
|
||||||
|
# === 步骤2:绘制元件连接点(红色圆点)===
|
||||||
|
for comp_info in comp_id_to_info.values():
|
||||||
|
for p in comp_info["conn_points"]:
|
||||||
|
cv2.circle(img, (p["x"], p["y"]), 6, (0, 0, 255), -1)
|
||||||
|
cv2.putText(img, p["name"], (p["x"]+5, p["y"]-5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
||||||
|
|
||||||
|
# === 步骤3:绘制导线(黄色线,区分起点/终点)===
|
||||||
|
for instance in results["instances"]:
|
||||||
|
if instance["key"] != "wire":
|
||||||
|
continue
|
||||||
|
conn_info = instance["props"]["__connections"]
|
||||||
|
if "0" not in conn_info or "1" not in conn_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取导线连接的元件和连接点
|
||||||
|
start_conn = conn_info["0"][0]
|
||||||
|
end_conn = conn_info["1"][0]
|
||||||
|
start_comp_id = start_conn["instId"]
|
||||||
|
end_comp_id = end_conn["instId"]
|
||||||
|
start_cp_idx = start_conn["cpIndex"]
|
||||||
|
end_cp_idx = end_conn["cpIndex"]
|
||||||
|
|
||||||
|
# 跳过不存在的元件
|
||||||
|
if start_comp_id not in comp_id_to_info or end_comp_id not in comp_id_to_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 绘制导线
|
||||||
|
start_point = comp_id_to_info[start_comp_id]["conn_points"][start_cp_idx]
|
||||||
|
end_point = comp_id_to_info[end_comp_id]["conn_points"][end_cp_idx]
|
||||||
|
cv2.line(img, (start_point["x"], start_point["y"]),
|
||||||
|
(end_point["x"], end_point["y"]), (0, 255, 255), 2)
|
||||||
|
cv2.circle(img, (start_point["x"], start_point["y"]), 6, (0, 0, 255), -1) # 起点红
|
||||||
|
cv2.circle(img, (end_point["x"], end_point["y"]), 6, (255, 0, 0), -1) # 终点蓝
|
||||||
|
cv2.putText(img, "start", (start_point["x"]+5, start_point["y"]-5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
||||||
|
cv2.putText(img, "end", (end_point["x"]+5, end_point["y"]-5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
|
||||||
|
|
||||||
|
# === 步骤4:绘制元件框和标签(区分电流表/电压表)===
|
||||||
|
for idx, comp in enumerate(components):
|
||||||
|
comp_id = str(idx + 1)
|
||||||
|
if comp_id not in comp_id_to_info:
|
||||||
|
continue
|
||||||
|
comp_info = comp_id_to_info[comp_id]
|
||||||
|
bbox = comp_info["bbox"]
|
||||||
|
display_label = comp_info["displayLabel"] # 显示ammeter/voltmeter
|
||||||
|
|
||||||
|
# 绘制边界框(电流表绿色,电压表蓝色,其他默认绿色)
|
||||||
|
if display_label == "ammeter":
|
||||||
|
box_color = (0, 255, 0) # 电流表绿色
|
||||||
|
elif display_label == "voltmeter":
|
||||||
|
box_color = (255, 0, 0) # 电压表蓝色
|
||||||
|
else:
|
||||||
|
box_color = (0, 128, 0) # 其他元件深绿色
|
||||||
|
|
||||||
|
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), box_color, 2)
|
||||||
|
# 绘制标签(与框同色)
|
||||||
|
cv2.putText(img, display_label, (bbox[0], bbox[1] - 8),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
|
||||||
|
|
||||||
|
# 保存并返回可视化结果
|
||||||
|
output_dir = ensure_output_dir()
|
||||||
|
save_path = os.path.join(output_dir, "recognized_result.jpg")
|
||||||
|
cv2.imwrite(save_path, img)
|
||||||
|
print(f"可视化结果已保存到: {save_path}")
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
_, encode_img = cv2.imencode('.jpg', img)
|
||||||
|
img_base64 = base64.b64encode(encode_img).decode('utf-8')
|
||||||
|
return img_base64
|
||||||
|
|
||||||
|
|
||||||
|
def img_recognition(img):
|
||||||
|
# 1. 检测导线
|
||||||
|
wires = wires_recog.detect_wires_and_endpoints(img)
|
||||||
|
# 2. 识别元件(模型输出原始标签)
|
||||||
|
elements = ele_recog.elements_recognition(img)
|
||||||
|
# 3. 生成区分后的JSON(ammeter/voltmeter)
|
||||||
|
results = generate_json(wires, elements)
|
||||||
|
# 4. 可视化(区分显示电流表/电压表)
|
||||||
|
results_img = visualize_wires_and_components(img, results, elements)
|
||||||
|
|
||||||
|
# 构建返回结果
|
||||||
|
request = {
|
||||||
|
"success": True,
|
||||||
|
"recognizedImage": f"data:image/jpeg;base64,{results_img}",
|
||||||
|
"circuitData": results,
|
||||||
|
"components": elements
|
||||||
|
}
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
# 测试入口(可选)
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# import time
|
||||||
|
# start = time.perf_counter()
|
||||||
|
# imgs_path = [r"你的测试图路径.jpg"]
|
||||||
|
# img = cv2.imread(imgs_path[0])
|
||||||
|
# if img is None:
|
||||||
|
# raise FileNotFoundError(f"无法读取图像: {imgs_path[0]}")
|
||||||
|
# result = img_recognition(img)
|
||||||
|
# end = time.perf_counter()
|
||||||
|
# print(f"处理耗时: {end - start:.2f}s")
|
||||||
|
# print(f"识别元件数: {len(result['components'])}")
|
||||||
|
# print(f"电路实例数(元件+导线): {len(result['circuitData']['instances'])}")
|
||||||
31
src/pre_func/run.py
Normal file
31
src/pre_func/run.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import cv2
|
||||||
|
import os # 新增os模块导入
|
||||||
|
from img_prec import img_recognition
|
||||||
|
from utils import ensure_output_dir # 导入确保输出目录存在的工具函数
|
||||||
|
|
||||||
|
# 1. 读取图像(替换为你的图像路径)
|
||||||
|
image_path = "love.jpg"
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
print(f"无法读取图像: {image_path}")
|
||||||
|
else:
|
||||||
|
# 确保output目录存在
|
||||||
|
output_dir = ensure_output_dir()
|
||||||
|
|
||||||
|
# 执行识别
|
||||||
|
result = img_recognition(img)
|
||||||
|
|
||||||
|
# 输出识别结果
|
||||||
|
print("识别是否成功:", result["success"])
|
||||||
|
print("识别到的仪器数量:", len(result["components"]))
|
||||||
|
|
||||||
|
# 修复导线数量统计:从instances中筛选key为"wire"的实例
|
||||||
|
wire_count = sum(1 for item in result["circuitData"]["instances"] if item["key"] == "wire")
|
||||||
|
print("识别到的导线数量:", wire_count)
|
||||||
|
|
||||||
|
# 保存结果为JSON到output文件夹
|
||||||
|
import json
|
||||||
|
with open(os.path.join(output_dir, "recognition_result.json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"识别结果已保存到 {os.path.join(output_dir, 'recognition_result.json')}")
|
||||||
130
src/pre_func/train_unet.py
Normal file
130
src/pre_func/train_unet.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import os, cv2, glob, numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch, torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import albumentations as A
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
|
||||||
|
# ---------- Dataset ----------
|
||||||
|
class WireDataset(Dataset):
|
||||||
|
def __init__(self, img_dir, mask_dir, train=True):
|
||||||
|
self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
|
||||||
|
self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))
|
||||||
|
aug_list = [
|
||||||
|
A.LongestMaxSize(max_size=1024),
|
||||||
|
A.PadIfNeeded(1024, 1024, border_mode=cv2.BORDER_CONSTANT, value=0),
|
||||||
|
]
|
||||||
|
if train:
|
||||||
|
aug_list += [
|
||||||
|
A.RandomCrop(512, 512),
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.VerticalFlip(p=0.2),
|
||||||
|
A.RandomBrightnessContrast(0.1, 0.1, p=0.5),
|
||||||
|
A.GaussianBlur(blur_limit=3, p=0.2),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
aug_list += [A.CenterCrop(512, 512)]
|
||||||
|
self.tf = A.Compose(aug_list + [A.Normalize(), ToTensorV2()])
|
||||||
|
|
||||||
|
def __len__(self): return len(self.img_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
img = cv2.imread(self.img_paths[i], cv2.IMREAD_COLOR)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
msk = cv2.imread(self.mask_paths[i], cv2.IMREAD_GRAYSCALE)
|
||||||
|
msk = (msk <= 127).astype(np.float32)
|
||||||
|
out = self.tf(image=img, mask=msk)
|
||||||
|
return out["image"], out["mask"].unsqueeze(0)
|
||||||
|
|
||||||
|
# ---------- U-Net ----------
|
||||||
|
class DoubleConv(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
def forward(self, x): return self.net(x)
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, in_ch=3, out_ch=1, base=64):
|
||||||
|
super().__init__()
|
||||||
|
self.d1 = DoubleConv(in_ch, base)
|
||||||
|
self.d2 = DoubleConv(base, base*2)
|
||||||
|
self.d3 = DoubleConv(base*2, base*4)
|
||||||
|
self.d4 = DoubleConv(base*4, base*8)
|
||||||
|
self.bottom = DoubleConv(base*8, base*16)
|
||||||
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, 2)
|
||||||
|
self.u4 = DoubleConv(base*16, base*8)
|
||||||
|
self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2)
|
||||||
|
self.u3 = DoubleConv(base*8, base*4)
|
||||||
|
self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2)
|
||||||
|
self.u2 = DoubleConv(base*4, base*2)
|
||||||
|
self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2)
|
||||||
|
self.u1 = DoubleConv(base*2, base)
|
||||||
|
self.out = nn.Conv2d(base, out_ch, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1 = self.d1(x); p1 = self.pool(c1)
|
||||||
|
c2 = self.d2(p1); p2 = self.pool(c2)
|
||||||
|
c3 = self.d3(p2); p3 = self.pool(c3)
|
||||||
|
c4 = self.d4(p3); p4 = self.pool(c4)
|
||||||
|
cb = self.bottom(p4)
|
||||||
|
x = self.up4(cb); x = torch.cat([x, c4], 1); x = self.u4(x)
|
||||||
|
x = self.up3(x); x = torch.cat([x, c3], 1); x = self.u3(x)
|
||||||
|
x = self.up2(x); x = torch.cat([x, c2], 1); x = self.u2(x)
|
||||||
|
x = self.up1(x); x = torch.cat([x, c1], 1); x = self.u1(x)
|
||||||
|
return self.out(x)
|
||||||
|
|
||||||
|
# ---------- Loss(BCE + Dice) ----------
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def __init__(self, eps=1e-6): super().__init__(); self.eps=eps
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
probs = torch.sigmoid(logits)
|
||||||
|
num = 2*(probs*targets).sum(dim=(2,3)) + self.eps
|
||||||
|
den = (probs+targets).sum(dim=(2,3)) + self.eps
|
||||||
|
return 1 - (num/den).mean()
|
||||||
|
|
||||||
|
def bce_dice_loss(logits, targets):
|
||||||
|
bce = nn.functional.binary_cross_entropy_with_logits(logits, targets)
|
||||||
|
dice = DiceLoss()(logits, targets)
|
||||||
|
return bce*0.5 + dice*0.5
|
||||||
|
|
||||||
|
# ---------- Train ----------
|
||||||
|
def train():
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"使用设备: {device}") # 添加这行
|
||||||
|
net = UNet().to(device)
|
||||||
|
opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||||
|
|
||||||
|
# 根据你的文件路径修改这里
|
||||||
|
train_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/masks", train=True)
|
||||||
|
val_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/masks", train=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=2)
|
||||||
|
|
||||||
|
best = 1e9
|
||||||
|
for epoch in range(100):
|
||||||
|
net.train(); loss_sum=0
|
||||||
|
for img, msk in tqdm(train_loader, desc=f"Epoch {epoch}"):
|
||||||
|
img, msk = img.to(device), msk.to(device)
|
||||||
|
opt.zero_grad(); logits = net(img); loss = bce_dice_loss(logits, msk)
|
||||||
|
loss.backward(); opt.step(); loss_sum += loss.item()
|
||||||
|
# val
|
||||||
|
net.eval(); iou_sum=0; npx=0
|
||||||
|
with torch.no_grad():
|
||||||
|
for img, msk in val_loader:
|
||||||
|
img, msk = img.to(device), msk.to(device)
|
||||||
|
pr = (torch.sigmoid(net(img))>0.5).float()
|
||||||
|
inter = (pr*msk).sum(); union = (pr+msk - pr*msk).sum()
|
||||||
|
iou = (inter+1e-6)/(union+1e-6); iou_sum += iou.item(); npx+=1
|
||||||
|
val_iou = iou_sum/max(npx,1)
|
||||||
|
print(f"epoch {epoch} train_loss={loss_sum/len(train_loader):.4f} val_iou={val_iou:.4f}")
|
||||||
|
# 简单保存
|
||||||
|
if val_iou < best: best = val_iou
|
||||||
|
torch.save(net.state_dict(), "unet_wire_red.pth")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
||||||
7
src/pre_func/utils.py
Normal file
7
src/pre_func/utils.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
def ensure_output_dir():
|
||||||
|
"""确保output文件夹存在,不存在则创建"""
|
||||||
|
if not os.path.exists("output"):
|
||||||
|
os.makedirs("output")
|
||||||
|
return "output"
|
||||||
130
src/pre_func/wires_recog.py
Normal file
130
src/pre_func/wires_recog.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
from train_unet import UNet
|
||||||
|
from utils import ensure_output_dir
|
||||||
|
|
||||||
|
# 初始化UNet模型
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = UNet().to(device)
|
||||||
|
model.load_state_dict(torch.load("src/best_model/unet_wire_red.pth"))
|
||||||
|
model.eval() # 设置为推理模式
|
||||||
|
|
||||||
|
def preprocess_image(image, target_size=(512, 512)):
|
||||||
|
"""预处理图像以适应UNet模型输入"""
|
||||||
|
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
img = cv2.resize(img, target_size)
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
std = np.array([0.229, 0.224, 0.225])
|
||||||
|
img = (img - mean) / std
|
||||||
|
img = np.transpose(img, (2, 0, 1))
|
||||||
|
img = torch.tensor(img, dtype=torch.float32)
|
||||||
|
img = img.unsqueeze(0).to(device)
|
||||||
|
return img
|
||||||
|
|
||||||
|
# 修改后的函数:添加手动掩码路径参数
|
||||||
|
def get_unet_mask(image, manual_mask_path="src/pre_func/wire_recognition_input.png"):
|
||||||
|
"""获取导线掩码,支持使用手动输入的掩码图片"""
|
||||||
|
original_size = image.shape[:2]
|
||||||
|
|
||||||
|
# 如果提供了手动掩码路径,则直接读取该图片作为掩码
|
||||||
|
if manual_mask_path and os.path.exists(manual_mask_path):
|
||||||
|
# 读取手动掩码并转换为二值图像
|
||||||
|
mask = cv2.imread(manual_mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
# 调整掩码尺寸以匹配原始图像
|
||||||
|
mask = cv2.resize(mask, (original_size[1], original_size[0]))
|
||||||
|
# 确保掩码是二值图像(0和255)
|
||||||
|
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
||||||
|
return mask
|
||||||
|
# 否则使用UNet模型生成掩码
|
||||||
|
else:
|
||||||
|
input_tensor = preprocess_image(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_tensor)
|
||||||
|
output = torch.sigmoid(output)
|
||||||
|
predicted_mask = (output > 0.3).cpu().numpy().squeeze()
|
||||||
|
mask = cv2.resize(predicted_mask.astype(np.uint8),
|
||||||
|
(original_size[1], original_size[0])) * 255
|
||||||
|
cleaned_mask = mask
|
||||||
|
return cleaned_mask
|
||||||
|
|
||||||
|
def get_skeleton(binary_mask):
|
||||||
|
skel = skeletonize(binary_mask // 255).astype(np.uint8) * 255
|
||||||
|
return skel
|
||||||
|
|
||||||
|
def find_endpoints(skel_img):
|
||||||
|
endpoints = []
|
||||||
|
h, w = skel_img.shape
|
||||||
|
for y in range(1, h - 1):
|
||||||
|
for x in range(1, w - 1):
|
||||||
|
if skel_img[y, x] == 255:
|
||||||
|
patch = skel_img[y - 1:y + 2, x - 1:x + 2]
|
||||||
|
if cv2.countNonZero(patch) == 2:
|
||||||
|
endpoints.append((x, y))
|
||||||
|
return endpoints
|
||||||
|
|
||||||
|
def detect_wires_and_endpoints(image):
|
||||||
|
"""使用UNet进行导线分割,然后检测端点(支持传入图像数组)"""
|
||||||
|
original = image
|
||||||
|
if original is None:
|
||||||
|
raise ValueError("传入的图像为空,请检查图像是否有效")
|
||||||
|
|
||||||
|
# 调整图像大小(与仪器识别保持一致的尺寸)
|
||||||
|
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||||
|
result_img = img.copy() # 只包含导线相关绘制,不含元件框
|
||||||
|
|
||||||
|
# 使用UNet获取导线掩码
|
||||||
|
mask_total = get_unet_mask(img)
|
||||||
|
|
||||||
|
# 保存UNet生成的掩码
|
||||||
|
output_dir = ensure_output_dir()
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "unet_wire_mask.png"), mask_total)
|
||||||
|
|
||||||
|
# 骨架化处理
|
||||||
|
skeleton = get_skeleton(mask_total)
|
||||||
|
|
||||||
|
# 查找连通组件(不同的导线)
|
||||||
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(skeleton, connectivity=8)
|
||||||
|
all_wires = []
|
||||||
|
|
||||||
|
# 处理每个导线
|
||||||
|
for i in range(1, num_labels):
|
||||||
|
wire_mask = (labels == i).astype(np.uint8) * 255
|
||||||
|
pixel_count = cv2.countNonZero(wire_mask)
|
||||||
|
|
||||||
|
if pixel_count < 100:
|
||||||
|
continue
|
||||||
|
|
||||||
|
endpoints = find_endpoints(wire_mask)
|
||||||
|
if len(endpoints) >= 2:
|
||||||
|
start, end = endpoints[0], endpoints[-1]
|
||||||
|
|
||||||
|
# 仅绘制导线相关标记
|
||||||
|
cv2.circle(result_img, start, 6, (0, 0, 255), -1)
|
||||||
|
cv2.circle(result_img, end, 6, (255, 0, 0), -1)
|
||||||
|
cv2.line(result_img, start, end, (0, 255, 255), 2)
|
||||||
|
|
||||||
|
wire_data = {
|
||||||
|
"start": {"x": int(start[0]), "y": int(start[1])},
|
||||||
|
"end": {"x": int(end[0]), "y": int(end[1])},
|
||||||
|
}
|
||||||
|
all_wires.append(wire_data)
|
||||||
|
|
||||||
|
# 保存导线检测结果(不含元件框)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "wire_detection_result.png"), result_img)
|
||||||
|
|
||||||
|
return all_wires
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
image_path = "/home/gqw/unet_ai/love.jpg"
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
raise FileNotFoundError(f"无法读取图像: {image_path}")
|
||||||
|
wires = detect_wires_and_endpoints(img)
|
||||||
|
print("检测到的导线数量:", len(wires))
|
||||||
|
for i, wire in enumerate(wires):
|
||||||
|
print(f"导线 #{i+1}: 起点 {wire['start']}, 终点 {wire['end']}")
|
||||||
82
src/pre_func/yanzheng.py
Normal file
82
src/pre_func/yanzheng.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
# 确保output文件夹存在
|
||||||
|
def ensure_output_dir():
|
||||||
|
output_dir = "output"
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
return output_dir
|
||||||
|
|
||||||
|
# 1. 读取JSON文件(请替换为你的文件路径)
|
||||||
|
json_path = "/home/gqw/unet_ai/output/recognition_result.json"
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
result = json.load(f)
|
||||||
|
|
||||||
|
# 2. 从JSON中提取原始图片(假设图片在最外层的recognizedImage字段,根据实际结构调整)
|
||||||
|
try:
|
||||||
|
base64_str = result["recognizedImage"].split(",")[1] # 去掉前缀
|
||||||
|
image_bytes = base64.b64decode(base64_str)
|
||||||
|
img = Image.open(BytesIO(image_bytes))
|
||||||
|
except KeyError:
|
||||||
|
# 若JSON中没有图片,可手动指定本地图片路径(请替换为你的原图路径)
|
||||||
|
print("未找到recognizedImage字段,使用本地图片...")
|
||||||
|
img = Image.open("/home/gqw/unet_ai/test/5.jpg") # 替换为你的原图路径
|
||||||
|
|
||||||
|
# 3. 准备绘图工具
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
# 尝试加载系统字体(可替换为你的字体路径)
|
||||||
|
font = ImageFont.truetype("simhei.ttf", 18) # 支持中文的字体
|
||||||
|
except:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
print("警告:未找到中文字体,可能无法正常显示中文标签")
|
||||||
|
|
||||||
|
# 4. 标注components中的仪器(核心标注逻辑)
|
||||||
|
# components结构:[{"label": "滑动变阻器", "bbox": [x1, y1, x2, y2]}, ...]
|
||||||
|
instrument_color = (255, 0, 0) # 红色框标注仪器
|
||||||
|
for idx, comp in enumerate(result["components"]):
|
||||||
|
# 提取标签和边界框
|
||||||
|
label = comp["label"] # 仪器类型(如"滑动变阻器")
|
||||||
|
bbox = comp["bbox"] # 边界框坐标 [x1, y1, x2, y2]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
||||||
|
# 绘制矩形框
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=instrument_color, width=3)
|
||||||
|
|
||||||
|
# 绘制标签(避免文字超出图片范围)
|
||||||
|
text_x = x1 if x1 + 150 < img.width else img.width - 150
|
||||||
|
text_y = y1 - 30 if y1 > 30 else y1 + 10
|
||||||
|
draw.text(
|
||||||
|
(text_x, text_y),
|
||||||
|
f"{idx+1}. {label}",
|
||||||
|
fill=instrument_color,
|
||||||
|
font=font
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. (可选)标注导线(从circuitData.instances中提取wire类型)
|
||||||
|
wire_color = (0, 0, 255) # 蓝色框标注导线
|
||||||
|
wires = [inst for inst in result["circuitData"]["instances"] if inst["key"] == "wire"]
|
||||||
|
for idx, wire in enumerate(wires):
|
||||||
|
# 导线的位置信息(x,y为中心点,这里简化为小矩形)
|
||||||
|
x, y = wire["x"], wire["y"]
|
||||||
|
size = 20 # 导线标注框大小(可调整)
|
||||||
|
x1, y1 = x - size, y - size
|
||||||
|
x2, y2 = x + size, y + size
|
||||||
|
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=wire_color, width=2)
|
||||||
|
draw.text(
|
||||||
|
(x1, y1 - 20),
|
||||||
|
f"导线{idx+1}",
|
||||||
|
fill=wire_color,
|
||||||
|
font=font
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 保存标注后的图片到output文件夹
|
||||||
|
output_dir = ensure_output_dir()
|
||||||
|
output_path = os.path.join(output_dir, "circuit_with_annotations.jpg")
|
||||||
|
img.save(output_path)
|
||||||
|
print(f"标注完成!结果已保存到:{output_path}")
|
||||||
BIN
src/yolo/Arial.Unicode.ttf
Normal file
BIN
src/yolo/Arial.Unicode.ttf
Normal file
Binary file not shown.
24
src/yolo/config.yaml
Normal file
24
src/yolo/config.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# data/config.yaml
|
||||||
|
path: ./dataset # 根目录(prepare_dataset.py 输出)
|
||||||
|
train: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/train
|
||||||
|
val: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/val
|
||||||
|
|
||||||
|
# nc 必须精确等于 names 列表的长度
|
||||||
|
nc: 10
|
||||||
|
|
||||||
|
names:
|
||||||
|
"0": "微安电流表" #躺着的
|
||||||
|
"1": "待测表头" #站着的
|
||||||
|
"2": "电阻箱"
|
||||||
|
"3": "滑动变阻器"
|
||||||
|
"4": "单刀双掷开关"
|
||||||
|
"5": "电源"
|
||||||
|
# <- 从 original_names.json 中拷过来的前若干项(顺序不要变)
|
||||||
|
"6": "单刀单掷开关" # 代码中存在该标签映射
|
||||||
|
"7": "灯泡" # 代码中存在该标签映射
|
||||||
|
"8": "电池电源" # 代码中存在该标签映射(对应battery)
|
||||||
|
"9": "电阻" # 代码中存在该标签映射
|
||||||
|
# "10": "螺线管" # 代码中存在该标签映射(对应电感)
|
||||||
|
# "11": "电容" # 代码中存在该标签映射
|
||||||
|
|
||||||
|
|
||||||
26
src/yolo/infer.py
Normal file
26
src/yolo/infer.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# 完全固定所有路径(根据你的实际情况修改以下路径)
|
||||||
|
WEIGHTS_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.16.pt' # 模型权重路径
|
||||||
|
IMAGE_PATH = '/home/gqw/unet_ai/test/love.jpg' # 待检测图片路径
|
||||||
|
OUTPUT_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 结果保存路径
|
||||||
|
CONFIDENCE = 0.25 # 置信度阈值
|
||||||
|
SAVE_RESULTS = True # 是否保存检测结果(True/False)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = YOLO(WEIGHTS_PATH)
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
results = model(
|
||||||
|
IMAGE_PATH,
|
||||||
|
conf=CONFIDENCE,
|
||||||
|
save=SAVE_RESULTS,
|
||||||
|
project=OUTPUT_PATH,
|
||||||
|
name='predictions'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印结果信息
|
||||||
|
print(f"检测图片:{IMAGE_PATH}")
|
||||||
|
print(f"检测完成,结果已保存至:{OUTPUT_PATH}/predictions")
|
||||||
|
|
||||||
15
src/yolo/inspect_model_names.py
Normal file
15
src/yolo/inspect_model_names.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# inspect_model_names.py
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import argparse, json
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights', '-w', default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/model_2.pt')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
m = YOLO(args.weights)
|
||||||
|
print("模型里的类别字典 (.names):")
|
||||||
|
print(m.names) # dict: {0: 'class0', 1: 'class1', ...}
|
||||||
|
# 可输出为文件
|
||||||
|
with open('original_names.json', 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(m.names, f, ensure_ascii=False, indent=2)
|
||||||
|
print("保存 original_names.json")
|
||||||
116
src/yolo/prepare_dataset.py
Normal file
116
src/yolo/prepare_dataset.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# prepare_dataset.py
|
||||||
|
import os, shutil, random, argparse, json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def copy_and_prefix(src_images, src_labels, out_images, out_labels, prefix):
|
||||||
|
os.makedirs(out_images, exist_ok=True)
|
||||||
|
os.makedirs(out_labels, exist_ok=True)
|
||||||
|
for img_path in Path(src_images).glob('*'):
|
||||||
|
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp']:
|
||||||
|
continue
|
||||||
|
base = img_path.name
|
||||||
|
new_img = os.path.join(out_images, f"{prefix}_{base}")
|
||||||
|
shutil.copy2(img_path, new_img)
|
||||||
|
# label
|
||||||
|
label_src = os.path.join(src_labels, img_path.with_suffix('.txt').name)
|
||||||
|
label_dst = os.path.join(out_labels, f"{prefix}_{img_path.with_suffix('.txt').name}")
|
||||||
|
if os.path.exists(label_src):
|
||||||
|
shutil.copy2(label_src, label_dst)
|
||||||
|
else:
|
||||||
|
# create empty label file (no objects)
|
||||||
|
open(label_dst, 'w', encoding='utf-8').close()
|
||||||
|
|
||||||
|
def remap_label_file(label_path, mapping):
|
||||||
|
# mapping: { old_id (str) : new_id (int) }
|
||||||
|
if not os.path.exists(label_path):
|
||||||
|
return
|
||||||
|
lines = []
|
||||||
|
with open(label_path, 'r', encoding='utf-8') as f:
|
||||||
|
for l in f:
|
||||||
|
if not l.strip(): continue
|
||||||
|
parts = l.strip().split()
|
||||||
|
cls = parts[0]
|
||||||
|
if cls in mapping:
|
||||||
|
parts[0] = str(mapping[cls])
|
||||||
|
else:
|
||||||
|
# if not in mapping, we keep it or raise
|
||||||
|
parts[0] = parts[0] # keep as is
|
||||||
|
lines.append(' '.join(parts))
|
||||||
|
with open(label_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write('\n'.join(lines))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--sources', nargs='+', required=True,
|
||||||
|
help='一组数据源,每个源是 IMAGE_DIR:LABEL_DIR,例如 /data/old/images:/data/old/labels')
|
||||||
|
parser.add_argument('--out', required=True, help='输出根目录,例如 ./dataset')
|
||||||
|
parser.add_argument('--val', type=float, default=0.2, help='验证集比例(0~1)')
|
||||||
|
parser.add_argument('--seed', type=int, default=42)
|
||||||
|
parser.add_argument('--remap', default=None,
|
||||||
|
help='可选的 JSON 文件,格式 {"old_id":"new_id", ...}(old_id作为字符串)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
out = Path(args.out)
|
||||||
|
imgs_all = out / 'images'
|
||||||
|
labels_all = out / 'labels'
|
||||||
|
imgs_all.mkdir(parents=True, exist_ok=True)
|
||||||
|
labels_all.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 临时收集目录
|
||||||
|
temp_images = out / 'temp_images'
|
||||||
|
temp_labels = out / 'temp_labels'
|
||||||
|
if temp_images.exists(): shutil.rmtree(temp_images)
|
||||||
|
if temp_labels.exists(): shutil.rmtree(temp_labels)
|
||||||
|
temp_images.mkdir()
|
||||||
|
temp_labels.mkdir()
|
||||||
|
|
||||||
|
# copy所有源,每个源加前缀避免文件名冲突
|
||||||
|
for i, src in enumerate(args.sources):
|
||||||
|
if ':' not in src:
|
||||||
|
raise ValueError("source 格式应为 IMAGE_DIR:LABEL_DIR")
|
||||||
|
img_dir, lbl_dir = src.split(':', 1)
|
||||||
|
prefix = f"src{i}"
|
||||||
|
copy_and_prefix(img_dir, lbl_dir, str(temp_images), str(temp_labels), prefix)
|
||||||
|
|
||||||
|
# 如果需要 remap
|
||||||
|
mapping = {}
|
||||||
|
if args.remap:
|
||||||
|
with open(args.remap, 'r', encoding='utf-8') as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
# keys expected as strings
|
||||||
|
mapping = {str(k): int(v) for k, v in raw.items()}
|
||||||
|
|
||||||
|
# 如果 remap 存在,则对每个临时标签文件进行 remap
|
||||||
|
if mapping:
|
||||||
|
for lf in temp_labels.glob('*.txt'):
|
||||||
|
remap_label_file(str(lf), mapping)
|
||||||
|
|
||||||
|
# 收集所有样本名
|
||||||
|
samples = [p.name for p in temp_images.glob('*') if p.suffix.lower() in ['.jpg','.jpeg','.png']]
|
||||||
|
random.shuffle(samples)
|
||||||
|
|
||||||
|
val_n = int(len(samples) * args.val)
|
||||||
|
val_samples = set(samples[:val_n])
|
||||||
|
train_samples = set(samples[val_n:])
|
||||||
|
|
||||||
|
# 创建最终目录结构
|
||||||
|
for split, sset in [('train', train_samples), ('val', val_samples)]:
|
||||||
|
img_out = out / 'images' / split
|
||||||
|
lbl_out = out / 'labels' / split
|
||||||
|
img_out.mkdir(parents=True, exist_ok=True)
|
||||||
|
lbl_out.mkdir(parents=True, exist_ok=True)
|
||||||
|
for nm in sset:
|
||||||
|
shutil.move(str(temp_images / nm), str(img_out / nm))
|
||||||
|
lbl_name = Path(nm).with_suffix('.txt').name
|
||||||
|
shutil.move(str(temp_labels / lbl_name), str(lbl_out / lbl_name))
|
||||||
|
|
||||||
|
# cleanup temp if any remaining
|
||||||
|
if temp_images.exists(): shutil.rmtree(temp_images, ignore_errors=True)
|
||||||
|
if temp_labels.exists(): shutil.rmtree(temp_labels, ignore_errors=True)
|
||||||
|
|
||||||
|
print(f"完成。输出目录: {out}")
|
||||||
|
print(f"train: {len(train_samples)}, val: {len(val_samples)}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
35
src/yolo/train_yolo.py
Normal file
35
src/yolo/train_yolo.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# train_yolo.py
|
||||||
|
import argparse
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights', '-w',
|
||||||
|
default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/best.pt',
|
||||||
|
help='预训练权重路径')
|
||||||
|
parser.add_argument('--data', '-d',
|
||||||
|
default='/home/gqw/unet_ai/elements_wires_congition/src/yolo/config.yaml',
|
||||||
|
help='data yaml 路径')
|
||||||
|
parser.add_argument('--epochs', type=int, default=200)
|
||||||
|
parser.add_argument('--imgsz', type=int, default=640)
|
||||||
|
parser.add_argument('--batch', type=int, default=16)
|
||||||
|
parser.add_argument('--device', default=0, help='GPU 设备号或 "cpu"')
|
||||||
|
parser.add_argument('--freeze', action='store_true', help='是否冻结backbone')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 加载预训练模型
|
||||||
|
model = YOLO(args.weights)
|
||||||
|
print("开始训练,参数:", args)
|
||||||
|
model.train(
|
||||||
|
data=args.data,
|
||||||
|
epochs=args.epochs,
|
||||||
|
imgsz=args.imgsz,
|
||||||
|
batch=args.batch,
|
||||||
|
device=args.device,
|
||||||
|
amp=False # 新增此行:禁用AMP检查,避免加载yolo11n.pt
|
||||||
|
)
|
||||||
|
print("训练结束,权重保存在 runs/detect/ 下")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user