54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import cv2
|
||
import numpy as np
|
||
from sklearn.cluster import KMeans
|
||
from tensorflow.keras.models import load_model
|
||
|
||
# =============================
|
||
# 1. 加载训练好的 U-Net 模型
|
||
# =============================
|
||
model = load_model("/home/gqw/unet_ai/elements_wires_congition/src/pre_func/unet_wire.pth", compile=False)
|
||
|
||
# 读入测试图像
|
||
img = cv2.imread("/home/gqw/unet_ai/test/IMG_0059.jpeg")
|
||
h, w, _ = img.shape
|
||
input_img = cv2.resize(img, (256, 256)) / 255.0
|
||
input_img = np.expand_dims(input_img, axis=0)
|
||
|
||
# =============================
|
||
# 2. U-Net 分割预测(得到二值 mask)
|
||
# =============================
|
||
pred = model.predict(input_img)[0, :, :, 0]
|
||
mask = (pred > 0.5).astype(np.uint8) # 阈值化
|
||
mask = cv2.resize(mask, (w, h))
|
||
|
||
# =============================
|
||
# 3. 提取导线区域的像素颜色
|
||
# =============================
|
||
wire_pixels = img[mask == 1] # N x 3 (RGB)
|
||
|
||
# KMeans 聚类(假设有3种颜色的导线)
|
||
n_colors = 3
|
||
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(wire_pixels)
|
||
labels = kmeans.labels_
|
||
|
||
# =============================
|
||
# 4. 生成每个颜色的 mask
|
||
# =============================
|
||
colored_masks = []
|
||
for i in range(n_colors):
|
||
color_mask = np.zeros(mask.shape, dtype=np.uint8)
|
||
color_mask[mask == 1] = (labels == i).astype(np.uint8)
|
||
colored_masks.append(color_mask)
|
||
|
||
# =============================
|
||
# 5. 可视化结果
|
||
# =============================
|
||
result = img.copy()
|
||
colors = [(0,0,255), (0,255,0), (255,0,0)] # 每个聚类显示的颜色
|
||
for i, cmask in enumerate(colored_masks):
|
||
result[cmask == 1] = colors[i % len(colors)]
|
||
|
||
cv2.imwrite("segmented_by_color.jpg", result)
|
||
|
||
print("输出完成:segmented_by_color.jpg")
|