80 lines
2.1 KiB
Python
80 lines
2.1 KiB
Python
#
|
|
# Demo 2: Spectral clustering on RGB images (d2a, d2b)
|
|
#
|
|
# Combines image_to_graph + spectral_clustering
|
|
#
|
|
# author: Christos Choutouridis <cchoutou@ece.auth.gr>
|
|
# date: 05/07/2025
|
|
#
|
|
|
|
try:
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from scipy.io import loadmat
|
|
from image_to_graph import image_to_graph
|
|
from spectral_clustering import spectral_clustering
|
|
except ImportError as e:
|
|
print("Missing package:", e)
|
|
exit(1)
|
|
|
|
|
|
def plot_clusters_on_image(image: np.ndarray, cluster_idx: np.ndarray, k: int, title: str, fname: str):
|
|
"""
|
|
Overlays clustering result on the image using a colormap.
|
|
|
|
Parameters:
|
|
-----------
|
|
image : np.ndarray of shape (M, N, 3)
|
|
Original RGB image.
|
|
|
|
cluster_idx : np.ndarray of shape (M*N,)
|
|
Flattened array of cluster labels.
|
|
k : int
|
|
Number of clusters.
|
|
title : str
|
|
Title for the plot.
|
|
fname : str
|
|
Output filename to save.
|
|
"""
|
|
M, N, _ = image.shape
|
|
clustered_img = cluster_idx.reshape(M, N)
|
|
|
|
plt.figure(figsize=(4, 4))
|
|
plt.imshow(clustered_img, cmap='tab10', vmin=0, vmax=k-1)
|
|
plt.title(title)
|
|
plt.axis('off')
|
|
plt.tight_layout()
|
|
plt.savefig(fname)
|
|
print(f"Saved: {fname}")
|
|
plt.close()
|
|
|
|
|
|
def run_demo2(normalized: bool = False):
|
|
data = loadmat("dip_hw_3.mat")
|
|
# Select string
|
|
normalized_str = "Normalized" if normalized else "Unnormalized"
|
|
|
|
for name in ["d2a", "d2b"]:
|
|
img = data[name]
|
|
print(f"\n=== {normalized_str} test for Image {name} - shape: {img.shape} ===")
|
|
|
|
affinity_mat = image_to_graph(img)
|
|
print("Affinity matrix computed.")
|
|
|
|
for k in [2, 3, 4]:
|
|
print(f" Clustering with k={k}...")
|
|
labels = spectral_clustering(affinity_mat, k=k, normalized=normalized)
|
|
|
|
plot_clusters_on_image(
|
|
img,
|
|
labels,
|
|
k,
|
|
title=f"{name} spectral clustering (k={k})",
|
|
fname=f"plots/demo2_{name}_k{k}_{normalized_str}.png"
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_demo2(False)
|
|
run_demo2(True)
|