version 0.1
This commit is contained in:
commit
62bc6d7d55
BIN
best_model/best.pt
Normal file
BIN
best_model/best.pt
Normal file
Binary file not shown.
24
ele_recog.py
Normal file
24
ele_recog.py
Normal file
@ -0,0 +1,24 @@
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def elements_recognition(img):
|
||||
model = YOLO('best_model/best.pt')
|
||||
original = img
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
results = model(img)[0]
|
||||
components = []
|
||||
|
||||
for box in results.boxes:
|
||||
cls = int(box.cls[0])
|
||||
label = model.names[cls]
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
components.append({
|
||||
"label": label,
|
||||
"bbox": [x1, y1, x2, y2]
|
||||
})
|
||||
|
||||
return components
|
||||
|
||||
|
||||
|
||||
223
img_prec.py
Normal file
223
img_prec.py
Normal file
@ -0,0 +1,223 @@
|
||||
import json
|
||||
import datetime
|
||||
import wires_recog
|
||||
import ele_recog
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
import base64
|
||||
|
||||
# 生成json数据
|
||||
def generate_json(wires, components):
|
||||
vertices = []
|
||||
elements = []
|
||||
|
||||
vertix_count = 0
|
||||
element_count = 0
|
||||
labels = {"微安电流表": "ammeter",
|
||||
"待测表头": "ammeter",
|
||||
"电阻箱": "RESISTOR_BOX",
|
||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||
"单刀双掷开关": "switch",
|
||||
"电源": "battery"
|
||||
}
|
||||
|
||||
for comp in components:
|
||||
bbox = comp["bbox"]
|
||||
label = labels[comp["label"]]
|
||||
elem = {
|
||||
"id": f"element_{element_count}",
|
||||
"startVertexId": "",
|
||||
"endVertexId": "",
|
||||
"type": label,
|
||||
}
|
||||
if bbox[2]-bbox[0] >= bbox[3] - bbox[1]:
|
||||
vertix = {
|
||||
"id": f"vertix_{vertix_count}",
|
||||
"x": bbox[0] + ((bbox[2] - bbox[0]) / 10),
|
||||
"y": bbox[1] + ((bbox[3] - bbox[1]) / 3),
|
||||
}
|
||||
elem["startVertexId"] = vertix["id"]
|
||||
vertices.append(vertix)
|
||||
vertix_count += 1
|
||||
vertix = {
|
||||
"id": f"vertix_{vertix_count}",
|
||||
"x": bbox[0] + ((bbox[2] - bbox[0]) * 9 / 10),
|
||||
"y": bbox[1] + ((bbox[3] - bbox[1]) / 3),
|
||||
}
|
||||
elem["endVertexId"] = vertix["id"]
|
||||
vertices.append(vertix)
|
||||
vertix_count += 1
|
||||
else:
|
||||
vertix = {
|
||||
"id": f"vertix_{vertix_count}",
|
||||
"x": bbox[0] + ((bbox[2] - bbox[0]) / 2),
|
||||
"y": bbox[1] + ((bbox[3] - bbox[1]) / 9),
|
||||
}
|
||||
elem["startVertexId"] = vertix["id"]
|
||||
vertices.append(vertix)
|
||||
vertix_count += 1
|
||||
vertix = {
|
||||
"id": f"vertix_{vertix_count}",
|
||||
"x": bbox[0] + ((bbox[2] - bbox[0]) / 2),
|
||||
"y": bbox[1] + ((bbox[3] - bbox[1]) * 8 / 9),
|
||||
}
|
||||
elem["endVertexId"] = vertix["id"]
|
||||
vertices.append(vertix)
|
||||
vertix_count += 1
|
||||
|
||||
if label == "switch":
|
||||
elem["closed"] = False
|
||||
elif label == "ammeter":
|
||||
elem["internalResistance"] = 0.01
|
||||
elif label == "RESISTOR_BOX" or label == "VARIABLE_RESISTOR":
|
||||
elem["type"] = "resistor"
|
||||
elem["resistorType"] = label
|
||||
elif label == "battery":
|
||||
elem["voltage"] = 9
|
||||
elem["batterType"] = "BATTERRY"
|
||||
elem["internalResistance"] = 0.0001
|
||||
elements.append(elem)
|
||||
element_count += 1
|
||||
|
||||
def find_nearest(point):
|
||||
min_dist = float('inf')
|
||||
nearest_vertex = None
|
||||
for vertex in vertices:
|
||||
ver = (vertex["x"], vertex["y"])
|
||||
dist = np.linalg.norm(np.array(point) - np.array(ver))
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
nearest_vertex = vertex
|
||||
return nearest_vertex
|
||||
# 加入wire
|
||||
for wire in wires:
|
||||
wire_start = (wire["start"]["x"], wire["start"]["y"])
|
||||
wire_end = (wire["end"]["x"], wire["end"]["y"])
|
||||
nearest_start = find_nearest(wire_start)
|
||||
nearest_end = find_nearest(wire_end)
|
||||
|
||||
elements.append({
|
||||
"id": f"element_{element_count}",
|
||||
"startVertexId": nearest_start["id"],
|
||||
"endVertexId": nearest_end["id"],
|
||||
"type": "wire",
|
||||
"resistance": 3e-8
|
||||
})
|
||||
element_count += 1
|
||||
|
||||
data = {
|
||||
"formatVersion": "1.0",
|
||||
"metadata": {
|
||||
"title": "Exported Circuit",
|
||||
"description": "Circuit exported from image",
|
||||
"created": datetime.datetime.now(datetime.UTC).isoformat() + "Z"
|
||||
},
|
||||
"vertices": vertices,
|
||||
"elements": elements,
|
||||
"displaySettings": {
|
||||
"showCurrent": True,
|
||||
"currentType": "electrons",
|
||||
"wireResistivity": 1e-10,
|
||||
"sourceResistance": 0.0001
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def visualize_wires_and_components(image, results, components):
|
||||
original = image
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"无法读取图像:")
|
||||
|
||||
for p in results["vertices"]:
|
||||
point = (int(p["x"]), int(p["y"]))
|
||||
cv2.circle(img, point, 6, (0, 0, 255), -1)
|
||||
|
||||
# ==== 画导线 ====
|
||||
for wire in results["elements"]:
|
||||
if wire["type"] != "wire":
|
||||
continue
|
||||
point = wire["startVertexId"]
|
||||
for v in results["vertices"]:
|
||||
if v["id"] == point:
|
||||
point = v
|
||||
|
||||
start = (int(point["x"]), int(point["y"]))
|
||||
point = wire["endVertexId"]
|
||||
for v in results["vertices"]:
|
||||
if v["id"] == point:
|
||||
point = v
|
||||
end = (int(point["x"]), int(point["y"]))
|
||||
|
||||
# 起点:红色,终点:蓝色
|
||||
cv2.circle(img, start, 6, (0, 0, 255), -1)
|
||||
cv2.circle(img, end, 6, (255, 0, 0), -1)
|
||||
cv2.line(img, start, end, (0, 255, 255), 2)
|
||||
|
||||
cv2.putText(img, "start", (start[0]+5, start[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
||||
cv2.putText(img, "end", (end[0]+5, end[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
|
||||
|
||||
# ==== 画元件框 ====
|
||||
labels = {"微安电流表": "ammeter",
|
||||
"待测表头": "ammeter",
|
||||
"电阻箱": "RESISTOR_BOX",
|
||||
"滑动变阻器": "VARIABLE_RESISTOR",
|
||||
"单刀双掷开关": "switch",
|
||||
"电源": "BATTERY"
|
||||
}
|
||||
for comp in components:
|
||||
label = labels[comp["label"]]
|
||||
x1, y1, x2, y2 = comp["bbox"]
|
||||
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(img, label, (x1, y1 - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
_, encode_img = cv2.imencode('.jpg', img)
|
||||
img_base64 = base64.b64encode(encode_img).decode('utf-8')
|
||||
#cv2.imwrite('output.jpg', img)
|
||||
return img_base64
|
||||
# 显示图像
|
||||
# resized = cv2.resize(img, (0, 0), fx=0.6, fy=0.6)
|
||||
# cv2.namedWindow("Wires and Components", cv2.WINDOW_NORMAL)
|
||||
# cv2.imshow("Wires and Components", resized)
|
||||
# cv2.waitKey(0)
|
||||
# cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def img_recognition(img):
|
||||
wires = wires_recog.detect_wires_and_endpoints(img)
|
||||
elements = ele_recog.elements_recognition(img)
|
||||
results = generate_json(wires, elements)
|
||||
results_img = visualize_wires_and_components(img, results, elements)
|
||||
sumx = 0
|
||||
sumy = 0
|
||||
for tmp in results["vertices"]:
|
||||
sumx += tmp["x"]
|
||||
sumy += tmp["y"]
|
||||
for i in range(len(results["vertices"])):
|
||||
results["vertices"][i]["x"] -= sumx/len(results["vertices"])
|
||||
results["vertices"][i]["y"] -= sumy/len(results["vertices"])
|
||||
results["vertices"][i]["x"] *= 0.6
|
||||
results["vertices"][i]["y"] *= 0.6
|
||||
|
||||
request = {
|
||||
"success": True,
|
||||
"recognizedImage": f"data:image/jpeg;base64,{results_img}",
|
||||
"circuitData": results
|
||||
}
|
||||
# with open('test.json', "w") as f:
|
||||
# json.dump(request, f, indent=2)
|
||||
# print(f"✅ 已导出电路 JSON 至 {'result_json'}")
|
||||
return request
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# start = time.perf_counter()
|
||||
# imgs_path = [
|
||||
#
|
||||
# ]
|
||||
# for img_path in imgs_path:
|
||||
# img_recognition(img_path)
|
||||
# end = time.perf_counter()
|
||||
# print(f"处理{len(imgs_path)}张图片耗时:{end - start:.2f}s")
|
||||
27
mian.py
Normal file
27
mian.py
Normal file
@ -0,0 +1,27 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
import numpy as np
|
||||
import cv2
|
||||
from img_prec import img_recognition
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
|
||||
@app.route('/process_image', methods=['POST'])
|
||||
def process_image():
|
||||
if 'image' not in request.files:
|
||||
return jsonify({'error': 'No image part in the request'}), 400
|
||||
|
||||
file = request.files['image']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No selected image'}), 400
|
||||
|
||||
file_bytes = np.frombuffer(file.read(), np.uint8)
|
||||
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
||||
|
||||
result = img_recognition(img)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, port=8000)
|
||||
103
wires_recog.py
Normal file
103
wires_recog.py
Normal file
@ -0,0 +1,103 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
|
||||
def show(title, img, scale=0.6):
|
||||
resized = cv2.resize(img, (0, 0), fx=scale, fy=scale)
|
||||
cv2.imshow(title, resized)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
def extract_wire_mask(hsv_img, lower, upper, name='color'):
|
||||
mask = cv2.inRange(hsv_img, lower, upper)
|
||||
# show(f"{name} - begin", mask)
|
||||
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
||||
closed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
|
||||
dilated = cv2.dilate(closed, kernel, iterations=1)
|
||||
# show(f"{name} - process + exbound", dilated)
|
||||
return dilated
|
||||
|
||||
|
||||
def get_skeleton(binary_mask, name='skeleton'):
|
||||
skel = skeletonize(binary_mask // 255).astype(np.uint8) * 255
|
||||
# show(f"{name} - bouns", skel)
|
||||
return skel
|
||||
|
||||
|
||||
def find_endpoints(skel_img):
|
||||
endpoints = []
|
||||
h, w = skel_img.shape
|
||||
for y in range(1, h - 1):
|
||||
for x in range(1, w - 1):
|
||||
if skel_img[y, x] == 255:
|
||||
patch = skel_img[y - 1:y + 2, x - 1:x + 2]
|
||||
if cv2.countNonZero(patch) == 2:
|
||||
endpoints.append((x, y))
|
||||
return endpoints
|
||||
|
||||
|
||||
def detect_wires_and_endpoints(image):
|
||||
original = image
|
||||
img = cv2.resize(original, (1000, int(original.shape[0] * 1000 / original.shape[1])))
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
|
||||
ranges = {
|
||||
'red': [(np.array([0, 112, 38]), np.array([8, 255, 255])),
|
||||
(np.array([160, 70, 50]), np.array([180, 255, 255]))],
|
||||
'green': [(np.array([35, 80, 80]), np.array([85, 255, 255]))],
|
||||
'yellow': [(np.array([19, 115, 103]), np.array([35, 255, 255]))]
|
||||
}
|
||||
|
||||
result_img = img.copy()
|
||||
all_wires = []
|
||||
|
||||
for color_name, hsv_ranges in ranges.items():
|
||||
# print(f"\n🟢 正在处理颜色: {color_name.upper()}")
|
||||
|
||||
mask_total = np.zeros(hsv.shape[:2], dtype=np.uint8)
|
||||
for (lower, upper) in hsv_ranges:
|
||||
mask = extract_wire_mask(hsv, lower, upper, color_name)
|
||||
mask_total = cv2.bitwise_or(mask_total, mask)
|
||||
|
||||
skeleton = get_skeleton(mask_total, f"{color_name}_skeleton")
|
||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(skeleton, connectivity=8)
|
||||
|
||||
for i in range(1, num_labels):
|
||||
wire_mask = (labels == i).astype(np.uint8) * 255
|
||||
pixel_count = cv2.countNonZero(wire_mask)
|
||||
|
||||
if pixel_count < 100:
|
||||
continue
|
||||
|
||||
endpoints = find_endpoints(wire_mask)
|
||||
wire_vis = cv2.cvtColor(wire_mask, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
if len(endpoints) >= 2:
|
||||
start, end = endpoints[0], endpoints[-1]
|
||||
# print(f"✅ {color_name.upper()}导线 #{i}: 起点 {start},终点 {end},像素数 {pixel_count}")
|
||||
cv2.circle(wire_vis, start, 6, (0, 0, 255), -1)
|
||||
cv2.circle(wire_vis, end, 6, (255, 0, 0), -1)
|
||||
cv2.line(wire_vis, start, end, (0, 255, 255), 2)
|
||||
|
||||
cv2.circle(result_img, start, 6, (0, 0, 255), -1)
|
||||
cv2.circle(result_img, end, 6, (255, 0, 0), -1)
|
||||
cv2.line(result_img, start, end, (0, 255, 255), 2)
|
||||
|
||||
# 保存导线数据
|
||||
wire_data = {
|
||||
"start": {"x": int(start[0]), "y": int(start[1])},
|
||||
"end": {"x": int(end[0]), "y": int(end[1])},
|
||||
}
|
||||
all_wires.append(wire_data)
|
||||
|
||||
# show(f"{color_name.upper()} 导线 #{i}", wire_vis)
|
||||
|
||||
# 显示图像
|
||||
# cv2.imshow('tmp', result_img)
|
||||
# cv2.waitKey(0)
|
||||
return all_wires
|
||||
|
||||
# 示例调用
|
||||
# detect_wires_and_endpoints("img/5.jpg")
|
||||
Loading…
x
Reference in New Issue
Block a user