first commit

This commit is contained in:
feie9456 2026-01-05 15:48:54 +08:00
commit f9125370af
26 changed files with 2927 additions and 0 deletions

223
.gitignore vendored Normal file
View File

@ -0,0 +1,223 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Machine Learning / AI specific
# Model files (you may want to track some models)
*.pkl
*.joblib
*.h5
*.hdf5
*.pb
*.onnx
# Uncomment the next line if you don't want to track your trained models
# *.pt
# Dataset and data files
data/
datasets/
*.csv
*.xml
*.txt
# Image files (uncomment if you have large image datasets)
# *.jpg
# *.jpeg
# *.png
# *.gif
# *.bmp
# *.tiff
# Logs and outputs
logs/
outputs/
results/
checkpoints/
runs/
wandb/
# TensorBoard
events.out.tfevents.*
# Jupyter Notebook checkpoints
.ipynb_checkpoints/
# VS Code
.vscode/
*.code-workspace
# PyCharm
.idea/
*.iws
*.iml
*.ipr
# macOS
.DS_Store
.AppleDouble
.LSOverride
Icon?
._*
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Windows
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
*.tmp
*.temp
Desktop.ini
$RECYCLE.BIN/
*.cab
*.msi
*.msix
*.msm
*.msp
*.lnk
# Linux
*~
.fuse_hidden*
.directory
.Trash-*
.nfs*
# Temporary files
*.tmp
*.temp
*.swp
*.swo
*~
# Project specific
# Add any project-specific files you want to ignore here
test_images/
temp/
cache/

2
README.md Normal file
View File

@ -0,0 +1,2 @@
stand 待测表头
lie 伟岸电流表

View File

@ -0,0 +1,26 @@
[Unit]
Description=Circuit Recognition Web Service
After=network.target
Wants=network.target
[Service]
Type=simple
User=root
Group=root
WorkingDirectory=/home/feie9454/elements_wires_congition
Environment=PATH=/home/feie9454/elements_wires_congition/venv/bin
ExecStart=/home/feie9454/elements_wires_congition/venv/bin/python main.py
Restart=always
RestartSec=10
StandardOutput=journal
StandardError=journal
SyslogIdentifier=circuit-recognition
# 安全设置
NoNewPrivileges=true
PrivateTmp=true
ProtectSystem=strict
ReadWritePaths=/home/feie9454/elements_wires_congition
[Install]
WantedBy=multi-user.target

80
main.py Normal file
View File

@ -0,0 +1,80 @@
from flask import Flask, request, jsonify, send_from_directory, send_file
from flask_cors import CORS
import numpy as np
import cv2
import os
from src.pre_func.img_prec import img_recognition
# 设置静态文件目录
app = Flask(__name__, static_folder='src/web', static_url_path='/static')
CORS(app)
@app.route('/')
def index():
"""主页路由提供index.html"""
return send_from_directory(app.static_folder, 'index.html')
@app.route('/web')
def web_index():
"""Web应用入口"""
return send_from_directory(app.static_folder, 'index.html')
@app.route('/web/<path:filename>')
def web_static(filename):
"""提供web目录下的静态文件"""
return send_from_directory(app.static_folder, filename)
@app.route('/static/<path:filename>')
def static_files(filename):
"""提供静态文件服务"""
return send_from_directory(app.static_folder, filename)
@app.route('/api/status')
def api_status():
"""API状态检查"""
return jsonify({
'status': 'running',
'message': 'Elements Wires Recognition API is running',
'endpoints': {
'POST /process_image': 'Image recognition endpoint',
'GET /': 'Main web interface',
'GET /web': 'Web application',
'GET /api/status': 'API status check'
}
})
@app.route('/process_image', methods=['POST'])
def process_image():
if 'image' not in request.files:
return jsonify({'error': 'No image part in the request'}), 400
file = request.files['image']
if file.filename == '':
return jsonify({'error': 'No selected image'}), 400
file_bytes = np.frombuffer(file.read(), np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
result = img_recognition(img)
return result
if __name__ == '__main__':
print("🚀 Starting Elements Wires Recognition Web Server...")
print("📁 Static files served from: src/web/")
print("🌐 Web interface available at:")
print(" - http://localhost:8000/")
print(" - http://0.0.0.0:8000/")
print("🔧 API endpoints:")
print(" - POST /process_image (for image recognition)")
print(" - GET /web/<filename> (for static files)")
print(" - GET /static/<filename> (for static files)")
print("=" * 50)
app.run(host='0.0.0.0', debug=True, port=12399)

63
manage_service.sh Normal file
View File

@ -0,0 +1,63 @@
#!/bin/bash
# Circuit Recognition Service Management Script
# 使用方法: ./manage_service.sh [install|start|stop|restart|status|logs|uninstall]
SERVICE_NAME="circuit-recognition"
SERVICE_FILE="/home/feie9454/elements_wires_congition/circuit-recognition.service"
SYSTEM_SERVICE_PATH="/etc/systemd/system/circuit-recognition.service"
case "$1" in
install)
echo "安装服务..."
sudo cp "$SERVICE_FILE" "$SYSTEM_SERVICE_PATH"
sudo systemctl daemon-reload
sudo systemctl enable "$SERVICE_NAME"
echo "✅ 服务已安装并设置为开机自启"
echo "使用 './manage_service.sh start' 启动服务"
;;
start)
echo "启动服务..."
sudo systemctl start "$SERVICE_NAME"
echo "✅ 服务已启动"
;;
stop)
echo "停止服务..."
sudo systemctl stop "$SERVICE_NAME"
echo "✅ 服务已停止"
;;
restart)
echo "重启服务..."
sudo systemctl restart "$SERVICE_NAME"
echo "✅ 服务已重启"
;;
status)
echo "服务状态:"
sudo systemctl status "$SERVICE_NAME"
;;
logs)
echo "查看服务日志:"
sudo journalctl -u "$SERVICE_NAME" -f
;;
uninstall)
echo "卸载服务..."
sudo systemctl stop "$SERVICE_NAME"
sudo systemctl disable "$SERVICE_NAME"
sudo rm -f "$SYSTEM_SERVICE_PATH"
sudo systemctl daemon-reload
echo "✅ 服务已卸载"
;;
*)
echo "使用方法: $0 {install|start|stop|restart|status|logs|uninstall}"
echo ""
echo "命令说明:"
echo " install - 安装服务并设置开机自启"
echo " start - 启动服务"
echo " stop - 停止服务"
echo " restart - 重启服务"
echo " status - 查看服务状态"
echo " logs - 实时查看服务日志"
echo " uninstall - 卸载服务"
exit 1
;;
esac

15
pyproject.toml Normal file
View File

@ -0,0 +1,15 @@
[project]
name = "circuit-recognition"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"albumentations>=2.0.8",
"flask>=3.1.1",
"flask-cors>=6.0.1",
"numpy>=2.3.0",
"opencv-python>=4.11.0.86",
"scikit-image>=0.25.2",
"ultralytics>=8.3.156",
]

1
src/__init__.py Normal file
View File

@ -0,0 +1 @@
# 这个文件使 src 目录成为一个 Python 包

BIN
src/best_model/best9.17.pt Normal file

Binary file not shown.

Binary file not shown.

1
src/pre_func/__init__.py Normal file
View File

@ -0,0 +1 @@
# 这个文件使 pre_func 目录成为一个 Python 包

53
src/pre_func/color.py Normal file
View File

@ -0,0 +1,53 @@
import cv2
import numpy as np
from sklearn.cluster import KMeans
from tensorflow.keras.models import load_model
# =============================
# 1. 加载训练好的 U-Net 模型
# =============================
model = load_model("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/unet_wire.pth", compile=False)
# 读入测试图像
img = cv2.imread("/home/gqw/unet_ai/test/IMG_0059.jpeg")
h, w, _ = img.shape
input_img = cv2.resize(img, (256, 256)) / 255.0
input_img = np.expand_dims(input_img, axis=0)
# =============================
# 2. U-Net 分割预测(得到二值 mask
# =============================
pred = model.predict(input_img)[0, :, :, 0]
mask = (pred > 0.5).astype(np.uint8) # 阈值化
mask = cv2.resize(mask, (w, h))
# =============================
# 3. 提取导线区域的像素颜色
# =============================
wire_pixels = img[mask == 1] # N x 3 (RGB)
# KMeans 聚类假设有3种颜色的导线
n_colors = 3
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(wire_pixels)
labels = kmeans.labels_
# =============================
# 4. 生成每个颜色的 mask
# =============================
colored_masks = []
for i in range(n_colors):
color_mask = np.zeros(mask.shape, dtype=np.uint8)
color_mask[mask == 1] = (labels == i).astype(np.uint8)
colored_masks.append(color_mask)
# =============================
# 5. 可视化结果
# =============================
result = img.copy()
colors = [(0,0,255), (0,255,0), (255,0,0)] # 每个聚类显示的颜色
for i, cmask in enumerate(colored_masks):
result[cmask == 1] = colors[i % len(colors)]
cv2.imwrite("segmented_by_color.jpg", result)
print("输出完成segmented_by_color.jpg")

5
src/pre_func/config.py Normal file
View File

@ -0,0 +1,5 @@
# 模型与路径配置
YOLO_MODEL_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.17.pt' # 从infer.py迁移
DETECTION_CONFIDENCE = 0.5 # 从infer.py迁移可根据需要调整
OUTPUT_BASE_DIR = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 从infer.py迁移
WIRE_CONNECT_THRESHOLD = 50 # 导线端点与连接点的最大允许距离(像素)

83
src/pre_func/ele_recog.py Normal file
View File

@ -0,0 +1,83 @@
import cv2
import os
from ultralytics import YOLO
from config import YOLO_MODEL_PATH, DETECTION_CONFIDENCE
from utils import ensure_output_dir # 导入工具函数
def elements_recognition(img):
# 加载模型
model = YOLO(YOLO_MODEL_PATH)
original = img
# 保持原缩放逻辑(与导线识别尺寸一致)
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
# 推理预测
results = model(img, conf=DETECTION_CONFIDENCE)[0]
components = []
# 创建只含YOLO识别结果的图像
yolo_img = img.copy()
# 标签映射与img_prec.py保持一致
labels = {
"微安电流表": "ammeter",
"黑盒电流表": "ammeter",
"待测表头": "voltmeter",
"黑盒电压表": "voltmeter",
"电阻箱": "resistance_box",
"滑动变阻器": "sliding_rheostat",
"单刀双掷开关": "switch",
"单刀单掷开关": "switch",
"灯泡": "light_bulb",
"电源": "power_supply",
"电池电源": "battery",
"电阻": "resistor",
"螺线管": "inductor",
"电容": "capacitor"
}
# 为每种元件类型定义专属颜色 (B, G, R)
color_map = {
"ammeter": (0, 255, 0), # 电流表:绿色
"voltmeter": (255, 0, 0), # 电压表:蓝色
"resistance_box": (0, 255, 255), # 电阻箱:黄色
"sliding_rheostat": (128, 0, 128), # 滑动变阻器:紫色
"switch": (255, 165, 0), # 开关:橙色
"light_bulb": (255, 255, 0), # 灯泡:青色
"power_supply": (0, 0, 255), # 电源:红色
"battery": (128, 128, 0), # 电池:橄榄绿
"resistor": (0, 128, 128), # 电阻:深青
"inductor": (128, 0, 0), # 电感(螺线管):深红色
"capacitor": (0, 0, 128) # 电容:深蓝色
}
# 遍历检测结果并绘制
for box in results.boxes:
if box.conf < DETECTION_CONFIDENCE:
continue
cls = int(box.cls[0])
original_label = model.names[cls]
# 映射为显示标签
display_label = labels.get(original_label, original_label)
x1, y1, x2, y2 = map(int, box.xyxy[0])
components.append({
"label": original_label,
"bbox": [x1, y1, x2, y2]
})
# 获取当前元件的颜色(默认灰色)
box_color = color_map.get(display_label, (128, 128, 128))
# 在yolo_img上绘制
cv2.rectangle(yolo_img, (x1, y1), (x2, y2), box_color, 2)
cv2.putText(yolo_img, display_label, (x1, y1 - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
# 保存YOLO识别结果图像
output_dir = ensure_output_dir()
save_path = os.path.join(output_dir, "yolo_recognition.png")
cv2.imwrite(save_path, yolo_img)
print(f"YOLO识别结果已保存到: {save_path}")
return components

484
src/pre_func/img_prec.py Normal file
View File

@ -0,0 +1,484 @@
import datetime
import wires_recog
import ele_recog
import cv2
import numpy as np
import base64
import os
from utils import ensure_output_dir # 导入工具函数
# 从 elements.ts 转换过来的 Python 结构(包含电流表、电压表独立定义)
elements_ts = [
# 文本框
{
"key": "text_box",
"name": "文本框",
"defaultSize": 0,
"connectionPoints": [],
"propertySchemas": [
{"key": "text", "label": "文本", "type": "text", "default": "双击编辑"},
{"key": "fontSize", "label": "字体大小", "type": "number", "unit": "px", "default": 24},
{"key": "color", "label": "颜色", "type": "color", "default": "#111827"}
]
},
# 电池
{
"key": "battery",
"name": "电池",
"defaultSize": 110,
"connectionPoints": [
{"x": 0.95, "y": 0.5, "name": "正极"},
{"x": 0.05, "y": 0.5, "name": "负极"}
],
"propertySchemas": [
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 3},
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.5}
]
},
# 电源
{
"key": "power_supply",
"name": "电源",
"defaultSize": 180,
"connectionPoints": [
{"x": 0.8, "y": 0.8, "name": "+"},
{"x": 0.85, "y": 0.2, "name": "-"}
],
"propertySchemas": [
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 5},
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.1}
]
},
# 开关
{
"key": "switch",
"name": "开关",
"defaultSize": 130,
"connectionPoints": [
{"x": 0.05, "y": 0.7, "name": "A"},
{"x": 0.95, "y": 0.7, "name": "B"}
],
"stateImages": {"on": "switchOn", "off": "switchOff"},
"propertySchemas": [
{"key": "state", "label": "状态", "type": "select", "options": ["off", "on"], "default": "off"}
]
},
# 灯泡
{
"key": "light_bulb",
"name": "灯泡",
"defaultSize": 80,
"connectionPoints": [
{"x": 0.5, "y": 1, "name": "A"},
{"x": 0.72, "y": 0.8, "name": "B"}
],
"propertySchemas": [
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 20}
],
"stateImages": {"on": "lightBulbGrow", "off": "lightBulb"}
},
# 滑动变阻器
{
"key": "sliding_rheostat",
"name": "滑动变阻器",
"defaultSize": 180,
"connectionPoints": [
{"x": 0.15, "y": 0.28, "name": "A"},
{"x": 0.85, "y": 0.76, "name": "D"}
],
"propertySchemas": [
{"key": "maxResistance", "label": "最大电阻", "type": "number", "unit": "Ω", "default": 100},
{"key": "position", "label": "滑块位置", "type": "number", "default": 0.5}
]
},
# 普通电阻
{
"key": "resistor",
"name": "电阻",
"defaultSize": 110,
"connectionPoints": [
{"x": 0.05, "y": 0.5, "name": "A"},
{"x": 0.95, "y": 0.5, "name": "B"}
],
"propertySchemas": [
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 5}
]
},
# 电阻箱
{
"key": "resistance_box",
"name": "电阻箱",
"defaultSize": 180,
"connectionPoints": [
{"x": 0.1, "y": 0.4, "name": "A"},
{"x": 0.8, "y": 0.2, "name": "B"}
],
"propertySchemas": [
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 100}
]
},
# 电流表独立定义替代原meter的电流表预设
{
"key": "ammeter",
"name": "电流表",
"defaultSize": 150,
"connectionPoints": [
{"x": 0.25, "y": 0.8, "name": "正极"},
{"x": 0.85, "y": 0.8, "name": "负极"}
],
"preset": [
{"name": "微安电流表", "propertyValues": {"resistance": 0.05, "renderFunc": "(i) => `${(i * 1000000).toFixed(0)} uA`"}},
{"name": "毫安电流表", "propertyValues": {"resistance": 0.5, "renderFunc": "(i) => `${(i * 1000).toFixed(1)} mA`"}},
{"name": "电流表", "propertyValues": {"resistance": 0.1, "renderFunc": "(i) => `${(i).toFixed(2)} A`"}}
],
"propertySchemas": [
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 0.1},
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i) => `${(i).toFixed(2)} A`"}
]
},
# 电压表独立定义替代原meter的电压表预设
{
"key": "voltmeter",
"name": "电压表",
"defaultSize": 150,
"connectionPoints": [
{"x": 0.25, "y": 1, "name": "正极"},
{"x": 0.45, "y": 1, "name": "负极"}
],
"preset": [
{"name": "电压表", "propertyValues": {"resistance": 10000000, "renderFunc": "(i, r) => `${(i * r).toFixed(2)} V`"}}
],
"propertySchemas": [
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 10000000},
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i, r) => `${(i * r).toFixed(2)} V`"}
]
},
# 电容
{
"key": "capacitor",
"name": "电容",
"defaultSize": 120,
"connectionPoints": [
{"x": 0.4, "y": 0.9, "name": "正极"},
{"x": 0.6, "y": 0.96, "name": "负极"}
],
"propertySchemas": [
{"key": "capacitance", "label": "电容", "type": "number", "unit": "F", "default": 0.000001}
]
},
# 电感
{
"key": "inductor",
"name": "电感",
"defaultSize": 100,
"connectionPoints": [
{"x": 0.05, "y": 0.5, "name": "A"},
{"x": 0.95, "y": 0.5, "name": "B"}
],
"propertySchemas": [
{"key": "inductance", "label": "电感", "type": "number", "unit": "H", "default": 0.001}
]
}
]
def generate_json(wires, components):
# 标签映射(保持不变)
labels = {
"微安电流表": "ammeter", "待测表头(电流表)": "ammeter", "黑盒电流表": "ammeter",
"待测表头": "voltmeter", "黑盒电压表": "voltmeter",
"电阻箱": "resistance_box", "滑动变阻器": "sliding_rheostat",
"单刀双掷开关": "switch", "单刀单掷开关": "switch", "灯泡": "light_bulb",
"电源": "power_supply", "电池电源": "battery", "电阻": "resistor",
"螺线管": "inductor", "电容": "capacitor"
}
instances = []
element_count = 0
# 新增:存储每个元件的连接点实际坐标(格式:{元件id: [{"x": 实际x, "y": 实际y, "cpIndex": 连接点索引}, ...]}
component_connection_points = {}
# 处理元件(新增:计算连接点实际坐标)
for comp in components:
bbox = comp["bbox"]
original_label = comp["label"]
label = labels.get(original_label, original_label)
x = (bbox[0] + bbox[2]) / 2 # 元件中心x
y = (bbox[1] + bbox[3]) / 2 # 元件中心y
comp_id = str(element_count + 1) # 当前元件id
# 获取默认尺寸和连接点定义
default_size = 100
conn_points = []
for elem in elements_ts:
if elem["key"] == label:
default_size = elem.get("defaultSize", 100)
conn_points = elem.get("connectionPoints", [])
break
# 计算连接点实际坐标(基于边界框)
comp_width = bbox[2] - bbox[0]
comp_height = bbox[3] - bbox[1]
actual_conn_points = []
for cp_idx, cp in enumerate(conn_points):
# 连接点在图像中的实际坐标(相对边界框的比例转换)
actual_x = bbox[0] + cp["x"] * comp_width
actual_y = bbox[1] + cp["y"] * comp_height
actual_conn_points.append({
"x": actual_x,
"y": actual_y,
"cpIndex": cp_idx
})
# 存储当前元件的连接点信息
component_connection_points[comp_id] = actual_conn_points
# 元件实例(保持不变)
instance = {
"id": comp_id,
"key": label,
"x": x,
"y": y,
"size": default_size,
"rotation": 0,
"props": {
"__connections": {},
"originalLabel": original_label
}
}
# ...(省略元件属性赋值,保持不变)
instances.append(instance)
element_count += 1
# 处理导线(基于坐标匹配连接点)
wire_id = element_count + 1
for wire in wires:
# 导线起点和终点的实际坐标来自wires_recog的检测结果
wire_start = (wire["start"]["x"], wire["start"]["y"])
wire_end = (wire["end"]["x"], wire["end"]["y"])
# 找到距离导线起点最近的元件连接点
min_dist_start = float('inf')
start_match = None # 格式:(元件id, 连接点索引)
for comp_id, cp_list in component_connection_points.items():
for cp in cp_list:
# 计算欧氏距离
dist = np.hypot(cp["x"] - wire_start[0], cp["y"] - wire_start[1])
if dist < min_dist_start:
min_dist_start = dist
start_match = (comp_id, cp["cpIndex"])
# 找到距离导线终点最近的元件连接点
min_dist_end = float('inf')
end_match = None
for comp_id, cp_list in component_connection_points.items():
for cp in cp_list:
dist = np.hypot(cp["x"] - wire_end[0], cp["y"] - wire_end[1])
if dist < min_dist_end:
min_dist_end = dist
end_match = (comp_id, cp["cpIndex"])
# 处理无匹配的情况(避免报错)
if not start_match or not end_match:
# 若没有匹配到,默认连接第一个元件(兼容旧逻辑,但尽量避免)
start_match = (instances[0]["id"], 0) if instances else (str(wire_id), 0)
end_match = (instances[-1]["id"], 0) if instances else (str(wire_id + 1), 0)
# 添加导线实例(使用匹配到的连接点)
instances.append({
"id": str(wire_id),
"key": "wire",
"x": (wire_start[0] + wire_end[0]) / 2,
"y": (wire_start[1] + wire_end[1]) / 2,
"size": 1,
"rotation": 0,
"props": {
"__connections": {
"0": [{"instId": start_match[0], "cpIndex": start_match[1]}],
"1": [{"instId": end_match[0], "cpIndex": end_match[1]}]
}
}
})
wire_id += 1
# 生成最终JSON保持不变
return {
"version": 2,
"world": {
"scale": 1,
"translateX": -2089,
"translateY": -2123,
"worldSize": 5000,
"gridSize": 40
},
"instances": instances
}
def visualize_wires_and_components(image, results, components):
# 确保output目录存在
def ensure_output_dir():
if not os.path.exists("output"):
os.makedirs("output")
return "output"
original = image
# 缩放图像(与元件识别尺寸一致,确保坐标匹配)
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
if img is None:
raise FileNotFoundError("无法读取图像")
# === 步骤1建立元件ID映射适配ammeter/voltmeter===
comp_id_to_info = {}
# 可视化用的标签映射与generate_json保持一致
labels = {
"微安电流表": "ammeter",
"黑盒电流表": "ammeter",
"待测表头": "voltmeter",
"黑盒电压表": "voltmeter",
"电阻箱": "resistance_box",
"滑动变阻器": "sliding_rheostat",
"单刀双掷开关": "switch",
"单刀单掷开关": "switch",
"灯泡": "light_bulb",
"电源": "power_supply",
"电池电源": "battery",
"电阻": "resistor",
"螺线管": "inductor",
"电容": "capacitor"
}
for idx, comp in enumerate(components):
comp_id = str(idx + 1)
bbox = comp["bbox"]
original_label = comp["label"]
comp_type = labels.get(original_label, original_label)
# 从elements_ts获取当前元件的连接点兼容ammeter/voltmeter
conn_points = []
for elem in elements_ts:
if elem["key"] == comp_type:
conn_points = elem["connectionPoints"]
break
# 计算连接点实际像素坐标
comp_width = bbox[2] - bbox[0]
comp_height = bbox[3] - bbox[1]
points = []
for p in conn_points:
x = bbox[0] + (p["x"] * comp_width)
y = bbox[1] + (p["y"] * comp_height)
points.append({"x": int(x), "y": int(y), "name": p["name"]})
# 保存元件信息(用于后续绘制)
comp_id_to_info[comp_id] = {
"bbox": bbox,
"conn_points": points,
"type": comp_type,
"displayLabel": comp_type # 显示ammeter/voltmeter明确区分
}
# === 步骤2绘制元件连接点红色圆点===
for comp_info in comp_id_to_info.values():
for p in comp_info["conn_points"]:
cv2.circle(img, (p["x"], p["y"]), 6, (0, 0, 255), -1)
cv2.putText(img, p["name"], (p["x"]+5, p["y"]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
# === 步骤3绘制导线黄色线区分起点/终点)===
for instance in results["instances"]:
if instance["key"] != "wire":
continue
conn_info = instance["props"]["__connections"]
if "0" not in conn_info or "1" not in conn_info:
continue
# 获取导线连接的元件和连接点
start_conn = conn_info["0"][0]
end_conn = conn_info["1"][0]
start_comp_id = start_conn["instId"]
end_comp_id = end_conn["instId"]
start_cp_idx = start_conn["cpIndex"]
end_cp_idx = end_conn["cpIndex"]
# 跳过不存在的元件
if start_comp_id not in comp_id_to_info or end_comp_id not in comp_id_to_info:
continue
# 绘制导线
start_point = comp_id_to_info[start_comp_id]["conn_points"][start_cp_idx]
end_point = comp_id_to_info[end_comp_id]["conn_points"][end_cp_idx]
cv2.line(img, (start_point["x"], start_point["y"]),
(end_point["x"], end_point["y"]), (0, 255, 255), 2)
cv2.circle(img, (start_point["x"], start_point["y"]), 6, (0, 0, 255), -1) # 起点红
cv2.circle(img, (end_point["x"], end_point["y"]), 6, (255, 0, 0), -1) # 终点蓝
cv2.putText(img, "start", (start_point["x"]+5, start_point["y"]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
cv2.putText(img, "end", (end_point["x"]+5, end_point["y"]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
# === 步骤4绘制元件框和标签区分电流表/电压表)===
for idx, comp in enumerate(components):
comp_id = str(idx + 1)
if comp_id not in comp_id_to_info:
continue
comp_info = comp_id_to_info[comp_id]
bbox = comp_info["bbox"]
display_label = comp_info["displayLabel"] # 显示ammeter/voltmeter
# 绘制边界框(电流表绿色,电压表蓝色,其他默认绿色)
if display_label == "ammeter":
box_color = (0, 255, 0) # 电流表绿色
elif display_label == "voltmeter":
box_color = (255, 0, 0) # 电压表蓝色
else:
box_color = (0, 128, 0) # 其他元件深绿色
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), box_color, 2)
# 绘制标签(与框同色)
cv2.putText(img, display_label, (bbox[0], bbox[1] - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
# 保存并返回可视化结果
output_dir = ensure_output_dir()
save_path = os.path.join(output_dir, "recognized_result.jpg")
cv2.imwrite(save_path, img)
print(f"可视化结果已保存到: {save_path}")
# 转换为base64
_, encode_img = cv2.imencode('.jpg', img)
img_base64 = base64.b64encode(encode_img).decode('utf-8')
return img_base64
def img_recognition(img):
# 1. 检测导线
wires = wires_recog.detect_wires_and_endpoints(img)
# 2. 识别元件(模型输出原始标签)
elements = ele_recog.elements_recognition(img)
# 3. 生成区分后的JSONammeter/voltmeter
results = generate_json(wires, elements)
# 4. 可视化(区分显示电流表/电压表)
results_img = visualize_wires_and_components(img, results, elements)
# 构建返回结果
request = {
"success": True,
"recognizedImage": f"data:image/jpeg;base64,{results_img}",
"circuitData": results,
"components": elements
}
return request
# 测试入口(可选)
# if __name__ == '__main__':
# import time
# start = time.perf_counter()
# imgs_path = [r"你的测试图路径.jpg"]
# img = cv2.imread(imgs_path[0])
# if img is None:
# raise FileNotFoundError(f"无法读取图像: {imgs_path[0]}")
# result = img_recognition(img)
# end = time.perf_counter()
# print(f"处理耗时: {end - start:.2f}s")
# print(f"识别元件数: {len(result['components'])}")
# print(f"电路实例数(元件+导线): {len(result['circuitData']['instances'])}")

31
src/pre_func/run.py Normal file
View File

@ -0,0 +1,31 @@
import cv2
import os # 新增os模块导入
from img_prec import img_recognition
from utils import ensure_output_dir # 导入确保输出目录存在的工具函数
# 1. 读取图像(替换为你的图像路径)
image_path = "love.jpg"
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图像: {image_path}")
else:
# 确保output目录存在
output_dir = ensure_output_dir()
# 执行识别
result = img_recognition(img)
# 输出识别结果
print("识别是否成功:", result["success"])
print("识别到的仪器数量:", len(result["components"]))
# 修复导线数量统计从instances中筛选key为"wire"的实例
wire_count = sum(1 for item in result["circuitData"]["instances"] if item["key"] == "wire")
print("识别到的导线数量:", wire_count)
# 保存结果为JSON到output文件夹
import json
with open(os.path.join(output_dir, "recognition_result.json"), "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"识别结果已保存到 {os.path.join(output_dir, 'recognition_result.json')}")

130
src/pre_func/train_unet.py Normal file
View File

@ -0,0 +1,130 @@
import os, cv2, glob, numpy as np
from tqdm import tqdm
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ---------- Dataset ----------
class WireDataset(Dataset):
def __init__(self, img_dir, mask_dir, train=True):
self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))
aug_list = [
A.LongestMaxSize(max_size=1024),
A.PadIfNeeded(1024, 1024, border_mode=cv2.BORDER_CONSTANT, value=0),
]
if train:
aug_list += [
A.RandomCrop(512, 512),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.2),
A.RandomBrightnessContrast(0.1, 0.1, p=0.5),
A.GaussianBlur(blur_limit=3, p=0.2),
]
else:
aug_list += [A.CenterCrop(512, 512)]
self.tf = A.Compose(aug_list + [A.Normalize(), ToTensorV2()])
def __len__(self): return len(self.img_paths)
def __getitem__(self, i):
img = cv2.imread(self.img_paths[i], cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
msk = cv2.imread(self.mask_paths[i], cv2.IMREAD_GRAYSCALE)
msk = (msk <= 127).astype(np.float32)
out = self.tf(image=img, mask=msk)
return out["image"], out["mask"].unsqueeze(0)
# ---------- U-Net ----------
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
)
def forward(self, x): return self.net(x)
class UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=1, base=64):
super().__init__()
self.d1 = DoubleConv(in_ch, base)
self.d2 = DoubleConv(base, base*2)
self.d3 = DoubleConv(base*2, base*4)
self.d4 = DoubleConv(base*4, base*8)
self.bottom = DoubleConv(base*8, base*16)
self.pool = nn.MaxPool2d(2)
self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, 2)
self.u4 = DoubleConv(base*16, base*8)
self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2)
self.u3 = DoubleConv(base*8, base*4)
self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2)
self.u2 = DoubleConv(base*4, base*2)
self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2)
self.u1 = DoubleConv(base*2, base)
self.out = nn.Conv2d(base, out_ch, 1)
def forward(self, x):
c1 = self.d1(x); p1 = self.pool(c1)
c2 = self.d2(p1); p2 = self.pool(c2)
c3 = self.d3(p2); p3 = self.pool(c3)
c4 = self.d4(p3); p4 = self.pool(c4)
cb = self.bottom(p4)
x = self.up4(cb); x = torch.cat([x, c4], 1); x = self.u4(x)
x = self.up3(x); x = torch.cat([x, c3], 1); x = self.u3(x)
x = self.up2(x); x = torch.cat([x, c2], 1); x = self.u2(x)
x = self.up1(x); x = torch.cat([x, c1], 1); x = self.u1(x)
return self.out(x)
# ---------- LossBCE + Dice ----------
class DiceLoss(nn.Module):
def __init__(self, eps=1e-6): super().__init__(); self.eps=eps
def forward(self, logits, targets):
probs = torch.sigmoid(logits)
num = 2*(probs*targets).sum(dim=(2,3)) + self.eps
den = (probs+targets).sum(dim=(2,3)) + self.eps
return 1 - (num/den).mean()
def bce_dice_loss(logits, targets):
bce = nn.functional.binary_cross_entropy_with_logits(logits, targets)
dice = DiceLoss()(logits, targets)
return bce*0.5 + dice*0.5
# ---------- Train ----------
def train():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}") # 添加这行
net = UNet().to(device)
opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
# 根据你的文件路径修改这里
train_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/masks", train=True)
val_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/masks", train=False)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=2)
best = 1e9
for epoch in range(100):
net.train(); loss_sum=0
for img, msk in tqdm(train_loader, desc=f"Epoch {epoch}"):
img, msk = img.to(device), msk.to(device)
opt.zero_grad(); logits = net(img); loss = bce_dice_loss(logits, msk)
loss.backward(); opt.step(); loss_sum += loss.item()
# val
net.eval(); iou_sum=0; npx=0
with torch.no_grad():
for img, msk in val_loader:
img, msk = img.to(device), msk.to(device)
pr = (torch.sigmoid(net(img))>0.5).float()
inter = (pr*msk).sum(); union = (pr+msk - pr*msk).sum()
iou = (inter+1e-6)/(union+1e-6); iou_sum += iou.item(); npx+=1
val_iou = iou_sum/max(npx,1)
print(f"epoch {epoch} train_loss={loss_sum/len(train_loader):.4f} val_iou={val_iou:.4f}")
# 简单保存
if val_iou < best: best = val_iou
torch.save(net.state_dict(), "unet_wire_red.pth")
if __name__ == "__main__":
train()

7
src/pre_func/utils.py Normal file
View File

@ -0,0 +1,7 @@
import os
def ensure_output_dir():
"""确保output文件夹存在不存在则创建"""
if not os.path.exists("output"):
os.makedirs("output")
return "output"

130
src/pre_func/wires_recog.py Normal file
View File

@ -0,0 +1,130 @@
import cv2
import numpy as np
import os
import torch
from torchvision import transforms
from skimage.morphology import skeletonize
from train_unet import UNet
from utils import ensure_output_dir
# 初始化UNet模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet().to(device)
model.load_state_dict(torch.load("src/best_model/unet_wire_red.pth"))
model.eval() # 设置为推理模式
def preprocess_image(image, target_size=(512, 512)):
"""预处理图像以适应UNet模型输入"""
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, target_size)
img = img.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = (img - mean) / std
img = np.transpose(img, (2, 0, 1))
img = torch.tensor(img, dtype=torch.float32)
img = img.unsqueeze(0).to(device)
return img
# 修改后的函数:添加手动掩码路径参数
def get_unet_mask(image, manual_mask_path="src/pre_func/wire_recognition_input.png"):
"""获取导线掩码,支持使用手动输入的掩码图片"""
original_size = image.shape[:2]
# 如果提供了手动掩码路径,则直接读取该图片作为掩码
if manual_mask_path and os.path.exists(manual_mask_path):
# 读取手动掩码并转换为二值图像
mask = cv2.imread(manual_mask_path, cv2.IMREAD_GRAYSCALE)
# 调整掩码尺寸以匹配原始图像
mask = cv2.resize(mask, (original_size[1], original_size[0]))
# 确保掩码是二值图像0和255
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
return mask
# 否则使用UNet模型生成掩码
else:
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
output = torch.sigmoid(output)
predicted_mask = (output > 0.3).cpu().numpy().squeeze()
mask = cv2.resize(predicted_mask.astype(np.uint8),
(original_size[1], original_size[0])) * 255
cleaned_mask = mask
return cleaned_mask
def get_skeleton(binary_mask):
skel = skeletonize(binary_mask // 255).astype(np.uint8) * 255
return skel
def find_endpoints(skel_img):
endpoints = []
h, w = skel_img.shape
for y in range(1, h - 1):
for x in range(1, w - 1):
if skel_img[y, x] == 255:
patch = skel_img[y - 1:y + 2, x - 1:x + 2]
if cv2.countNonZero(patch) == 2:
endpoints.append((x, y))
return endpoints
def detect_wires_and_endpoints(image):
"""使用UNet进行导线分割然后检测端点支持传入图像数组"""
original = image
if original is None:
raise ValueError("传入的图像为空,请检查图像是否有效")
# 调整图像大小(与仪器识别保持一致的尺寸)
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
result_img = img.copy() # 只包含导线相关绘制,不含元件框
# 使用UNet获取导线掩码
mask_total = get_unet_mask(img)
# 保存UNet生成的掩码
output_dir = ensure_output_dir()
cv2.imwrite(os.path.join(output_dir, "unet_wire_mask.png"), mask_total)
# 骨架化处理
skeleton = get_skeleton(mask_total)
# 查找连通组件(不同的导线)
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(skeleton, connectivity=8)
all_wires = []
# 处理每个导线
for i in range(1, num_labels):
wire_mask = (labels == i).astype(np.uint8) * 255
pixel_count = cv2.countNonZero(wire_mask)
if pixel_count < 100:
continue
endpoints = find_endpoints(wire_mask)
if len(endpoints) >= 2:
start, end = endpoints[0], endpoints[-1]
# 仅绘制导线相关标记
cv2.circle(result_img, start, 6, (0, 0, 255), -1)
cv2.circle(result_img, end, 6, (255, 0, 0), -1)
cv2.line(result_img, start, end, (0, 255, 255), 2)
wire_data = {
"start": {"x": int(start[0]), "y": int(start[1])},
"end": {"x": int(end[0]), "y": int(end[1])},
}
all_wires.append(wire_data)
# 保存导线检测结果(不含元件框)
cv2.imwrite(os.path.join(output_dir, "wire_detection_result.png"), result_img)
return all_wires
if __name__ == "__main__":
image_path = "/home/gqw/unet_ai/love.jpg"
img = cv2.imread(image_path)
if img is None:
raise FileNotFoundError(f"无法读取图像: {image_path}")
wires = detect_wires_and_endpoints(img)
print("检测到的导线数量:", len(wires))
for i, wire in enumerate(wires):
print(f"导线 #{i+1}: 起点 {wire['start']}, 终点 {wire['end']}")

82
src/pre_func/yanzheng.py Normal file
View File

@ -0,0 +1,82 @@
import json
import base64
import os
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
# 确保output文件夹存在
def ensure_output_dir():
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
return output_dir
# 1. 读取JSON文件请替换为你的文件路径
json_path = "/home/gqw/unet_ai/output/recognition_result.json"
with open(json_path, "r", encoding="utf-8") as f:
result = json.load(f)
# 2. 从JSON中提取原始图片假设图片在最外层的recognizedImage字段根据实际结构调整
try:
base64_str = result["recognizedImage"].split(",")[1] # 去掉前缀
image_bytes = base64.b64decode(base64_str)
img = Image.open(BytesIO(image_bytes))
except KeyError:
# 若JSON中没有图片可手动指定本地图片路径请替换为你的原图路径
print("未找到recognizedImage字段使用本地图片...")
img = Image.open("/home/gqw/unet_ai/test/5.jpg") # 替换为你的原图路径
# 3. 准备绘图工具
draw = ImageDraw.Draw(img)
try:
# 尝试加载系统字体(可替换为你的字体路径)
font = ImageFont.truetype("simhei.ttf", 18) # 支持中文的字体
except:
font = ImageFont.load_default()
print("警告:未找到中文字体,可能无法正常显示中文标签")
# 4. 标注components中的仪器核心标注逻辑
# components结构[{"label": "滑动变阻器", "bbox": [x1, y1, x2, y2]}, ...]
instrument_color = (255, 0, 0) # 红色框标注仪器
for idx, comp in enumerate(result["components"]):
# 提取标签和边界框
label = comp["label"] # 仪器类型(如"滑动变阻器"
bbox = comp["bbox"] # 边界框坐标 [x1, y1, x2, y2]
x1, y1, x2, y2 = bbox
# 绘制矩形框
draw.rectangle([x1, y1, x2, y2], outline=instrument_color, width=3)
# 绘制标签(避免文字超出图片范围)
text_x = x1 if x1 + 150 < img.width else img.width - 150
text_y = y1 - 30 if y1 > 30 else y1 + 10
draw.text(
(text_x, text_y),
f"{idx+1}. {label}",
fill=instrument_color,
font=font
)
# 5. 可选标注导线从circuitData.instances中提取wire类型
wire_color = (0, 0, 255) # 蓝色框标注导线
wires = [inst for inst in result["circuitData"]["instances"] if inst["key"] == "wire"]
for idx, wire in enumerate(wires):
# 导线的位置信息x,y为中心点这里简化为小矩形
x, y = wire["x"], wire["y"]
size = 20 # 导线标注框大小(可调整)
x1, y1 = x - size, y - size
x2, y2 = x + size, y + size
draw.rectangle([x1, y1, x2, y2], outline=wire_color, width=2)
draw.text(
(x1, y1 - 20),
f"导线{idx+1}",
fill=wire_color,
font=font
)
# 6. 保存标注后的图片到output文件夹
output_dir = ensure_output_dir()
output_path = os.path.join(output_dir, "circuit_with_annotations.jpg")
img.save(output_path)
print(f"标注完成!结果已保存到:{output_path}")

BIN
src/yolo/Arial.Unicode.ttf Normal file

Binary file not shown.

24
src/yolo/config.yaml Normal file
View File

@ -0,0 +1,24 @@
# data/config.yaml
path: ./dataset # 根目录prepare_dataset.py 输出)
train: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/train
val: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/val
# nc 必须精确等于 names 列表的长度
nc: 10
names:
"0": "微安电流表" #躺着的
"1": "待测表头" #站着的
"2": "电阻箱"
"3": "滑动变阻器"
"4": "单刀双掷开关"
"5": "电源"
# <- 从 original_names.json 中拷过来的前若干项(顺序不要变)
"6": "单刀单掷开关" # 代码中存在该标签映射
"7": "灯泡" # 代码中存在该标签映射
"8": "电池电源" # 代码中存在该标签映射对应battery
"9": "电阻" # 代码中存在该标签映射
# "10": "螺线管" # 代码中存在该标签映射(对应电感)
# "11": "电容" # 代码中存在该标签映射

26
src/yolo/infer.py Normal file
View File

@ -0,0 +1,26 @@
from ultralytics import YOLO
import cv2
# 完全固定所有路径(根据你的实际情况修改以下路径)
WEIGHTS_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.16.pt' # 模型权重路径
IMAGE_PATH = '/home/gqw/unet_ai/test/love.jpg' # 待检测图片路径
OUTPUT_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 结果保存路径
CONFIDENCE = 0.25 # 置信度阈值
SAVE_RESULTS = True # 是否保存检测结果True/False
# 加载模型
model = YOLO(WEIGHTS_PATH)
# 执行检测
results = model(
IMAGE_PATH,
conf=CONFIDENCE,
save=SAVE_RESULTS,
project=OUTPUT_PATH,
name='predictions'
)
# 打印结果信息
print(f"检测图片:{IMAGE_PATH}")
print(f"检测完成,结果已保存至:{OUTPUT_PATH}/predictions")

View File

@ -0,0 +1,15 @@
# inspect_model_names.py
from ultralytics import YOLO
import argparse, json
parser = argparse.ArgumentParser()
parser.add_argument('--weights', '-w', default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/model_2.pt')
args = parser.parse_args()
m = YOLO(args.weights)
print("模型里的类别字典 (.names):")
print(m.names) # dict: {0: 'class0', 1: 'class1', ...}
# 可输出为文件
with open('original_names.json', 'w', encoding='utf-8') as f:
json.dump(m.names, f, ensure_ascii=False, indent=2)
print("保存 original_names.json")

116
src/yolo/prepare_dataset.py Normal file
View File

@ -0,0 +1,116 @@
# prepare_dataset.py
import os, shutil, random, argparse, json
from pathlib import Path
def copy_and_prefix(src_images, src_labels, out_images, out_labels, prefix):
os.makedirs(out_images, exist_ok=True)
os.makedirs(out_labels, exist_ok=True)
for img_path in Path(src_images).glob('*'):
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp']:
continue
base = img_path.name
new_img = os.path.join(out_images, f"{prefix}_{base}")
shutil.copy2(img_path, new_img)
# label
label_src = os.path.join(src_labels, img_path.with_suffix('.txt').name)
label_dst = os.path.join(out_labels, f"{prefix}_{img_path.with_suffix('.txt').name}")
if os.path.exists(label_src):
shutil.copy2(label_src, label_dst)
else:
# create empty label file (no objects)
open(label_dst, 'w', encoding='utf-8').close()
def remap_label_file(label_path, mapping):
# mapping: { old_id (str) : new_id (int) }
if not os.path.exists(label_path):
return
lines = []
with open(label_path, 'r', encoding='utf-8') as f:
for l in f:
if not l.strip(): continue
parts = l.strip().split()
cls = parts[0]
if cls in mapping:
parts[0] = str(mapping[cls])
else:
# if not in mapping, we keep it or raise
parts[0] = parts[0] # keep as is
lines.append(' '.join(parts))
with open(label_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(lines))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--sources', nargs='+', required=True,
help='一组数据源,每个源是 IMAGE_DIR:LABEL_DIR例如 /data/old/images:/data/old/labels')
parser.add_argument('--out', required=True, help='输出根目录,例如 ./dataset')
parser.add_argument('--val', type=float, default=0.2, help='验证集比例(0~1)')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--remap', default=None,
help='可选的 JSON 文件,格式 {"old_id":"new_id", ...}old_id作为字符串')
args = parser.parse_args()
random.seed(args.seed)
out = Path(args.out)
imgs_all = out / 'images'
labels_all = out / 'labels'
imgs_all.mkdir(parents=True, exist_ok=True)
labels_all.mkdir(parents=True, exist_ok=True)
# 临时收集目录
temp_images = out / 'temp_images'
temp_labels = out / 'temp_labels'
if temp_images.exists(): shutil.rmtree(temp_images)
if temp_labels.exists(): shutil.rmtree(temp_labels)
temp_images.mkdir()
temp_labels.mkdir()
# copy所有源每个源加前缀避免文件名冲突
for i, src in enumerate(args.sources):
if ':' not in src:
raise ValueError("source 格式应为 IMAGE_DIR:LABEL_DIR")
img_dir, lbl_dir = src.split(':', 1)
prefix = f"src{i}"
copy_and_prefix(img_dir, lbl_dir, str(temp_images), str(temp_labels), prefix)
# 如果需要 remap
mapping = {}
if args.remap:
with open(args.remap, 'r', encoding='utf-8') as f:
raw = json.load(f)
# keys expected as strings
mapping = {str(k): int(v) for k, v in raw.items()}
# 如果 remap 存在,则对每个临时标签文件进行 remap
if mapping:
for lf in temp_labels.glob('*.txt'):
remap_label_file(str(lf), mapping)
# 收集所有样本名
samples = [p.name for p in temp_images.glob('*') if p.suffix.lower() in ['.jpg','.jpeg','.png']]
random.shuffle(samples)
val_n = int(len(samples) * args.val)
val_samples = set(samples[:val_n])
train_samples = set(samples[val_n:])
# 创建最终目录结构
for split, sset in [('train', train_samples), ('val', val_samples)]:
img_out = out / 'images' / split
lbl_out = out / 'labels' / split
img_out.mkdir(parents=True, exist_ok=True)
lbl_out.mkdir(parents=True, exist_ok=True)
for nm in sset:
shutil.move(str(temp_images / nm), str(img_out / nm))
lbl_name = Path(nm).with_suffix('.txt').name
shutil.move(str(temp_labels / lbl_name), str(lbl_out / lbl_name))
# cleanup temp if any remaining
if temp_images.exists(): shutil.rmtree(temp_images, ignore_errors=True)
if temp_labels.exists(): shutil.rmtree(temp_labels, ignore_errors=True)
print(f"完成。输出目录: {out}")
print(f"train: {len(train_samples)}, val: {len(val_samples)}")
if __name__ == '__main__':
main()

35
src/yolo/train_yolo.py Normal file
View File

@ -0,0 +1,35 @@
# train_yolo.py
import argparse
from ultralytics import YOLO
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', '-w',
default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/best.pt',
help='预训练权重路径')
parser.add_argument('--data', '-d',
default='/home/gqw/unet_ai/elements_wires_congition/src/yolo/config.yaml',
help='data yaml 路径')
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--imgsz', type=int, default=640)
parser.add_argument('--batch', type=int, default=16)
parser.add_argument('--device', default=0, help='GPU 设备号或 "cpu"')
parser.add_argument('--freeze', action='store_true', help='是否冻结backbone')
args = parser.parse_args()
# 加载预训练模型
model = YOLO(args.weights)
print("开始训练,参数:", args)
model.train(
data=args.data,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
device=args.device,
amp=False # 新增此行禁用AMP检查避免加载yolo11n.pt
)
print("训练结束,权重保存在 runs/detect/ 下")
if __name__ == '__main__':
main()

1295
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff