commit 62bc6d7d55f89b1a77418c8ca84bbe0fb8ca4262 Author: 碳灰面包 <3266919615@qq.com> Date: Wed Jun 18 19:19:55 2025 +0800 version 0.1 diff --git a/best_model/best.pt b/best_model/best.pt new file mode 100644 index 0000000..108a88e Binary files /dev/null and b/best_model/best.pt differ diff --git a/ele_recog.py b/ele_recog.py new file mode 100644 index 0000000..48caa4b --- /dev/null +++ b/ele_recog.py @@ -0,0 +1,24 @@ +import cv2 +from ultralytics import YOLO + + +def elements_recognition(img): + model = YOLO('best_model/best.pt') + original = img + img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1]))) + results = model(img)[0] + components = [] + + for box in results.boxes: + cls = int(box.cls[0]) + label = model.names[cls] + x1, y1, x2, y2 = map(int, box.xyxy[0]) + components.append({ + "label": label, + "bbox": [x1, y1, x2, y2] + }) + + return components + + + diff --git a/img_prec.py b/img_prec.py new file mode 100644 index 0000000..772267d --- /dev/null +++ b/img_prec.py @@ -0,0 +1,223 @@ +import json +import datetime +import wires_recog +import ele_recog +import cv2 +import time +import numpy as np +import base64 + +# 生成json数据 +def generate_json(wires, components): + vertices = [] + elements = [] + + vertix_count = 0 + element_count = 0 + labels = {"微安电流表": "ammeter", + "待测表头": "ammeter", + "电阻箱": "RESISTOR_BOX", + "滑动变阻器": "VARIABLE_RESISTOR", + "单刀双掷开关": "switch", + "电源": "battery" + } + + for comp in components: + bbox = comp["bbox"] + label = labels[comp["label"]] + elem = { + "id": f"element_{element_count}", + "startVertexId": "", + "endVertexId": "", + "type": label, + } + if bbox[2]-bbox[0] >= bbox[3] - bbox[1]: + vertix = { + "id": f"vertix_{vertix_count}", + "x": bbox[0] + ((bbox[2] - bbox[0]) / 10), + "y": bbox[1] + ((bbox[3] - bbox[1]) / 3), + } + elem["startVertexId"] = vertix["id"] + vertices.append(vertix) + vertix_count += 1 + vertix = { + "id": f"vertix_{vertix_count}", + "x": bbox[0] + ((bbox[2] - bbox[0]) * 9 / 10), + "y": bbox[1] + ((bbox[3] - bbox[1]) / 3), + } + elem["endVertexId"] = vertix["id"] + vertices.append(vertix) + vertix_count += 1 + else: + vertix = { + "id": f"vertix_{vertix_count}", + "x": bbox[0] + ((bbox[2] - bbox[0]) / 2), + "y": bbox[1] + ((bbox[3] - bbox[1]) / 9), + } + elem["startVertexId"] = vertix["id"] + vertices.append(vertix) + vertix_count += 1 + vertix = { + "id": f"vertix_{vertix_count}", + "x": bbox[0] + ((bbox[2] - bbox[0]) / 2), + "y": bbox[1] + ((bbox[3] - bbox[1]) * 8 / 9), + } + elem["endVertexId"] = vertix["id"] + vertices.append(vertix) + vertix_count += 1 + + if label == "switch": + elem["closed"] = False + elif label == "ammeter": + elem["internalResistance"] = 0.01 + elif label == "RESISTOR_BOX" or label == "VARIABLE_RESISTOR": + elem["type"] = "resistor" + elem["resistorType"] = label + elif label == "battery": + elem["voltage"] = 9 + elem["batterType"] = "BATTERRY" + elem["internalResistance"] = 0.0001 + elements.append(elem) + element_count += 1 + + def find_nearest(point): + min_dist = float('inf') + nearest_vertex = None + for vertex in vertices: + ver = (vertex["x"], vertex["y"]) + dist = np.linalg.norm(np.array(point) - np.array(ver)) + if dist < min_dist: + min_dist = dist + nearest_vertex = vertex + return nearest_vertex + # 加入wire + for wire in wires: + wire_start = (wire["start"]["x"], wire["start"]["y"]) + wire_end = (wire["end"]["x"], wire["end"]["y"]) + nearest_start = find_nearest(wire_start) + nearest_end = find_nearest(wire_end) + + elements.append({ + "id": f"element_{element_count}", + "startVertexId": nearest_start["id"], + "endVertexId": nearest_end["id"], + "type": "wire", + "resistance": 3e-8 + }) + element_count += 1 + + data = { + "formatVersion": "1.0", + "metadata": { + "title": "Exported Circuit", + "description": "Circuit exported from image", + "created": datetime.datetime.now(datetime.UTC).isoformat() + "Z" + }, + "vertices": vertices, + "elements": elements, + "displaySettings": { + "showCurrent": True, + "currentType": "electrons", + "wireResistivity": 1e-10, + "sourceResistance": 0.0001 + } + } + return data + + +def visualize_wires_and_components(image, results, components): + original = image + img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1]))) + if img is None: + raise FileNotFoundError(f"无法读取图像:") + + for p in results["vertices"]: + point = (int(p["x"]), int(p["y"])) + cv2.circle(img, point, 6, (0, 0, 255), -1) + + # ==== 画导线 ==== + for wire in results["elements"]: + if wire["type"] != "wire": + continue + point = wire["startVertexId"] + for v in results["vertices"]: + if v["id"] == point: + point = v + + start = (int(point["x"]), int(point["y"])) + point = wire["endVertexId"] + for v in results["vertices"]: + if v["id"] == point: + point = v + end = (int(point["x"]), int(point["y"])) + + # 起点:红色,终点:蓝色 + cv2.circle(img, start, 6, (0, 0, 255), -1) + cv2.circle(img, end, 6, (255, 0, 0), -1) + cv2.line(img, start, end, (0, 255, 255), 2) + + cv2.putText(img, "start", (start[0]+5, start[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1) + cv2.putText(img, "end", (end[0]+5, end[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1) + + # ==== 画元件框 ==== + labels = {"微安电流表": "ammeter", + "待测表头": "ammeter", + "电阻箱": "RESISTOR_BOX", + "滑动变阻器": "VARIABLE_RESISTOR", + "单刀双掷开关": "switch", + "电源": "BATTERY" + } + for comp in components: + label = labels[comp["label"]] + x1, y1, x2, y2 = comp["bbox"] + + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText(img, label, (x1, y1 - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + _, encode_img = cv2.imencode('.jpg', img) + img_base64 = base64.b64encode(encode_img).decode('utf-8') + #cv2.imwrite('output.jpg', img) + return img_base64 + # 显示图像 + # resized = cv2.resize(img, (0, 0), fx=0.6, fy=0.6) + # cv2.namedWindow("Wires and Components", cv2.WINDOW_NORMAL) + # cv2.imshow("Wires and Components", resized) + # cv2.waitKey(0) + # cv2.destroyAllWindows() + + +def img_recognition(img): + wires = wires_recog.detect_wires_and_endpoints(img) + elements = ele_recog.elements_recognition(img) + results = generate_json(wires, elements) + results_img = visualize_wires_and_components(img, results, elements) + sumx = 0 + sumy = 0 + for tmp in results["vertices"]: + sumx += tmp["x"] + sumy += tmp["y"] + for i in range(len(results["vertices"])): + results["vertices"][i]["x"] -= sumx/len(results["vertices"]) + results["vertices"][i]["y"] -= sumy/len(results["vertices"]) + results["vertices"][i]["x"] *= 0.6 + results["vertices"][i]["y"] *= 0.6 + + request = { + "success": True, + "recognizedImage": f"data:image/jpeg;base64,{results_img}", + "circuitData": results + } + # with open('test.json', "w") as f: + # json.dump(request, f, indent=2) + # print(f"✅ 已导出电路 JSON 至 {'result_json'}") + return request + + +# if __name__ == '__main__': +# start = time.perf_counter() +# imgs_path = [ +# +# ] +# for img_path in imgs_path: +# img_recognition(img_path) +# end = time.perf_counter() +# print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s") diff --git a/mian.py b/mian.py new file mode 100644 index 0000000..cbb33eb --- /dev/null +++ b/mian.py @@ -0,0 +1,27 @@ +from flask import Flask, request, jsonify +from flask_cors import CORS +import numpy as np +import cv2 +from img_prec import img_recognition +app = Flask(__name__) +CORS(app) + + +@app.route('/process_image', methods=['POST']) +def process_image(): + if 'image' not in request.files: + return jsonify({'error': 'No image part in the request'}), 400 + + file = request.files['image'] + if file.filename == '': + return jsonify({'error': 'No selected image'}), 400 + + file_bytes = np.frombuffer(file.read(), np.uint8) + img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) + + result = img_recognition(img) + return result + + +if __name__ == '__main__': + app.run(debug=True, port=8000) diff --git a/wires_recog.py b/wires_recog.py new file mode 100644 index 0000000..9e2b2ca --- /dev/null +++ b/wires_recog.py @@ -0,0 +1,103 @@ +import cv2 +import numpy as np +from skimage.morphology import skeletonize + + +def show(title, img, scale=0.6): + resized = cv2.resize(img, (0, 0), fx=scale, fy=scale) + cv2.imshow(title, resized) + cv2.waitKey(0) + + +def extract_wire_mask(hsv_img, lower, upper, name='color'): + mask = cv2.inRange(hsv_img, lower, upper) + # show(f"{name} - begin", mask) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) + closed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) + dilated = cv2.dilate(closed, kernel, iterations=1) + # show(f"{name} - process + exbound", dilated) + return dilated + + +def get_skeleton(binary_mask, name='skeleton'): + skel = skeletonize(binary_mask // 255).astype(np.uint8) * 255 + # show(f"{name} - bouns", skel) + return skel + + +def find_endpoints(skel_img): + endpoints = [] + h, w = skel_img.shape + for y in range(1, h - 1): + for x in range(1, w - 1): + if skel_img[y, x] == 255: + patch = skel_img[y - 1:y + 2, x - 1:x + 2] + if cv2.countNonZero(patch) == 2: + endpoints.append((x, y)) + return endpoints + + +def detect_wires_and_endpoints(image): + original = image + img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1]))) + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + + ranges = { + 'red': [(np.array([0, 112, 38]), np.array([8, 255, 255])), + (np.array([160, 70, 50]), np.array([180, 255, 255]))], + 'green': [(np.array([35, 80, 80]), np.array([85, 255, 255]))], + 'yellow': [(np.array([19, 115, 103]), np.array([35, 255, 255]))] + } + + result_img = img.copy() + all_wires = [] + + for color_name, hsv_ranges in ranges.items(): + # print(f"\n🟢 正在处理颜色: {color_name.upper()}") + + mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8) + for (lower, upper) in hsv_ranges: + mask = extract_wire_mask(hsv, lower, upper, color_name) + mask_total = cv2.bitwise_or(mask_total, mask) + + skeleton = get_skeleton(mask_total, f"{color_name}_skeleton") + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(skeleton, connectivity=8) + + for i in range(1, num_labels): + wire_mask = (labels == i).astype(np.uint8) * 255 + pixel_count = cv2.countNonZero(wire_mask) + + if pixel_count < 100: + continue + + endpoints = find_endpoints(wire_mask) + wire_vis = cv2.cvtColor(wire_mask, cv2.COLOR_GRAY2BGR) + + if len(endpoints) >= 2: + start, end = endpoints[0], endpoints[-1] + # print(f"✅ {color_name.upper()}导线 #{i}: 起点 {start},终点 {end},像素数 {pixel_count}") + cv2.circle(wire_vis, start, 6, (0, 0, 255), -1) + cv2.circle(wire_vis, end, 6, (255, 0, 0), -1) + cv2.line(wire_vis, start, end, (0, 255, 255), 2) + + cv2.circle(result_img, start, 6, (0, 0, 255), -1) + cv2.circle(result_img, end, 6, (255, 0, 0), -1) + cv2.line(result_img, start, end, (0, 255, 255), 2) + + # 保存导线数据 + wire_data = { + "start": {"x": int(start[0]), "y": int(start[1])}, + "end": {"x": int(end[0]), "y": int(end[1])}, + } + all_wires.append(wire_data) + + # show(f"{color_name.upper()} 导线 #{i}", wire_vis) + + # 显示图像 + # cv2.imshow('tmp', result_img) + # cv2.waitKey(0) + return all_wires + +# 示例调用 +# detect_wires_and_endpoints("img/5.jpg")