first commit
This commit is contained in:
commit
f9125370af
223
.gitignore
vendored
Normal file
223
.gitignore
vendored
Normal file
@ -0,0 +1,223 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Machine Learning / AI specific
|
||||
# Model files (you may want to track some models)
|
||||
*.pkl
|
||||
*.joblib
|
||||
*.h5
|
||||
*.hdf5
|
||||
*.pb
|
||||
*.onnx
|
||||
# Uncomment the next line if you don't want to track your trained models
|
||||
# *.pt
|
||||
|
||||
# Dataset and data files
|
||||
data/
|
||||
datasets/
|
||||
*.csv
|
||||
*.xml
|
||||
*.txt
|
||||
# Image files (uncomment if you have large image datasets)
|
||||
# *.jpg
|
||||
# *.jpeg
|
||||
# *.png
|
||||
# *.gif
|
||||
# *.bmp
|
||||
# *.tiff
|
||||
|
||||
# Logs and outputs
|
||||
logs/
|
||||
outputs/
|
||||
results/
|
||||
checkpoints/
|
||||
runs/
|
||||
wandb/
|
||||
|
||||
# TensorBoard
|
||||
events.out.tfevents.*
|
||||
|
||||
# Jupyter Notebook checkpoints
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
*.code-workspace
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
*.iws
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
Icon?
|
||||
._*
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Windows
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
*.tmp
|
||||
*.temp
|
||||
Desktop.ini
|
||||
$RECYCLE.BIN/
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
*.lnk
|
||||
|
||||
# Linux
|
||||
*~
|
||||
.fuse_hidden*
|
||||
.directory
|
||||
.Trash-*
|
||||
.nfs*
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Project specific
|
||||
# Add any project-specific files you want to ignore here
|
||||
test_images/
|
||||
temp/
|
||||
cache/
|
||||
26
circuit-recognition.service
Normal file
26
circuit-recognition.service
Normal file
@ -0,0 +1,26 @@
|
||||
[Unit]
|
||||
Description=Circuit Recognition Web Service
|
||||
After=network.target
|
||||
Wants=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=root
|
||||
Group=root
|
||||
WorkingDirectory=/home/feie9454/elements_wires_congition
|
||||
Environment=PATH=/home/feie9454/elements_wires_congition/venv/bin
|
||||
ExecStart=/home/feie9454/elements_wires_congition/venv/bin/python main.py
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=circuit-recognition
|
||||
|
||||
# 安全设置
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
ProtectSystem=strict
|
||||
ReadWritePaths=/home/feie9454/elements_wires_congition
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
80
main.py
Normal file
80
main.py
Normal file
@ -0,0 +1,80 @@
|
||||
from flask import Flask, request, jsonify, send_from_directory, send_file
|
||||
from flask_cors import CORS
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
from src.pre_func.img_prec import img_recognition
|
||||
|
||||
# 设置静态文件目录
|
||||
app = Flask(__name__, static_folder='src/web', static_url_path='/static')
|
||||
CORS(app)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""主页路由,提供index.html"""
|
||||
return send_from_directory(app.static_folder, 'index.html')
|
||||
|
||||
|
||||
@app.route('/web')
|
||||
def web_index():
|
||||
"""Web应用入口"""
|
||||
return send_from_directory(app.static_folder, 'index.html')
|
||||
|
||||
|
||||
@app.route('/web/<path:filename>')
|
||||
def web_static(filename):
|
||||
"""提供web目录下的静态文件"""
|
||||
return send_from_directory(app.static_folder, filename)
|
||||
|
||||
|
||||
@app.route('/static/<path:filename>')
|
||||
def static_files(filename):
|
||||
"""提供静态文件服务"""
|
||||
return send_from_directory(app.static_folder, filename)
|
||||
|
||||
|
||||
@app.route('/api/status')
|
||||
def api_status():
|
||||
"""API状态检查"""
|
||||
return jsonify({
|
||||
'status': 'running',
|
||||
'message': 'Elements Wires Recognition API is running',
|
||||
'endpoints': {
|
||||
'POST /process_image': 'Image recognition endpoint',
|
||||
'GET /': 'Main web interface',
|
||||
'GET /web': 'Web application',
|
||||
'GET /api/status': 'API status check'
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.route('/process_image', methods=['POST'])
|
||||
def process_image():
|
||||
if 'image' not in request.files:
|
||||
return jsonify({'error': 'No image part in the request'}), 400
|
||||
|
||||
file = request.files['image']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No selected image'}), 400
|
||||
|
||||
file_bytes = np.frombuffer(file.read(), np.uint8)
|
||||
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
||||
|
||||
result = img_recognition(img)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("🚀 Starting Elements Wires Recognition Web Server...")
|
||||
print("📁 Static files served from: src/web/")
|
||||
print("🌐 Web interface available at:")
|
||||
print(" - http://localhost:8000/")
|
||||
print(" - http://0.0.0.0:8000/")
|
||||
print("🔧 API endpoints:")
|
||||
print(" - POST /process_image (for image recognition)")
|
||||
print(" - GET /web/<filename> (for static files)")
|
||||
print(" - GET /static/<filename> (for static files)")
|
||||
print("=" * 50)
|
||||
|
||||
app.run(host='0.0.0.0', debug=True, port=12399)
|
||||
63
manage_service.sh
Normal file
63
manage_service.sh
Normal file
@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Circuit Recognition Service Management Script
|
||||
# 使用方法: ./manage_service.sh [install|start|stop|restart|status|logs|uninstall]
|
||||
|
||||
SERVICE_NAME="circuit-recognition"
|
||||
SERVICE_FILE="/home/feie9454/elements_wires_congition/circuit-recognition.service"
|
||||
SYSTEM_SERVICE_PATH="/etc/systemd/system/circuit-recognition.service"
|
||||
|
||||
case "$1" in
|
||||
install)
|
||||
echo "安装服务..."
|
||||
sudo cp "$SERVICE_FILE" "$SYSTEM_SERVICE_PATH"
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl enable "$SERVICE_NAME"
|
||||
echo "✅ 服务已安装并设置为开机自启"
|
||||
echo "使用 './manage_service.sh start' 启动服务"
|
||||
;;
|
||||
start)
|
||||
echo "启动服务..."
|
||||
sudo systemctl start "$SERVICE_NAME"
|
||||
echo "✅ 服务已启动"
|
||||
;;
|
||||
stop)
|
||||
echo "停止服务..."
|
||||
sudo systemctl stop "$SERVICE_NAME"
|
||||
echo "✅ 服务已停止"
|
||||
;;
|
||||
restart)
|
||||
echo "重启服务..."
|
||||
sudo systemctl restart "$SERVICE_NAME"
|
||||
echo "✅ 服务已重启"
|
||||
;;
|
||||
status)
|
||||
echo "服务状态:"
|
||||
sudo systemctl status "$SERVICE_NAME"
|
||||
;;
|
||||
logs)
|
||||
echo "查看服务日志:"
|
||||
sudo journalctl -u "$SERVICE_NAME" -f
|
||||
;;
|
||||
uninstall)
|
||||
echo "卸载服务..."
|
||||
sudo systemctl stop "$SERVICE_NAME"
|
||||
sudo systemctl disable "$SERVICE_NAME"
|
||||
sudo rm -f "$SYSTEM_SERVICE_PATH"
|
||||
sudo systemctl daemon-reload
|
||||
echo "✅ 服务已卸载"
|
||||
;;
|
||||
*)
|
||||
echo "使用方法: $0 {install|start|stop|restart|status|logs|uninstall}"
|
||||
echo ""
|
||||
echo "命令说明:"
|
||||
echo " install - 安装服务并设置开机自启"
|
||||
echo " start - 启动服务"
|
||||
echo " stop - 停止服务"
|
||||
echo " restart - 重启服务"
|
||||
echo " status - 查看服务状态"
|
||||
echo " logs - 实时查看服务日志"
|
||||
echo " uninstall - 卸载服务"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
15
pyproject.toml
Normal file
15
pyproject.toml
Normal file
@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "circuit-recognition"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"albumentations>=2.0.8",
|
||||
"flask>=3.1.1",
|
||||
"flask-cors>=6.0.1",
|
||||
"numpy>=2.3.0",
|
||||
"opencv-python>=4.11.0.86",
|
||||
"scikit-image>=0.25.2",
|
||||
"ultralytics>=8.3.156",
|
||||
]
|
||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# 这个文件使 src 目录成为一个 Python 包
|
||||
BIN
src/best_model/best9.17.pt
Normal file
BIN
src/best_model/best9.17.pt
Normal file
Binary file not shown.
BIN
src/best_model/unet_wire_red.pth
Normal file
BIN
src/best_model/unet_wire_red.pth
Normal file
Binary file not shown.
1
src/pre_func/__init__.py
Normal file
1
src/pre_func/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# 这个文件使 pre_func 目录成为一个 Python 包
|
||||
53
src/pre_func/color.py
Normal file
53
src/pre_func/color.py
Normal file
@ -0,0 +1,53 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
# =============================
|
||||
# 1. 加载训练好的 U-Net 模型
|
||||
# =============================
|
||||
model = load_model("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/unet_wire.pth", compile=False)
|
||||
|
||||
# 读入测试图像
|
||||
img = cv2.imread("/home/gqw/unet_ai/test/IMG_0059.jpeg")
|
||||
h, w, _ = img.shape
|
||||
input_img = cv2.resize(img, (256, 256)) / 255.0
|
||||
input_img = np.expand_dims(input_img, axis=0)
|
||||
|
||||
# =============================
|
||||
# 2. U-Net 分割预测(得到二值 mask)
|
||||
# =============================
|
||||
pred = model.predict(input_img)[0, :, :, 0]
|
||||
mask = (pred > 0.5).astype(np.uint8) # 阈值化
|
||||
mask = cv2.resize(mask, (w, h))
|
||||
|
||||
# =============================
|
||||
# 3. 提取导线区域的像素颜色
|
||||
# =============================
|
||||
wire_pixels = img[mask == 1] # N x 3 (RGB)
|
||||
|
||||
# KMeans 聚类(假设有3种颜色的导线)
|
||||
n_colors = 3
|
||||
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(wire_pixels)
|
||||
labels = kmeans.labels_
|
||||
|
||||
# =============================
|
||||
# 4. 生成每个颜色的 mask
|
||||
# =============================
|
||||
colored_masks = []
|
||||
for i in range(n_colors):
|
||||
color_mask = np.zeros(mask.shape, dtype=np.uint8)
|
||||
color_mask[mask == 1] = (labels == i).astype(np.uint8)
|
||||
colored_masks.append(color_mask)
|
||||
|
||||
# =============================
|
||||
# 5. 可视化结果
|
||||
# =============================
|
||||
result = img.copy()
|
||||
colors = [(0,0,255), (0,255,0), (255,0,0)] # 每个聚类显示的颜色
|
||||
for i, cmask in enumerate(colored_masks):
|
||||
result[cmask == 1] = colors[i % len(colors)]
|
||||
|
||||
cv2.imwrite("segmented_by_color.jpg", result)
|
||||
|
||||
print("输出完成:segmented_by_color.jpg")
|
||||
5
src/pre_func/config.py
Normal file
5
src/pre_func/config.py
Normal file
@ -0,0 +1,5 @@
|
||||
# 模型与路径配置
|
||||
YOLO_MODEL_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.17.pt' # 从infer.py迁移
|
||||
DETECTION_CONFIDENCE = 0.5 # 从infer.py迁移(可根据需要调整)
|
||||
OUTPUT_BASE_DIR = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 从infer.py迁移
|
||||
WIRE_CONNECT_THRESHOLD = 50 # 导线端点与连接点的最大允许距离(像素)
|
||||
83
src/pre_func/ele_recog.py
Normal file
83
src/pre_func/ele_recog.py
Normal file
@ -0,0 +1,83 @@
|
||||
import cv2
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
from config import YOLO_MODEL_PATH, DETECTION_CONFIDENCE
|
||||
from utils import ensure_output_dir # 导入工具函数
|
||||
|
||||
|
||||
def elements_recognition(img):
|
||||
# 加载模型
|
||||
model = YOLO(YOLO_MODEL_PATH)
|
||||
original = img
|
||||
# 保持原缩放逻辑(与导线识别尺寸一致)
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
# 推理预测
|
||||
results = model(img, conf=DETECTION_CONFIDENCE)[0]
|
||||
components = []
|
||||
|
||||
# 创建只含YOLO识别结果的图像
|
||||
yolo_img = img.copy()
|
||||
|
||||
# 标签映射(与img_prec.py保持一致)
|
||||
labels = {
|
||||
"微安电流表": "ammeter",
|
||||
"黑盒电流表": "ammeter",
|
||||
"待测表头": "voltmeter",
|
||||
"黑盒电压表": "voltmeter",
|
||||
"电阻箱": "resistance_box",
|
||||
"滑动变阻器": "sliding_rheostat",
|
||||
"单刀双掷开关": "switch",
|
||||
"单刀单掷开关": "switch",
|
||||
"灯泡": "light_bulb",
|
||||
"电源": "power_supply",
|
||||
"电池电源": "battery",
|
||||
"电阻": "resistor",
|
||||
"螺线管": "inductor",
|
||||
"电容": "capacitor"
|
||||
}
|
||||
|
||||
# 为每种元件类型定义专属颜色 (B, G, R)
|
||||
color_map = {
|
||||
"ammeter": (0, 255, 0), # 电流表:绿色
|
||||
"voltmeter": (255, 0, 0), # 电压表:蓝色
|
||||
"resistance_box": (0, 255, 255), # 电阻箱:黄色
|
||||
"sliding_rheostat": (128, 0, 128), # 滑动变阻器:紫色
|
||||
"switch": (255, 165, 0), # 开关:橙色
|
||||
"light_bulb": (255, 255, 0), # 灯泡:青色
|
||||
"power_supply": (0, 0, 255), # 电源:红色
|
||||
"battery": (128, 128, 0), # 电池:橄榄绿
|
||||
"resistor": (0, 128, 128), # 电阻:深青
|
||||
"inductor": (128, 0, 0), # 电感(螺线管):深红色
|
||||
"capacitor": (0, 0, 128) # 电容:深蓝色
|
||||
}
|
||||
|
||||
# 遍历检测结果并绘制
|
||||
for box in results.boxes:
|
||||
if box.conf < DETECTION_CONFIDENCE:
|
||||
continue
|
||||
cls = int(box.cls[0])
|
||||
original_label = model.names[cls]
|
||||
# 映射为显示标签
|
||||
display_label = labels.get(original_label, original_label)
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
|
||||
components.append({
|
||||
"label": original_label,
|
||||
"bbox": [x1, y1, x2, y2]
|
||||
})
|
||||
|
||||
# 获取当前元件的颜色(默认灰色)
|
||||
box_color = color_map.get(display_label, (128, 128, 128))
|
||||
|
||||
# 在yolo_img上绘制
|
||||
cv2.rectangle(yolo_img, (x1, y1), (x2, y2), box_color, 2)
|
||||
cv2.putText(yolo_img, display_label, (x1, y1 - 8),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
|
||||
|
||||
# 保存YOLO识别结果图像
|
||||
output_dir = ensure_output_dir()
|
||||
save_path = os.path.join(output_dir, "yolo_recognition.png")
|
||||
cv2.imwrite(save_path, yolo_img)
|
||||
print(f"YOLO识别结果已保存到: {save_path}")
|
||||
|
||||
return components
|
||||
484
src/pre_func/img_prec.py
Normal file
484
src/pre_func/img_prec.py
Normal file
@ -0,0 +1,484 @@
|
||||
import datetime
|
||||
import wires_recog
|
||||
import ele_recog
|
||||
import cv2
|
||||
import numpy as np
|
||||
import base64
|
||||
import os
|
||||
from utils import ensure_output_dir # 导入工具函数
|
||||
|
||||
# 从 elements.ts 转换过来的 Python 结构(包含电流表、电压表独立定义)
|
||||
elements_ts = [
|
||||
# 文本框
|
||||
{
|
||||
"key": "text_box",
|
||||
"name": "文本框",
|
||||
"defaultSize": 0,
|
||||
"connectionPoints": [],
|
||||
"propertySchemas": [
|
||||
{"key": "text", "label": "文本", "type": "text", "default": "双击编辑"},
|
||||
{"key": "fontSize", "label": "字体大小", "type": "number", "unit": "px", "default": 24},
|
||||
{"key": "color", "label": "颜色", "type": "color", "default": "#111827"}
|
||||
]
|
||||
},
|
||||
# 电池
|
||||
{
|
||||
"key": "battery",
|
||||
"name": "电池",
|
||||
"defaultSize": 110,
|
||||
"connectionPoints": [
|
||||
{"x": 0.95, "y": 0.5, "name": "正极"},
|
||||
{"x": 0.05, "y": 0.5, "name": "负极"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 3},
|
||||
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.5}
|
||||
]
|
||||
},
|
||||
# 电源
|
||||
{
|
||||
"key": "power_supply",
|
||||
"name": "电源",
|
||||
"defaultSize": 180,
|
||||
"connectionPoints": [
|
||||
{"x": 0.8, "y": 0.8, "name": "+"},
|
||||
{"x": 0.85, "y": 0.2, "name": "-"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "voltage", "label": "电压", "type": "number", "unit": "V", "default": 5},
|
||||
{"key": "internalResistance", "label": "内阻", "type": "number", "unit": "Ω", "default": 0.1}
|
||||
]
|
||||
},
|
||||
# 开关
|
||||
{
|
||||
"key": "switch",
|
||||
"name": "开关",
|
||||
"defaultSize": 130,
|
||||
"connectionPoints": [
|
||||
{"x": 0.05, "y": 0.7, "name": "A"},
|
||||
{"x": 0.95, "y": 0.7, "name": "B"}
|
||||
],
|
||||
"stateImages": {"on": "switchOn", "off": "switchOff"},
|
||||
"propertySchemas": [
|
||||
{"key": "state", "label": "状态", "type": "select", "options": ["off", "on"], "default": "off"}
|
||||
]
|
||||
},
|
||||
# 灯泡
|
||||
{
|
||||
"key": "light_bulb",
|
||||
"name": "灯泡",
|
||||
"defaultSize": 80,
|
||||
"connectionPoints": [
|
||||
{"x": 0.5, "y": 1, "name": "A"},
|
||||
{"x": 0.72, "y": 0.8, "name": "B"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 20}
|
||||
],
|
||||
"stateImages": {"on": "lightBulbGrow", "off": "lightBulb"}
|
||||
},
|
||||
# 滑动变阻器
|
||||
{
|
||||
"key": "sliding_rheostat",
|
||||
"name": "滑动变阻器",
|
||||
"defaultSize": 180,
|
||||
"connectionPoints": [
|
||||
{"x": 0.15, "y": 0.28, "name": "A"},
|
||||
{"x": 0.85, "y": 0.76, "name": "D"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "maxResistance", "label": "最大电阻", "type": "number", "unit": "Ω", "default": 100},
|
||||
{"key": "position", "label": "滑块位置", "type": "number", "default": 0.5}
|
||||
]
|
||||
},
|
||||
# 普通电阻
|
||||
{
|
||||
"key": "resistor",
|
||||
"name": "电阻",
|
||||
"defaultSize": 110,
|
||||
"connectionPoints": [
|
||||
{"x": 0.05, "y": 0.5, "name": "A"},
|
||||
{"x": 0.95, "y": 0.5, "name": "B"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 5}
|
||||
]
|
||||
},
|
||||
# 电阻箱
|
||||
{
|
||||
"key": "resistance_box",
|
||||
"name": "电阻箱",
|
||||
"defaultSize": 180,
|
||||
"connectionPoints": [
|
||||
{"x": 0.1, "y": 0.4, "name": "A"},
|
||||
{"x": 0.8, "y": 0.2, "name": "B"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 100}
|
||||
]
|
||||
},
|
||||
# 电流表(独立定义,替代原meter的电流表预设)
|
||||
{
|
||||
"key": "ammeter",
|
||||
"name": "电流表",
|
||||
"defaultSize": 150,
|
||||
"connectionPoints": [
|
||||
{"x": 0.25, "y": 0.8, "name": "正极"},
|
||||
{"x": 0.85, "y": 0.8, "name": "负极"}
|
||||
],
|
||||
"preset": [
|
||||
{"name": "微安电流表", "propertyValues": {"resistance": 0.05, "renderFunc": "(i) => `${(i * 1000000).toFixed(0)} uA`"}},
|
||||
{"name": "毫安电流表", "propertyValues": {"resistance": 0.5, "renderFunc": "(i) => `${(i * 1000).toFixed(1)} mA`"}},
|
||||
{"name": "电流表", "propertyValues": {"resistance": 0.1, "renderFunc": "(i) => `${(i).toFixed(2)} A`"}}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 0.1},
|
||||
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i) => `${(i).toFixed(2)} A`"}
|
||||
]
|
||||
},
|
||||
# 电压表(独立定义,替代原meter的电压表预设)
|
||||
{
|
||||
"key": "voltmeter",
|
||||
"name": "电压表",
|
||||
"defaultSize": 150,
|
||||
"connectionPoints": [
|
||||
{"x": 0.25, "y": 1, "name": "正极"},
|
||||
{"x": 0.45, "y": 1, "name": "负极"}
|
||||
],
|
||||
"preset": [
|
||||
{"name": "电压表", "propertyValues": {"resistance": 10000000, "renderFunc": "(i, r) => `${(i * r).toFixed(2)} V`"}}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "resistance", "label": "电阻", "type": "number", "unit": "Ω", "default": 10000000},
|
||||
{"key": "renderFunc", "label": "显示函数", "type": "text", "default": "(i, r) => `${(i * r).toFixed(2)} V`"}
|
||||
]
|
||||
},
|
||||
# 电容
|
||||
{
|
||||
"key": "capacitor",
|
||||
"name": "电容",
|
||||
"defaultSize": 120,
|
||||
"connectionPoints": [
|
||||
{"x": 0.4, "y": 0.9, "name": "正极"},
|
||||
{"x": 0.6, "y": 0.96, "name": "负极"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "capacitance", "label": "电容", "type": "number", "unit": "F", "default": 0.000001}
|
||||
]
|
||||
},
|
||||
# 电感
|
||||
{
|
||||
"key": "inductor",
|
||||
"name": "电感",
|
||||
"defaultSize": 100,
|
||||
"connectionPoints": [
|
||||
{"x": 0.05, "y": 0.5, "name": "A"},
|
||||
{"x": 0.95, "y": 0.5, "name": "B"}
|
||||
],
|
||||
"propertySchemas": [
|
||||
{"key": "inductance", "label": "电感", "type": "number", "unit": "H", "default": 0.001}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def generate_json(wires, components):
|
||||
# 标签映射(保持不变)
|
||||
labels = {
|
||||
"微安电流表": "ammeter", "待测表头(电流表)": "ammeter", "黑盒电流表": "ammeter",
|
||||
"待测表头": "voltmeter", "黑盒电压表": "voltmeter",
|
||||
"电阻箱": "resistance_box", "滑动变阻器": "sliding_rheostat",
|
||||
"单刀双掷开关": "switch", "单刀单掷开关": "switch", "灯泡": "light_bulb",
|
||||
"电源": "power_supply", "电池电源": "battery", "电阻": "resistor",
|
||||
"螺线管": "inductor", "电容": "capacitor"
|
||||
}
|
||||
|
||||
instances = []
|
||||
element_count = 0
|
||||
# 新增:存储每个元件的连接点实际坐标(格式:{元件id: [{"x": 实际x, "y": 实际y, "cpIndex": 连接点索引}, ...]})
|
||||
component_connection_points = {}
|
||||
|
||||
# 处理元件(新增:计算连接点实际坐标)
|
||||
for comp in components:
|
||||
bbox = comp["bbox"]
|
||||
original_label = comp["label"]
|
||||
label = labels.get(original_label, original_label)
|
||||
|
||||
x = (bbox[0] + bbox[2]) / 2 # 元件中心x
|
||||
y = (bbox[1] + bbox[3]) / 2 # 元件中心y
|
||||
comp_id = str(element_count + 1) # 当前元件id
|
||||
|
||||
# 获取默认尺寸和连接点定义
|
||||
default_size = 100
|
||||
conn_points = []
|
||||
for elem in elements_ts:
|
||||
if elem["key"] == label:
|
||||
default_size = elem.get("defaultSize", 100)
|
||||
conn_points = elem.get("connectionPoints", [])
|
||||
break
|
||||
|
||||
# 计算连接点实际坐标(基于边界框)
|
||||
comp_width = bbox[2] - bbox[0]
|
||||
comp_height = bbox[3] - bbox[1]
|
||||
actual_conn_points = []
|
||||
for cp_idx, cp in enumerate(conn_points):
|
||||
# 连接点在图像中的实际坐标(相对边界框的比例转换)
|
||||
actual_x = bbox[0] + cp["x"] * comp_width
|
||||
actual_y = bbox[1] + cp["y"] * comp_height
|
||||
actual_conn_points.append({
|
||||
"x": actual_x,
|
||||
"y": actual_y,
|
||||
"cpIndex": cp_idx
|
||||
})
|
||||
# 存储当前元件的连接点信息
|
||||
component_connection_points[comp_id] = actual_conn_points
|
||||
|
||||
# 元件实例(保持不变)
|
||||
instance = {
|
||||
"id": comp_id,
|
||||
"key": label,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"size": default_size,
|
||||
"rotation": 0,
|
||||
"props": {
|
||||
"__connections": {},
|
||||
"originalLabel": original_label
|
||||
}
|
||||
}
|
||||
# ...(省略元件属性赋值,保持不变)
|
||||
instances.append(instance)
|
||||
element_count += 1
|
||||
|
||||
# 处理导线(基于坐标匹配连接点)
|
||||
wire_id = element_count + 1
|
||||
for wire in wires:
|
||||
# 导线起点和终点的实际坐标(来自wires_recog的检测结果)
|
||||
wire_start = (wire["start"]["x"], wire["start"]["y"])
|
||||
wire_end = (wire["end"]["x"], wire["end"]["y"])
|
||||
|
||||
# 找到距离导线起点最近的元件连接点
|
||||
min_dist_start = float('inf')
|
||||
start_match = None # 格式:(元件id, 连接点索引)
|
||||
for comp_id, cp_list in component_connection_points.items():
|
||||
for cp in cp_list:
|
||||
# 计算欧氏距离
|
||||
dist = np.hypot(cp["x"] - wire_start[0], cp["y"] - wire_start[1])
|
||||
if dist < min_dist_start:
|
||||
min_dist_start = dist
|
||||
start_match = (comp_id, cp["cpIndex"])
|
||||
|
||||
# 找到距离导线终点最近的元件连接点
|
||||
min_dist_end = float('inf')
|
||||
end_match = None
|
||||
for comp_id, cp_list in component_connection_points.items():
|
||||
for cp in cp_list:
|
||||
dist = np.hypot(cp["x"] - wire_end[0], cp["y"] - wire_end[1])
|
||||
if dist < min_dist_end:
|
||||
min_dist_end = dist
|
||||
end_match = (comp_id, cp["cpIndex"])
|
||||
|
||||
# 处理无匹配的情况(避免报错)
|
||||
if not start_match or not end_match:
|
||||
# 若没有匹配到,默认连接第一个元件(兼容旧逻辑,但尽量避免)
|
||||
start_match = (instances[0]["id"], 0) if instances else (str(wire_id), 0)
|
||||
end_match = (instances[-1]["id"], 0) if instances else (str(wire_id + 1), 0)
|
||||
|
||||
# 添加导线实例(使用匹配到的连接点)
|
||||
instances.append({
|
||||
"id": str(wire_id),
|
||||
"key": "wire",
|
||||
"x": (wire_start[0] + wire_end[0]) / 2,
|
||||
"y": (wire_start[1] + wire_end[1]) / 2,
|
||||
"size": 1,
|
||||
"rotation": 0,
|
||||
"props": {
|
||||
"__connections": {
|
||||
"0": [{"instId": start_match[0], "cpIndex": start_match[1]}],
|
||||
"1": [{"instId": end_match[0], "cpIndex": end_match[1]}]
|
||||
}
|
||||
}
|
||||
})
|
||||
wire_id += 1
|
||||
|
||||
# 生成最终JSON(保持不变)
|
||||
return {
|
||||
"version": 2,
|
||||
"world": {
|
||||
"scale": 1,
|
||||
"translateX": -2089,
|
||||
"translateY": -2123,
|
||||
"worldSize": 5000,
|
||||
"gridSize": 40
|
||||
},
|
||||
"instances": instances
|
||||
}
|
||||
|
||||
def visualize_wires_and_components(image, results, components):
|
||||
# 确保output目录存在
|
||||
def ensure_output_dir():
|
||||
if not os.path.exists("output"):
|
||||
os.makedirs("output")
|
||||
return "output"
|
||||
|
||||
original = image
|
||||
# 缩放图像(与元件识别尺寸一致,确保坐标匹配)
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
if img is None:
|
||||
raise FileNotFoundError("无法读取图像")
|
||||
|
||||
# === 步骤1:建立元件ID映射(适配ammeter/voltmeter)===
|
||||
comp_id_to_info = {}
|
||||
# 可视化用的标签映射(与generate_json保持一致)
|
||||
labels = {
|
||||
"微安电流表": "ammeter",
|
||||
"黑盒电流表": "ammeter",
|
||||
"待测表头": "voltmeter",
|
||||
"黑盒电压表": "voltmeter",
|
||||
"电阻箱": "resistance_box",
|
||||
"滑动变阻器": "sliding_rheostat",
|
||||
"单刀双掷开关": "switch",
|
||||
"单刀单掷开关": "switch",
|
||||
"灯泡": "light_bulb",
|
||||
"电源": "power_supply",
|
||||
"电池电源": "battery",
|
||||
"电阻": "resistor",
|
||||
"螺线管": "inductor",
|
||||
"电容": "capacitor"
|
||||
}
|
||||
|
||||
for idx, comp in enumerate(components):
|
||||
comp_id = str(idx + 1)
|
||||
bbox = comp["bbox"]
|
||||
original_label = comp["label"]
|
||||
comp_type = labels.get(original_label, original_label)
|
||||
|
||||
# 从elements_ts获取当前元件的连接点(兼容ammeter/voltmeter)
|
||||
conn_points = []
|
||||
for elem in elements_ts:
|
||||
if elem["key"] == comp_type:
|
||||
conn_points = elem["connectionPoints"]
|
||||
break
|
||||
|
||||
# 计算连接点实际像素坐标
|
||||
comp_width = bbox[2] - bbox[0]
|
||||
comp_height = bbox[3] - bbox[1]
|
||||
points = []
|
||||
for p in conn_points:
|
||||
x = bbox[0] + (p["x"] * comp_width)
|
||||
y = bbox[1] + (p["y"] * comp_height)
|
||||
points.append({"x": int(x), "y": int(y), "name": p["name"]})
|
||||
|
||||
# 保存元件信息(用于后续绘制)
|
||||
comp_id_to_info[comp_id] = {
|
||||
"bbox": bbox,
|
||||
"conn_points": points,
|
||||
"type": comp_type,
|
||||
"displayLabel": comp_type # 显示ammeter/voltmeter,明确区分
|
||||
}
|
||||
|
||||
# === 步骤2:绘制元件连接点(红色圆点)===
|
||||
for comp_info in comp_id_to_info.values():
|
||||
for p in comp_info["conn_points"]:
|
||||
cv2.circle(img, (p["x"], p["y"]), 6, (0, 0, 255), -1)
|
||||
cv2.putText(img, p["name"], (p["x"]+5, p["y"]-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
||||
|
||||
# === 步骤3:绘制导线(黄色线,区分起点/终点)===
|
||||
for instance in results["instances"]:
|
||||
if instance["key"] != "wire":
|
||||
continue
|
||||
conn_info = instance["props"]["__connections"]
|
||||
if "0" not in conn_info or "1" not in conn_info:
|
||||
continue
|
||||
|
||||
# 获取导线连接的元件和连接点
|
||||
start_conn = conn_info["0"][0]
|
||||
end_conn = conn_info["1"][0]
|
||||
start_comp_id = start_conn["instId"]
|
||||
end_comp_id = end_conn["instId"]
|
||||
start_cp_idx = start_conn["cpIndex"]
|
||||
end_cp_idx = end_conn["cpIndex"]
|
||||
|
||||
# 跳过不存在的元件
|
||||
if start_comp_id not in comp_id_to_info or end_comp_id not in comp_id_to_info:
|
||||
continue
|
||||
|
||||
# 绘制导线
|
||||
start_point = comp_id_to_info[start_comp_id]["conn_points"][start_cp_idx]
|
||||
end_point = comp_id_to_info[end_comp_id]["conn_points"][end_cp_idx]
|
||||
cv2.line(img, (start_point["x"], start_point["y"]),
|
||||
(end_point["x"], end_point["y"]), (0, 255, 255), 2)
|
||||
cv2.circle(img, (start_point["x"], start_point["y"]), 6, (0, 0, 255), -1) # 起点红
|
||||
cv2.circle(img, (end_point["x"], end_point["y"]), 6, (255, 0, 0), -1) # 终点蓝
|
||||
cv2.putText(img, "start", (start_point["x"]+5, start_point["y"]-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
||||
cv2.putText(img, "end", (end_point["x"]+5, end_point["y"]-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
|
||||
|
||||
# === 步骤4:绘制元件框和标签(区分电流表/电压表)===
|
||||
for idx, comp in enumerate(components):
|
||||
comp_id = str(idx + 1)
|
||||
if comp_id not in comp_id_to_info:
|
||||
continue
|
||||
comp_info = comp_id_to_info[comp_id]
|
||||
bbox = comp_info["bbox"]
|
||||
display_label = comp_info["displayLabel"] # 显示ammeter/voltmeter
|
||||
|
||||
# 绘制边界框(电流表绿色,电压表蓝色,其他默认绿色)
|
||||
if display_label == "ammeter":
|
||||
box_color = (0, 255, 0) # 电流表绿色
|
||||
elif display_label == "voltmeter":
|
||||
box_color = (255, 0, 0) # 电压表蓝色
|
||||
else:
|
||||
box_color = (0, 128, 0) # 其他元件深绿色
|
||||
|
||||
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), box_color, 2)
|
||||
# 绘制标签(与框同色)
|
||||
cv2.putText(img, display_label, (bbox[0], bbox[1] - 8),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 2)
|
||||
|
||||
# 保存并返回可视化结果
|
||||
output_dir = ensure_output_dir()
|
||||
save_path = os.path.join(output_dir, "recognized_result.jpg")
|
||||
cv2.imwrite(save_path, img)
|
||||
print(f"可视化结果已保存到: {save_path}")
|
||||
|
||||
# 转换为base64
|
||||
_, encode_img = cv2.imencode('.jpg', img)
|
||||
img_base64 = base64.b64encode(encode_img).decode('utf-8')
|
||||
return img_base64
|
||||
|
||||
|
||||
def img_recognition(img):
|
||||
# 1. 检测导线
|
||||
wires = wires_recog.detect_wires_and_endpoints(img)
|
||||
# 2. 识别元件(模型输出原始标签)
|
||||
elements = ele_recog.elements_recognition(img)
|
||||
# 3. 生成区分后的JSON(ammeter/voltmeter)
|
||||
results = generate_json(wires, elements)
|
||||
# 4. 可视化(区分显示电流表/电压表)
|
||||
results_img = visualize_wires_and_components(img, results, elements)
|
||||
|
||||
# 构建返回结果
|
||||
request = {
|
||||
"success": True,
|
||||
"recognizedImage": f"data:image/jpeg;base64,{results_img}",
|
||||
"circuitData": results,
|
||||
"components": elements
|
||||
}
|
||||
return request
|
||||
|
||||
|
||||
# 测试入口(可选)
|
||||
# if __name__ == '__main__':
|
||||
# import time
|
||||
# start = time.perf_counter()
|
||||
# imgs_path = [r"你的测试图路径.jpg"]
|
||||
# img = cv2.imread(imgs_path[0])
|
||||
# if img is None:
|
||||
# raise FileNotFoundError(f"无法读取图像: {imgs_path[0]}")
|
||||
# result = img_recognition(img)
|
||||
# end = time.perf_counter()
|
||||
# print(f"处理耗时: {end - start:.2f}s")
|
||||
# print(f"识别元件数: {len(result['components'])}")
|
||||
# print(f"电路实例数(元件+导线): {len(result['circuitData']['instances'])}")
|
||||
31
src/pre_func/run.py
Normal file
31
src/pre_func/run.py
Normal file
@ -0,0 +1,31 @@
|
||||
import cv2
|
||||
import os # 新增os模块导入
|
||||
from img_prec import img_recognition
|
||||
from utils import ensure_output_dir # 导入确保输出目录存在的工具函数
|
||||
|
||||
# 1. 读取图像(替换为你的图像路径)
|
||||
image_path = "love.jpg"
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
if img is None:
|
||||
print(f"无法读取图像: {image_path}")
|
||||
else:
|
||||
# 确保output目录存在
|
||||
output_dir = ensure_output_dir()
|
||||
|
||||
# 执行识别
|
||||
result = img_recognition(img)
|
||||
|
||||
# 输出识别结果
|
||||
print("识别是否成功:", result["success"])
|
||||
print("识别到的仪器数量:", len(result["components"]))
|
||||
|
||||
# 修复导线数量统计:从instances中筛选key为"wire"的实例
|
||||
wire_count = sum(1 for item in result["circuitData"]["instances"] if item["key"] == "wire")
|
||||
print("识别到的导线数量:", wire_count)
|
||||
|
||||
# 保存结果为JSON到output文件夹
|
||||
import json
|
||||
with open(os.path.join(output_dir, "recognition_result.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"识别结果已保存到 {os.path.join(output_dir, 'recognition_result.json')}")
|
||||
130
src/pre_func/train_unet.py
Normal file
130
src/pre_func/train_unet.py
Normal file
@ -0,0 +1,130 @@
|
||||
import os, cv2, glob, numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch, torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
# ---------- Dataset ----------
|
||||
class WireDataset(Dataset):
|
||||
def __init__(self, img_dir, mask_dir, train=True):
|
||||
self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
|
||||
self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))
|
||||
aug_list = [
|
||||
A.LongestMaxSize(max_size=1024),
|
||||
A.PadIfNeeded(1024, 1024, border_mode=cv2.BORDER_CONSTANT, value=0),
|
||||
]
|
||||
if train:
|
||||
aug_list += [
|
||||
A.RandomCrop(512, 512),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.VerticalFlip(p=0.2),
|
||||
A.RandomBrightnessContrast(0.1, 0.1, p=0.5),
|
||||
A.GaussianBlur(blur_limit=3, p=0.2),
|
||||
]
|
||||
else:
|
||||
aug_list += [A.CenterCrop(512, 512)]
|
||||
self.tf = A.Compose(aug_list + [A.Normalize(), ToTensorV2()])
|
||||
|
||||
def __len__(self): return len(self.img_paths)
|
||||
|
||||
def __getitem__(self, i):
|
||||
img = cv2.imread(self.img_paths[i], cv2.IMREAD_COLOR)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
msk = cv2.imread(self.mask_paths[i], cv2.IMREAD_GRAYSCALE)
|
||||
msk = (msk <= 127).astype(np.float32)
|
||||
out = self.tf(image=img, mask=msk)
|
||||
return out["image"], out["mask"].unsqueeze(0)
|
||||
|
||||
# ---------- U-Net ----------
|
||||
class DoubleConv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
|
||||
)
|
||||
def forward(self, x): return self.net(x)
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1, base=64):
|
||||
super().__init__()
|
||||
self.d1 = DoubleConv(in_ch, base)
|
||||
self.d2 = DoubleConv(base, base*2)
|
||||
self.d3 = DoubleConv(base*2, base*4)
|
||||
self.d4 = DoubleConv(base*4, base*8)
|
||||
self.bottom = DoubleConv(base*8, base*16)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, 2)
|
||||
self.u4 = DoubleConv(base*16, base*8)
|
||||
self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, 2)
|
||||
self.u3 = DoubleConv(base*8, base*4)
|
||||
self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2)
|
||||
self.u2 = DoubleConv(base*4, base*2)
|
||||
self.up1 = nn.ConvTranspose2d(base*2, base, 2, 2)
|
||||
self.u1 = DoubleConv(base*2, base)
|
||||
self.out = nn.Conv2d(base, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
c1 = self.d1(x); p1 = self.pool(c1)
|
||||
c2 = self.d2(p1); p2 = self.pool(c2)
|
||||
c3 = self.d3(p2); p3 = self.pool(c3)
|
||||
c4 = self.d4(p3); p4 = self.pool(c4)
|
||||
cb = self.bottom(p4)
|
||||
x = self.up4(cb); x = torch.cat([x, c4], 1); x = self.u4(x)
|
||||
x = self.up3(x); x = torch.cat([x, c3], 1); x = self.u3(x)
|
||||
x = self.up2(x); x = torch.cat([x, c2], 1); x = self.u2(x)
|
||||
x = self.up1(x); x = torch.cat([x, c1], 1); x = self.u1(x)
|
||||
return self.out(x)
|
||||
|
||||
# ---------- Loss(BCE + Dice) ----------
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, eps=1e-6): super().__init__(); self.eps=eps
|
||||
def forward(self, logits, targets):
|
||||
probs = torch.sigmoid(logits)
|
||||
num = 2*(probs*targets).sum(dim=(2,3)) + self.eps
|
||||
den = (probs+targets).sum(dim=(2,3)) + self.eps
|
||||
return 1 - (num/den).mean()
|
||||
|
||||
def bce_dice_loss(logits, targets):
|
||||
bce = nn.functional.binary_cross_entropy_with_logits(logits, targets)
|
||||
dice = DiceLoss()(logits, targets)
|
||||
return bce*0.5 + dice*0.5
|
||||
|
||||
# ---------- Train ----------
|
||||
def train():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}") # 添加这行
|
||||
net = UNet().to(device)
|
||||
opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
|
||||
# 根据你的文件路径修改这里
|
||||
train_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/train/masks", train=True)
|
||||
val_set = WireDataset("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/images", "/home/gqw/unet_ai/elements_wires_congition/src/pre_func/data_red/val/masks", train=False)
|
||||
|
||||
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=2)
|
||||
|
||||
best = 1e9
|
||||
for epoch in range(100):
|
||||
net.train(); loss_sum=0
|
||||
for img, msk in tqdm(train_loader, desc=f"Epoch {epoch}"):
|
||||
img, msk = img.to(device), msk.to(device)
|
||||
opt.zero_grad(); logits = net(img); loss = bce_dice_loss(logits, msk)
|
||||
loss.backward(); opt.step(); loss_sum += loss.item()
|
||||
# val
|
||||
net.eval(); iou_sum=0; npx=0
|
||||
with torch.no_grad():
|
||||
for img, msk in val_loader:
|
||||
img, msk = img.to(device), msk.to(device)
|
||||
pr = (torch.sigmoid(net(img))>0.5).float()
|
||||
inter = (pr*msk).sum(); union = (pr+msk - pr*msk).sum()
|
||||
iou = (inter+1e-6)/(union+1e-6); iou_sum += iou.item(); npx+=1
|
||||
val_iou = iou_sum/max(npx,1)
|
||||
print(f"epoch {epoch} train_loss={loss_sum/len(train_loader):.4f} val_iou={val_iou:.4f}")
|
||||
# 简单保存
|
||||
if val_iou < best: best = val_iou
|
||||
torch.save(net.state_dict(), "unet_wire_red.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
7
src/pre_func/utils.py
Normal file
7
src/pre_func/utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os
|
||||
|
||||
def ensure_output_dir():
|
||||
"""确保output文件夹存在,不存在则创建"""
|
||||
if not os.path.exists("output"):
|
||||
os.makedirs("output")
|
||||
return "output"
|
||||
130
src/pre_func/wires_recog.py
Normal file
130
src/pre_func/wires_recog.py
Normal file
@ -0,0 +1,130 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from skimage.morphology import skeletonize
|
||||
from train_unet import UNet
|
||||
from utils import ensure_output_dir
|
||||
|
||||
# 初始化UNet模型
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = UNet().to(device)
|
||||
model.load_state_dict(torch.load("src/best_model/unet_wire_red.pth"))
|
||||
model.eval() # 设置为推理模式
|
||||
|
||||
def preprocess_image(image, target_size=(512, 512)):
|
||||
"""预处理图像以适应UNet模型输入"""
|
||||
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, target_size)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
img = (img - mean) / std
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
img = torch.tensor(img, dtype=torch.float32)
|
||||
img = img.unsqueeze(0).to(device)
|
||||
return img
|
||||
|
||||
# 修改后的函数:添加手动掩码路径参数
|
||||
def get_unet_mask(image, manual_mask_path="src/pre_func/wire_recognition_input.png"):
|
||||
"""获取导线掩码,支持使用手动输入的掩码图片"""
|
||||
original_size = image.shape[:2]
|
||||
|
||||
# 如果提供了手动掩码路径,则直接读取该图片作为掩码
|
||||
if manual_mask_path and os.path.exists(manual_mask_path):
|
||||
# 读取手动掩码并转换为二值图像
|
||||
mask = cv2.imread(manual_mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
# 调整掩码尺寸以匹配原始图像
|
||||
mask = cv2.resize(mask, (original_size[1], original_size[0]))
|
||||
# 确保掩码是二值图像(0和255)
|
||||
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
||||
return mask
|
||||
# 否则使用UNet模型生成掩码
|
||||
else:
|
||||
input_tensor = preprocess_image(image)
|
||||
with torch.no_grad():
|
||||
output = model(input_tensor)
|
||||
output = torch.sigmoid(output)
|
||||
predicted_mask = (output > 0.3).cpu().numpy().squeeze()
|
||||
mask = cv2.resize(predicted_mask.astype(np.uint8),
|
||||
(original_size[1], original_size[0])) * 255
|
||||
cleaned_mask = mask
|
||||
return cleaned_mask
|
||||
|
||||
def get_skeleton(binary_mask):
|
||||
skel = skeletonize(binary_mask // 255).astype(np.uint8) * 255
|
||||
return skel
|
||||
|
||||
def find_endpoints(skel_img):
|
||||
endpoints = []
|
||||
h, w = skel_img.shape
|
||||
for y in range(1, h - 1):
|
||||
for x in range(1, w - 1):
|
||||
if skel_img[y, x] == 255:
|
||||
patch = skel_img[y - 1:y + 2, x - 1:x + 2]
|
||||
if cv2.countNonZero(patch) == 2:
|
||||
endpoints.append((x, y))
|
||||
return endpoints
|
||||
|
||||
def detect_wires_and_endpoints(image):
|
||||
"""使用UNet进行导线分割,然后检测端点(支持传入图像数组)"""
|
||||
original = image
|
||||
if original is None:
|
||||
raise ValueError("传入的图像为空,请检查图像是否有效")
|
||||
|
||||
# 调整图像大小(与仪器识别保持一致的尺寸)
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
result_img = img.copy() # 只包含导线相关绘制,不含元件框
|
||||
|
||||
# 使用UNet获取导线掩码
|
||||
mask_total = get_unet_mask(img)
|
||||
|
||||
# 保存UNet生成的掩码
|
||||
output_dir = ensure_output_dir()
|
||||
cv2.imwrite(os.path.join(output_dir, "unet_wire_mask.png"), mask_total)
|
||||
|
||||
# 骨架化处理
|
||||
skeleton = get_skeleton(mask_total)
|
||||
|
||||
# 查找连通组件(不同的导线)
|
||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(skeleton, connectivity=8)
|
||||
all_wires = []
|
||||
|
||||
# 处理每个导线
|
||||
for i in range(1, num_labels):
|
||||
wire_mask = (labels == i).astype(np.uint8) * 255
|
||||
pixel_count = cv2.countNonZero(wire_mask)
|
||||
|
||||
if pixel_count < 100:
|
||||
continue
|
||||
|
||||
endpoints = find_endpoints(wire_mask)
|
||||
if len(endpoints) >= 2:
|
||||
start, end = endpoints[0], endpoints[-1]
|
||||
|
||||
# 仅绘制导线相关标记
|
||||
cv2.circle(result_img, start, 6, (0, 0, 255), -1)
|
||||
cv2.circle(result_img, end, 6, (255, 0, 0), -1)
|
||||
cv2.line(result_img, start, end, (0, 255, 255), 2)
|
||||
|
||||
wire_data = {
|
||||
"start": {"x": int(start[0]), "y": int(start[1])},
|
||||
"end": {"x": int(end[0]), "y": int(end[1])},
|
||||
}
|
||||
all_wires.append(wire_data)
|
||||
|
||||
# 保存导线检测结果(不含元件框)
|
||||
cv2.imwrite(os.path.join(output_dir, "wire_detection_result.png"), result_img)
|
||||
|
||||
return all_wires
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_path = "/home/gqw/unet_ai/love.jpg"
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"无法读取图像: {image_path}")
|
||||
wires = detect_wires_and_endpoints(img)
|
||||
print("检测到的导线数量:", len(wires))
|
||||
for i, wire in enumerate(wires):
|
||||
print(f"导线 #{i+1}: 起点 {wire['start']}, 终点 {wire['end']}")
|
||||
82
src/pre_func/yanzheng.py
Normal file
82
src/pre_func/yanzheng.py
Normal file
@ -0,0 +1,82 @@
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
# 确保output文件夹存在
|
||||
def ensure_output_dir():
|
||||
output_dir = "output"
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
return output_dir
|
||||
|
||||
# 1. 读取JSON文件(请替换为你的文件路径)
|
||||
json_path = "/home/gqw/unet_ai/output/recognition_result.json"
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
# 2. 从JSON中提取原始图片(假设图片在最外层的recognizedImage字段,根据实际结构调整)
|
||||
try:
|
||||
base64_str = result["recognizedImage"].split(",")[1] # 去掉前缀
|
||||
image_bytes = base64.b64decode(base64_str)
|
||||
img = Image.open(BytesIO(image_bytes))
|
||||
except KeyError:
|
||||
# 若JSON中没有图片,可手动指定本地图片路径(请替换为你的原图路径)
|
||||
print("未找到recognizedImage字段,使用本地图片...")
|
||||
img = Image.open("/home/gqw/unet_ai/test/5.jpg") # 替换为你的原图路径
|
||||
|
||||
# 3. 准备绘图工具
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
# 尝试加载系统字体(可替换为你的字体路径)
|
||||
font = ImageFont.truetype("simhei.ttf", 18) # 支持中文的字体
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
print("警告:未找到中文字体,可能无法正常显示中文标签")
|
||||
|
||||
# 4. 标注components中的仪器(核心标注逻辑)
|
||||
# components结构:[{"label": "滑动变阻器", "bbox": [x1, y1, x2, y2]}, ...]
|
||||
instrument_color = (255, 0, 0) # 红色框标注仪器
|
||||
for idx, comp in enumerate(result["components"]):
|
||||
# 提取标签和边界框
|
||||
label = comp["label"] # 仪器类型(如"滑动变阻器")
|
||||
bbox = comp["bbox"] # 边界框坐标 [x1, y1, x2, y2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# 绘制矩形框
|
||||
draw.rectangle([x1, y1, x2, y2], outline=instrument_color, width=3)
|
||||
|
||||
# 绘制标签(避免文字超出图片范围)
|
||||
text_x = x1 if x1 + 150 < img.width else img.width - 150
|
||||
text_y = y1 - 30 if y1 > 30 else y1 + 10
|
||||
draw.text(
|
||||
(text_x, text_y),
|
||||
f"{idx+1}. {label}",
|
||||
fill=instrument_color,
|
||||
font=font
|
||||
)
|
||||
|
||||
# 5. (可选)标注导线(从circuitData.instances中提取wire类型)
|
||||
wire_color = (0, 0, 255) # 蓝色框标注导线
|
||||
wires = [inst for inst in result["circuitData"]["instances"] if inst["key"] == "wire"]
|
||||
for idx, wire in enumerate(wires):
|
||||
# 导线的位置信息(x,y为中心点,这里简化为小矩形)
|
||||
x, y = wire["x"], wire["y"]
|
||||
size = 20 # 导线标注框大小(可调整)
|
||||
x1, y1 = x - size, y - size
|
||||
x2, y2 = x + size, y + size
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=wire_color, width=2)
|
||||
draw.text(
|
||||
(x1, y1 - 20),
|
||||
f"导线{idx+1}",
|
||||
fill=wire_color,
|
||||
font=font
|
||||
)
|
||||
|
||||
# 6. 保存标注后的图片到output文件夹
|
||||
output_dir = ensure_output_dir()
|
||||
output_path = os.path.join(output_dir, "circuit_with_annotations.jpg")
|
||||
img.save(output_path)
|
||||
print(f"标注完成!结果已保存到:{output_path}")
|
||||
BIN
src/yolo/Arial.Unicode.ttf
Normal file
BIN
src/yolo/Arial.Unicode.ttf
Normal file
Binary file not shown.
24
src/yolo/config.yaml
Normal file
24
src/yolo/config.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
# data/config.yaml
|
||||
path: ./dataset # 根目录(prepare_dataset.py 输出)
|
||||
train: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/train
|
||||
val: /home/gqw/unet_ai/elements_wires_congition/src/yolo/dataset/images/val
|
||||
|
||||
# nc 必须精确等于 names 列表的长度
|
||||
nc: 10
|
||||
|
||||
names:
|
||||
"0": "微安电流表" #躺着的
|
||||
"1": "待测表头" #站着的
|
||||
"2": "电阻箱"
|
||||
"3": "滑动变阻器"
|
||||
"4": "单刀双掷开关"
|
||||
"5": "电源"
|
||||
# <- 从 original_names.json 中拷过来的前若干项(顺序不要变)
|
||||
"6": "单刀单掷开关" # 代码中存在该标签映射
|
||||
"7": "灯泡" # 代码中存在该标签映射
|
||||
"8": "电池电源" # 代码中存在该标签映射(对应battery)
|
||||
"9": "电阻" # 代码中存在该标签映射
|
||||
# "10": "螺线管" # 代码中存在该标签映射(对应电感)
|
||||
# "11": "电容" # 代码中存在该标签映射
|
||||
|
||||
|
||||
26
src/yolo/infer.py
Normal file
26
src/yolo/infer.py
Normal file
@ -0,0 +1,26 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
# 完全固定所有路径(根据你的实际情况修改以下路径)
|
||||
WEIGHTS_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/best_model/best9.16.pt' # 模型权重路径
|
||||
IMAGE_PATH = '/home/gqw/unet_ai/test/love.jpg' # 待检测图片路径
|
||||
OUTPUT_PATH = '/home/gqw/unet_ai/elements_wires_congition/src/yolo/detect_results' # 结果保存路径
|
||||
CONFIDENCE = 0.25 # 置信度阈值
|
||||
SAVE_RESULTS = True # 是否保存检测结果(True/False)
|
||||
|
||||
# 加载模型
|
||||
model = YOLO(WEIGHTS_PATH)
|
||||
|
||||
# 执行检测
|
||||
results = model(
|
||||
IMAGE_PATH,
|
||||
conf=CONFIDENCE,
|
||||
save=SAVE_RESULTS,
|
||||
project=OUTPUT_PATH,
|
||||
name='predictions'
|
||||
)
|
||||
|
||||
# 打印结果信息
|
||||
print(f"检测图片:{IMAGE_PATH}")
|
||||
print(f"检测完成,结果已保存至:{OUTPUT_PATH}/predictions")
|
||||
|
||||
15
src/yolo/inspect_model_names.py
Normal file
15
src/yolo/inspect_model_names.py
Normal file
@ -0,0 +1,15 @@
|
||||
# inspect_model_names.py
|
||||
from ultralytics import YOLO
|
||||
import argparse, json
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', '-w', default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/model_2.pt')
|
||||
args = parser.parse_args()
|
||||
|
||||
m = YOLO(args.weights)
|
||||
print("模型里的类别字典 (.names):")
|
||||
print(m.names) # dict: {0: 'class0', 1: 'class1', ...}
|
||||
# 可输出为文件
|
||||
with open('original_names.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(m.names, f, ensure_ascii=False, indent=2)
|
||||
print("保存 original_names.json")
|
||||
116
src/yolo/prepare_dataset.py
Normal file
116
src/yolo/prepare_dataset.py
Normal file
@ -0,0 +1,116 @@
|
||||
# prepare_dataset.py
|
||||
import os, shutil, random, argparse, json
|
||||
from pathlib import Path
|
||||
|
||||
def copy_and_prefix(src_images, src_labels, out_images, out_labels, prefix):
|
||||
os.makedirs(out_images, exist_ok=True)
|
||||
os.makedirs(out_labels, exist_ok=True)
|
||||
for img_path in Path(src_images).glob('*'):
|
||||
if img_path.suffix.lower() not in ['.jpg', '.jpeg', '.png', '.bmp']:
|
||||
continue
|
||||
base = img_path.name
|
||||
new_img = os.path.join(out_images, f"{prefix}_{base}")
|
||||
shutil.copy2(img_path, new_img)
|
||||
# label
|
||||
label_src = os.path.join(src_labels, img_path.with_suffix('.txt').name)
|
||||
label_dst = os.path.join(out_labels, f"{prefix}_{img_path.with_suffix('.txt').name}")
|
||||
if os.path.exists(label_src):
|
||||
shutil.copy2(label_src, label_dst)
|
||||
else:
|
||||
# create empty label file (no objects)
|
||||
open(label_dst, 'w', encoding='utf-8').close()
|
||||
|
||||
def remap_label_file(label_path, mapping):
|
||||
# mapping: { old_id (str) : new_id (int) }
|
||||
if not os.path.exists(label_path):
|
||||
return
|
||||
lines = []
|
||||
with open(label_path, 'r', encoding='utf-8') as f:
|
||||
for l in f:
|
||||
if not l.strip(): continue
|
||||
parts = l.strip().split()
|
||||
cls = parts[0]
|
||||
if cls in mapping:
|
||||
parts[0] = str(mapping[cls])
|
||||
else:
|
||||
# if not in mapping, we keep it or raise
|
||||
parts[0] = parts[0] # keep as is
|
||||
lines.append(' '.join(parts))
|
||||
with open(label_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(lines))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--sources', nargs='+', required=True,
|
||||
help='一组数据源,每个源是 IMAGE_DIR:LABEL_DIR,例如 /data/old/images:/data/old/labels')
|
||||
parser.add_argument('--out', required=True, help='输出根目录,例如 ./dataset')
|
||||
parser.add_argument('--val', type=float, default=0.2, help='验证集比例(0~1)')
|
||||
parser.add_argument('--seed', type=int, default=42)
|
||||
parser.add_argument('--remap', default=None,
|
||||
help='可选的 JSON 文件,格式 {"old_id":"new_id", ...}(old_id作为字符串)')
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
out = Path(args.out)
|
||||
imgs_all = out / 'images'
|
||||
labels_all = out / 'labels'
|
||||
imgs_all.mkdir(parents=True, exist_ok=True)
|
||||
labels_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 临时收集目录
|
||||
temp_images = out / 'temp_images'
|
||||
temp_labels = out / 'temp_labels'
|
||||
if temp_images.exists(): shutil.rmtree(temp_images)
|
||||
if temp_labels.exists(): shutil.rmtree(temp_labels)
|
||||
temp_images.mkdir()
|
||||
temp_labels.mkdir()
|
||||
|
||||
# copy所有源,每个源加前缀避免文件名冲突
|
||||
for i, src in enumerate(args.sources):
|
||||
if ':' not in src:
|
||||
raise ValueError("source 格式应为 IMAGE_DIR:LABEL_DIR")
|
||||
img_dir, lbl_dir = src.split(':', 1)
|
||||
prefix = f"src{i}"
|
||||
copy_and_prefix(img_dir, lbl_dir, str(temp_images), str(temp_labels), prefix)
|
||||
|
||||
# 如果需要 remap
|
||||
mapping = {}
|
||||
if args.remap:
|
||||
with open(args.remap, 'r', encoding='utf-8') as f:
|
||||
raw = json.load(f)
|
||||
# keys expected as strings
|
||||
mapping = {str(k): int(v) for k, v in raw.items()}
|
||||
|
||||
# 如果 remap 存在,则对每个临时标签文件进行 remap
|
||||
if mapping:
|
||||
for lf in temp_labels.glob('*.txt'):
|
||||
remap_label_file(str(lf), mapping)
|
||||
|
||||
# 收集所有样本名
|
||||
samples = [p.name for p in temp_images.glob('*') if p.suffix.lower() in ['.jpg','.jpeg','.png']]
|
||||
random.shuffle(samples)
|
||||
|
||||
val_n = int(len(samples) * args.val)
|
||||
val_samples = set(samples[:val_n])
|
||||
train_samples = set(samples[val_n:])
|
||||
|
||||
# 创建最终目录结构
|
||||
for split, sset in [('train', train_samples), ('val', val_samples)]:
|
||||
img_out = out / 'images' / split
|
||||
lbl_out = out / 'labels' / split
|
||||
img_out.mkdir(parents=True, exist_ok=True)
|
||||
lbl_out.mkdir(parents=True, exist_ok=True)
|
||||
for nm in sset:
|
||||
shutil.move(str(temp_images / nm), str(img_out / nm))
|
||||
lbl_name = Path(nm).with_suffix('.txt').name
|
||||
shutil.move(str(temp_labels / lbl_name), str(lbl_out / lbl_name))
|
||||
|
||||
# cleanup temp if any remaining
|
||||
if temp_images.exists(): shutil.rmtree(temp_images, ignore_errors=True)
|
||||
if temp_labels.exists(): shutil.rmtree(temp_labels, ignore_errors=True)
|
||||
|
||||
print(f"完成。输出目录: {out}")
|
||||
print(f"train: {len(train_samples)}, val: {len(val_samples)}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
35
src/yolo/train_yolo.py
Normal file
35
src/yolo/train_yolo.py
Normal file
@ -0,0 +1,35 @@
|
||||
# train_yolo.py
|
||||
import argparse
|
||||
from ultralytics import YOLO
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', '-w',
|
||||
default='/home/gqw/unet_ai/elements_wires_congition/src/best_model/best.pt',
|
||||
help='预训练权重路径')
|
||||
parser.add_argument('--data', '-d',
|
||||
default='/home/gqw/unet_ai/elements_wires_congition/src/yolo/config.yaml',
|
||||
help='data yaml 路径')
|
||||
parser.add_argument('--epochs', type=int, default=200)
|
||||
parser.add_argument('--imgsz', type=int, default=640)
|
||||
parser.add_argument('--batch', type=int, default=16)
|
||||
parser.add_argument('--device', default=0, help='GPU 设备号或 "cpu"')
|
||||
parser.add_argument('--freeze', action='store_true', help='是否冻结backbone')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载预训练模型
|
||||
model = YOLO(args.weights)
|
||||
print("开始训练,参数:", args)
|
||||
model.train(
|
||||
data=args.data,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
device=args.device,
|
||||
amp=False # 新增此行:禁用AMP检查,避免加载yolo11n.pt
|
||||
)
|
||||
print("训练结束,权重保存在 runs/detect/ 下")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user