version 0.12
model_2.pt 是专门用来演示的模型 body_congition.py 是识别人体骨骼的
This commit is contained in:
parent
d921497253
commit
9d2402d278
BIN
best_model/model_2.pt
Normal file
BIN
best_model/model_2.pt
Normal file
Binary file not shown.
63
body_congition.py
Normal file
63
body_congition.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import cv2
|
||||||
|
import mediapipe as mp
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
VIDEO_PATH = r"E:\code\tmp\test.mp4" # 视频文件路径
|
||||||
|
|
||||||
|
# 检查视频文件存在
|
||||||
|
if not os.path.exists(VIDEO_PATH):
|
||||||
|
raise FileNotFoundError(f"视频文件未找到: {VIDEO_PATH}")
|
||||||
|
|
||||||
|
# 初始化 MediaPipe Pose 模块
|
||||||
|
mp_pose = mp.solutions.pose
|
||||||
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
|
pose = mp_pose.Pose(
|
||||||
|
static_image_mode=False,
|
||||||
|
model_complexity=1,
|
||||||
|
enable_segmentation=False,
|
||||||
|
min_detection_confidence=0.5,
|
||||||
|
min_tracking_confidence=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打开视频文件
|
||||||
|
cap = cv2.VideoCapture(VIDEO_PATH)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise RuntimeError(f"无法打开视频文件: {VIDEO_PATH}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
success, frame = cap.read()
|
||||||
|
if not success:
|
||||||
|
break # 视频播放完毕
|
||||||
|
|
||||||
|
# 可以调整窗口
|
||||||
|
frame = cv2.resize(frame, (1920, 1080))
|
||||||
|
|
||||||
|
# 转为RGB格式
|
||||||
|
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
image_rgb.flags.writeable = False
|
||||||
|
|
||||||
|
# 骨架识别
|
||||||
|
results = pose.process(image_rgb)
|
||||||
|
|
||||||
|
# 绘制骨架图像
|
||||||
|
image_rgb.flags.writeable = True
|
||||||
|
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
if results.pose_landmarks:
|
||||||
|
mp_drawing.draw_landmarks(
|
||||||
|
image=image_bgr,
|
||||||
|
landmark_list=results.pose_landmarks,
|
||||||
|
connections=mp_pose.POSE_CONNECTIONS,
|
||||||
|
landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2),
|
||||||
|
connection_drawing_spec=mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 显示画面
|
||||||
|
cv2.imshow("跳水视频骨架识别(MediaPipe)", image_bgr)
|
||||||
|
key = cv2.waitKey(1)
|
||||||
|
if key == 27: # 按Esc键退出
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
2
main.py
2
main.py
@ -2,7 +2,7 @@ from flask import Flask, request, jsonify
|
|||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from img_prec import img_recognition
|
from pre_func.img_prec import img_recognition
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app)
|
CORS(app)
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,15 @@ from ultralytics import YOLO
|
|||||||
|
|
||||||
|
|
||||||
def elements_recognition(img):
|
def elements_recognition(img):
|
||||||
model = YOLO('best_model/best.pt')
|
model = YOLO('../best_model/model_2.pt')
|
||||||
original = img
|
original = img
|
||||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||||
results = model(img)[0]
|
results = model(img)[0]
|
||||||
components = []
|
components = []
|
||||||
|
|
||||||
for box in results.boxes:
|
for box in results.boxes:
|
||||||
|
if box.conf < 0.6:
|
||||||
|
continue
|
||||||
cls = int(box.cls[0])
|
cls = int(box.cls[0])
|
||||||
label = model.names[cls]
|
label = model.names[cls]
|
||||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||||
@ -1,12 +1,11 @@
|
|||||||
import json
|
|
||||||
import datetime
|
import datetime
|
||||||
import wires_recog
|
from pre_func import wires_recog
|
||||||
import ele_recog
|
from pre_func import ele_recog
|
||||||
import cv2
|
import cv2
|
||||||
import time
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
# 生成json数据
|
# 生成json数据
|
||||||
def generate_json(wires, components):
|
def generate_json(wires, components):
|
||||||
vertices = []
|
vertices = []
|
||||||
@ -19,7 +18,10 @@ def generate_json(wires, components):
|
|||||||
"电阻箱": "RESISTOR_BOX",
|
"电阻箱": "RESISTOR_BOX",
|
||||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||||
"单刀双掷开关": "switch",
|
"单刀双掷开关": "switch",
|
||||||
|
"单刀单掷开关": "switch",
|
||||||
|
"灯泡": "lightBulb",
|
||||||
"电源": "BATTERY",
|
"电源": "BATTERY",
|
||||||
|
"电池电源": "BATTERY",
|
||||||
"电阻": "RESISTOR",
|
"电阻": "RESISTOR",
|
||||||
"黑盒电流表": "ammeter",
|
"黑盒电流表": "ammeter",
|
||||||
"螺线管": "solenoid"
|
"螺线管": "solenoid"
|
||||||
@ -92,6 +94,9 @@ def generate_json(wires, components):
|
|||||||
elif label == "resistor":
|
elif label == "resistor":
|
||||||
elem["resistorType"] = "RESISTOR"
|
elem["resistorType"] = "RESISTOR"
|
||||||
elem["resistance"] = 1
|
elem["resistance"] = 1
|
||||||
|
elif label == "lightBulb":
|
||||||
|
elem["resistance"] = 10
|
||||||
|
elem["isReal"] = False
|
||||||
elements.append(elem)
|
elements.append(elem)
|
||||||
element_count += 1
|
element_count += 1
|
||||||
|
|
||||||
@ -180,7 +185,10 @@ def visualize_wires_and_components(image, results, components):
|
|||||||
"电阻箱": "RESISTOR_BOX",
|
"电阻箱": "RESISTOR_BOX",
|
||||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||||
"单刀双掷开关": "switch",
|
"单刀双掷开关": "switch",
|
||||||
|
"单刀单掷开关": "switch",
|
||||||
|
"灯泡":"lightBulb",
|
||||||
"电源": "BATTERY",
|
"电源": "BATTERY",
|
||||||
|
"电池电源": "BATTERY",
|
||||||
"电阻": "RESISTOR",
|
"电阻": "RESISTOR",
|
||||||
"黑盒电流表": "ammeter",
|
"黑盒电流表": "ammeter",
|
||||||
"螺线管": "solenoid"
|
"螺线管": "solenoid"
|
||||||
@ -233,10 +241,8 @@ def img_recognition(img):
|
|||||||
|
|
||||||
# if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
# start = time.perf_counter()
|
# start = time.perf_counter()
|
||||||
# imgs_path = [
|
# imgs_path = [r"E:\code\tmp\mmexport1754723780254.jpg"]
|
||||||
#
|
# img = cv2.imread(imgs_path[0])
|
||||||
# ]
|
# img_recognition(img)
|
||||||
# for img_path in imgs_path:
|
|
||||||
# img_recognition(img_path)
|
|
||||||
# end = time.perf_counter()
|
# end = time.perf_counter()
|
||||||
# print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s")
|
# print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s")
|
||||||
@ -55,7 +55,8 @@ def detect_wires_and_endpoints(image):
|
|||||||
|
|
||||||
for color_name, hsv_ranges in ranges.items():
|
for color_name, hsv_ranges in ranges.items():
|
||||||
# print(f"\n🟢 正在处理颜色: {color_name.upper()}")
|
# print(f"\n🟢 正在处理颜色: {color_name.upper()}")
|
||||||
|
if color_name == 'green':
|
||||||
|
continue
|
||||||
mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8)
|
mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8)
|
||||||
for (lower, upper) in hsv_ranges:
|
for (lower, upper) in hsv_ranges:
|
||||||
mask = extract_wire_mask(hsv, lower, upper, color_name)
|
mask = extract_wire_mask(hsv, lower, upper, color_name)
|
||||||
Loading…
x
Reference in New Issue
Block a user