diff --git a/best_model/model_2.pt b/best_model/model_2.pt new file mode 100644 index 0000000..395bbc4 Binary files /dev/null and b/best_model/model_2.pt differ diff --git a/body_congition.py b/body_congition.py new file mode 100644 index 0000000..45ca3af --- /dev/null +++ b/body_congition.py @@ -0,0 +1,63 @@ +import cv2 +import mediapipe as mp +import csv +import os + +VIDEO_PATH = r"E:\code\tmp\test.mp4" # 视频文件路径 + +# 检查视频文件存在 +if not os.path.exists(VIDEO_PATH): + raise FileNotFoundError(f"视频文件未找到: {VIDEO_PATH}") + +# 初始化 MediaPipe Pose 模块 +mp_pose = mp.solutions.pose +mp_drawing = mp.solutions.drawing_utils +pose = mp_pose.Pose( + static_image_mode=False, + model_complexity=1, + enable_segmentation=False, + min_detection_confidence=0.5, + min_tracking_confidence=0.5 +) + +# 打开视频文件 +cap = cv2.VideoCapture(VIDEO_PATH) +if not cap.isOpened(): + raise RuntimeError(f"无法打开视频文件: {VIDEO_PATH}") + +while True: + success, frame = cap.read() + if not success: + break # 视频播放完毕 + + # 可以调整窗口 + frame = cv2.resize(frame, (1920, 1080)) + + # 转为RGB格式 + image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image_rgb.flags.writeable = False + + # 骨架识别 + results = pose.process(image_rgb) + + # 绘制骨架图像 + image_rgb.flags.writeable = True + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + + if results.pose_landmarks: + mp_drawing.draw_landmarks( + image=image_bgr, + landmark_list=results.pose_landmarks, + connections=mp_pose.POSE_CONNECTIONS, + landmark_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2), + connection_drawing_spec=mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2) + ) + + # 显示画面 + cv2.imshow("跳水视频骨架识别(MediaPipe)", image_bgr) + key = cv2.waitKey(1) + if key == 27: # 按Esc键退出 + break + +cap.release() +cv2.destroyAllWindows() \ No newline at end of file diff --git a/main.py b/main.py index cbb33eb..ac97119 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from flask import Flask, request, jsonify from flask_cors import CORS import numpy as np import cv2 -from img_prec import img_recognition +from pre_func.img_prec import img_recognition app = Flask(__name__) CORS(app) diff --git a/ele_recog.py b/pre_func/ele_recog.py similarity index 84% rename from ele_recog.py rename to pre_func/ele_recog.py index 48caa4b..5023280 100644 --- a/ele_recog.py +++ b/pre_func/ele_recog.py @@ -3,13 +3,15 @@ from ultralytics import YOLO def elements_recognition(img): - model = YOLO('best_model/best.pt') + model = YOLO('../best_model/model_2.pt') original = img img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1]))) results = model(img)[0] components = [] for box in results.boxes: + if box.conf < 0.6: + continue cls = int(box.cls[0]) label = model.names[cls] x1, y1, x2, y2 = map(int, box.xyxy[0]) diff --git a/img_prec.py b/pre_func/img_prec.py similarity index 93% rename from img_prec.py rename to pre_func/img_prec.py index 791a8a0..163fb0e 100644 --- a/img_prec.py +++ b/pre_func/img_prec.py @@ -1,12 +1,11 @@ -import json import datetime -import wires_recog -import ele_recog +from pre_func import wires_recog +from pre_func import ele_recog import cv2 -import time import numpy as np import base64 + # 生成json数据 def generate_json(wires, components): vertices = [] @@ -19,7 +18,10 @@ def generate_json(wires, components): "电阻箱": "RESISTOR_BOX", "滑动变阻器": "VARIABLE_RESISTOR", "单刀双掷开关": "switch", + "单刀单掷开关": "switch", + "灯泡": "lightBulb", "电源": "BATTERY", + "电池电源": "BATTERY", "电阻": "RESISTOR", "黑盒电流表": "ammeter", "螺线管": "solenoid" @@ -92,6 +94,9 @@ def generate_json(wires, components): elif label == "resistor": elem["resistorType"] = "RESISTOR" elem["resistance"] = 1 + elif label == "lightBulb": + elem["resistance"] = 10 + elem["isReal"] = False elements.append(elem) element_count += 1 @@ -180,7 +185,10 @@ def visualize_wires_and_components(image, results, components): "电阻箱": "RESISTOR_BOX", "滑动变阻器": "VARIABLE_RESISTOR", "单刀双掷开关": "switch", + "单刀单掷开关": "switch", + "灯泡":"lightBulb", "电源": "BATTERY", + "电池电源": "BATTERY", "电阻": "RESISTOR", "黑盒电流表": "ammeter", "螺线管": "solenoid" @@ -233,10 +241,8 @@ def img_recognition(img): # if __name__ == '__main__': # start = time.perf_counter() -# imgs_path = [ -# -# ] -# for img_path in imgs_path: -# img_recognition(img_path) +# imgs_path = [r"E:\code\tmp\mmexport1754723780254.jpg"] +# img = cv2.imread(imgs_path[0]) +# img_recognition(img) # end = time.perf_counter() # print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s") diff --git a/wires_recog.py b/pre_func/wires_recog.py similarity index 98% rename from wires_recog.py rename to pre_func/wires_recog.py index 9e2b2ca..a121b80 100644 --- a/wires_recog.py +++ b/pre_func/wires_recog.py @@ -55,7 +55,8 @@ def detect_wires_and_endpoints(image): for color_name, hsv_ranges in ranges.items(): # print(f"\n🟢 正在处理颜色: {color_name.upper()}") - + if color_name == 'green': + continue mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8) for (lower, upper) in hsv_ranges: mask = extract_wire_mask(hsv, lower, upper, color_name)