2026-01-05 15:48:54 +08:00

54 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from sklearn.cluster import KMeans
from tensorflow.keras.models import load_model
# =============================
# 1. 加载训练好的 U-Net 模型
# =============================
model = load_model("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/unet_wire.pth", compile=False)
# 读入测试图像
img = cv2.imread("/home/gqw/unet_ai/test/IMG_0059.jpeg")
h, w, _ = img.shape
input_img = cv2.resize(img, (256, 256)) / 255.0
input_img = np.expand_dims(input_img, axis=0)
# =============================
# 2. U-Net 分割预测(得到二值 mask
# =============================
pred = model.predict(input_img)[0, :, :, 0]
mask = (pred > 0.5).astype(np.uint8) # 阈值化
mask = cv2.resize(mask, (w, h))
# =============================
# 3. 提取导线区域的像素颜色
# =============================
wire_pixels = img[mask == 1] # N x 3 (RGB)
# KMeans 聚类假设有3种颜色的导线
n_colors = 3
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(wire_pixels)
labels = kmeans.labels_
# =============================
# 4. 生成每个颜色的 mask
# =============================
colored_masks = []
for i in range(n_colors):
color_mask = np.zeros(mask.shape, dtype=np.uint8)
color_mask[mask == 1] = (labels == i).astype(np.uint8)
colored_masks.append(color_mask)
# =============================
# 5. 可视化结果
# =============================
result = img.copy()
colors = [(0,0,255), (0,255,0), (255,0,0)] # 每个聚类显示的颜色
for i, cmask in enumerate(colored_masks):
result[cmask == 1] = colors[i % len(colors)]
cv2.imwrite("segmented_by_color.jpg", result)
print("输出完成segmented_by_color.jpg")