Compare commits
10 Commits
18dd498210
...
9d2402d278
| Author | SHA1 | Date | |
|---|---|---|---|
| 9d2402d278 | |||
| d921497253 | |||
| e396c59d50 | |||
| a1da9df94d | |||
| ff3f886270 | |||
| 7d798195cb | |||
| d0e5885578 | |||
| 6ada6e7944 | |||
| b189f96b8d | |||
| 0521992bf2 |
BIN
best_model/model_2.pt
Normal file
BIN
best_model/model_2.pt
Normal file
Binary file not shown.
63
body_congition.py
Normal file
63
body_congition.py
Normal file
@ -0,0 +1,63 @@
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import csv
|
||||
import os
|
||||
|
||||
VIDEO_PATH = r"E:\code\tmp\test.mp4" # 视频文件路径
|
||||
|
||||
# 检查视频文件存在
|
||||
if not os.path.exists(VIDEO_PATH):
|
||||
raise FileNotFoundError(f"视频文件未找到: {VIDEO_PATH}")
|
||||
|
||||
# 初始化 MediaPipe Pose 模块
|
||||
mp_pose = mp.solutions.pose
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
pose = mp_pose.Pose(
|
||||
static_image_mode=False,
|
||||
model_complexity=1,
|
||||
enable_segmentation=False,
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5
|
||||
)
|
||||
|
||||
# 打开视频文件
|
||||
cap = cv2.VideoCapture(VIDEO_PATH)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"无法打开视频文件: {VIDEO_PATH}")
|
||||
|
||||
while True:
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
break # 视频播放完毕
|
||||
|
||||
# 可以调整窗口
|
||||
frame = cv2.resize(frame, (1920, 1080))
|
||||
|
||||
# 转为RGB格式
|
||||
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
image_rgb.flags.writeable = False
|
||||
|
||||
# 骨架识别
|
||||
results = pose.process(image_rgb)
|
||||
|
||||
# 绘制骨架图像
|
||||
image_rgb.flags.writeable = True
|
||||
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
if results.pose_landmarks:
|
||||
mp_drawing.draw_landmarks(
|
||||
image=image_bgr,
|
||||
landmark_list=results.pose_landmarks,
|
||||
connections=mp_pose.POSE_CONNECTIONS,
|
||||
landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2),
|
||||
connection_drawing_spec=mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2)
|
||||
)
|
||||
|
||||
# 显示画面
|
||||
cv2.imshow("跳水视频骨架识别(MediaPipe)", image_bgr)
|
||||
key = cv2.waitKey(1)
|
||||
if key == 27: # 按Esc键退出
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
2
main.py
2
main.py
@ -2,7 +2,7 @@ from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
import numpy as np
|
||||
import cv2
|
||||
from img_prec import img_recognition
|
||||
from pre_func.img_prec import img_recognition
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
|
||||
@ -3,13 +3,15 @@ from ultralytics import YOLO
|
||||
|
||||
|
||||
def elements_recognition(img):
|
||||
model = YOLO('best_model/best.pt')
|
||||
model = YOLO('../best_model/model_2.pt')
|
||||
original = img
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
results = model(img)[0]
|
||||
components = []
|
||||
|
||||
for box in results.boxes:
|
||||
if box.conf < 0.6:
|
||||
continue
|
||||
cls = int(box.cls[0])
|
||||
label = model.names[cls]
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
@ -1,12 +1,11 @@
|
||||
import json
|
||||
import datetime
|
||||
import wires_recog
|
||||
import ele_recog
|
||||
from pre_func import wires_recog
|
||||
from pre_func import ele_recog
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
import base64
|
||||
|
||||
|
||||
# 生成json数据
|
||||
def generate_json(wires, components):
|
||||
vertices = []
|
||||
@ -19,7 +18,13 @@ def generate_json(wires, components):
|
||||
"电阻箱": "RESISTOR_BOX",
|
||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||
"单刀双掷开关": "switch",
|
||||
"电源": "battery"
|
||||
"单刀单掷开关": "switch",
|
||||
"灯泡": "lightBulb",
|
||||
"电源": "BATTERY",
|
||||
"电池电源": "BATTERY",
|
||||
"电阻": "RESISTOR",
|
||||
"黑盒电流表": "ammeter",
|
||||
"螺线管": "solenoid"
|
||||
}
|
||||
|
||||
for comp in components:
|
||||
@ -69,14 +74,29 @@ def generate_json(wires, components):
|
||||
if label == "switch":
|
||||
elem["closed"] = False
|
||||
elif label == "ammeter":
|
||||
elem["internalResistance"] = 0.01
|
||||
elif label == "RESISTOR_BOX" or label == "VARIABLE_RESISTOR":
|
||||
elem["resistance"] = 1
|
||||
elem["type"] = "seriesAmmeter"
|
||||
elem["customLabel"] = "电流表"
|
||||
elem["customDisplayFunction"] = "i => `${i.toFixed(2)} A`"
|
||||
elif label == "voltmeter":
|
||||
elem["resistance"] = 1
|
||||
elem["type"] = "seriesAmmeter"
|
||||
elem["customLabel"] = "电压表"
|
||||
elem["customDisplayFunction"] = "i => `${i.toFixed(2)} V`"
|
||||
elif label == "RESISTOR_BOX" or label == "VARIABLE_RESISTOR" or label == "RESISTOR":
|
||||
elem["type"] = "resistor"
|
||||
elem["resistorType"] = label
|
||||
elem["resistance"] = 10
|
||||
elif label == "battery":
|
||||
elem["voltage"] = 9
|
||||
elem["batterType"] = "BATTERRY"
|
||||
elem["internalResistance"] = 0.0001
|
||||
elem["internalResistance"] = 0.01
|
||||
elif label == "resistor":
|
||||
elem["resistorType"] = "RESISTOR"
|
||||
elem["resistance"] = 1
|
||||
elif label == "lightBulb":
|
||||
elem["resistance"] = 10
|
||||
elem["isReal"] = False
|
||||
elements.append(elem)
|
||||
element_count += 1
|
||||
|
||||
@ -165,24 +185,30 @@ def visualize_wires_and_components(image, results, components):
|
||||
"电阻箱": "RESISTOR_BOX",
|
||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||
"单刀双掷开关": "switch",
|
||||
"电源": "BATTERY"
|
||||
"单刀单掷开关": "switch",
|
||||
"灯泡":"lightBulb",
|
||||
"电源": "BATTERY",
|
||||
"电池电源": "BATTERY",
|
||||
"电阻": "RESISTOR",
|
||||
"黑盒电流表": "ammeter",
|
||||
"螺线管": "solenoid"
|
||||
}
|
||||
for comp in components:
|
||||
label = labels[comp["label"]]
|
||||
x1, y1, x2, y2 = comp["bbox"]
|
||||
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(img, label, (x1, y1 - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 128, 0), 2)
|
||||
cv2.putText(img, label, (x1, y1 - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 128, 0), 2)
|
||||
_, encode_img = cv2.imencode('.jpg', img)
|
||||
img_base64 = base64.b64encode(encode_img).decode('utf-8')
|
||||
#cv2.imwrite('output.jpg', img)
|
||||
# cv2.imwrite('output.jpg', img)
|
||||
# cv2.namedWindow("Wires and Components", cv2.WINDOW_NORMAL)
|
||||
# cv2.imshow("Wires and Components", img)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
return img_base64
|
||||
# 显示图像
|
||||
# resized = cv2.resize(img, (0, 0), fx=0.6, fy=0.6)
|
||||
# cv2.namedWindow("Wires and Components", cv2.WINDOW_NORMAL)
|
||||
# cv2.imshow("Wires and Components", resized)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def img_recognition(img):
|
||||
@ -204,7 +230,8 @@ def img_recognition(img):
|
||||
request = {
|
||||
"success": True,
|
||||
"recognizedImage": f"data:image/jpeg;base64,{results_img}",
|
||||
"circuitData": results
|
||||
"circuitData": results,
|
||||
"components": elements
|
||||
}
|
||||
# with open('test.json', "w") as f:
|
||||
# json.dump(request, f, indent=2)
|
||||
@ -214,10 +241,8 @@ def img_recognition(img):
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# start = time.perf_counter()
|
||||
# imgs_path = [
|
||||
#
|
||||
# ]
|
||||
# for img_path in imgs_path:
|
||||
# img_recognition(img_path)
|
||||
# imgs_path = [r"E:\code\tmp\mmexport1754723780254.jpg"]
|
||||
# img = cv2.imread(imgs_path[0])
|
||||
# img_recognition(img)
|
||||
# end = time.perf_counter()
|
||||
# print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s")
|
||||
@ -55,7 +55,8 @@ def detect_wires_and_endpoints(image):
|
||||
|
||||
for color_name, hsv_ranges in ranges.items():
|
||||
# print(f"\n🟢 正在处理颜色: {color_name.upper()}")
|
||||
|
||||
if color_name == 'green':
|
||||
continue
|
||||
mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8)
|
||||
for (lower, upper) in hsv_ranges:
|
||||
mask = extract_wire_mask(hsv, lower, upper, color_name)
|
||||
13
pyproject.toml
Normal file
13
pyproject.toml
Normal file
@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "circuit-recognition"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"flask>=3.1.1",
|
||||
"flask-cors>=6.0.1",
|
||||
"numpy>=2.3.0",
|
||||
"scikit-image>=0.25.2",
|
||||
"ultralytics>=8.3.156",
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user