2025-04-29 20:36:41 +03:00

93 lines
2.7 KiB
Python

#
# Digital Image Processing HW01 assignment demo file
#
#
# author: Christos Choutouridis <cchoutou@ece.auth.gr>
# date: 29/04/2025
#
try:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
except ImportError as e:
print("Missing package: ", e)
print("Run: pip install -r requirements.txt to install.")
exit(1)
from hist_utils import calculate_hist_of_img
from hist_modif import perform_hist_eq, perform_hist_matching
# Define filenames
input_filename = "input_img.jpg"
ref_filename = "ref_img.jpg"
# Load images
input_img = Image.open(input_filename).convert("L")
ref_img = Image.open(ref_filename).convert("L")
# Convert to numpy arrays in [0,1]
input_array = np.array(input_img).astype(float) / 255.0
ref_array = np.array(ref_img).astype(float) / 255.0
# Create output directory
os.makedirs("demo_outputs", exist_ok=True)
def plot_comparison(input_array, output_array, title, filename):
"""
Create a 2x2 plot: input image, output image, input histogram, output histogram.
Save it to filename.
"""
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# Plot input image
axes[0, 0].imshow(input_array, cmap="gray", vmin=0, vmax=1)
axes[0, 0].set_title("Input Image")
axes[0, 0].axis("off")
# Plot output image
axes[0, 1].imshow(output_array, cmap="gray", vmin=0, vmax=1)
axes[0, 1].set_title("Output Image")
axes[0, 1].axis("off")
# Plot input histogram
input_hist = calculate_hist_of_img(input_array, return_normalized=True)
axes[1, 0].bar(list(input_hist.keys()), list(input_hist.values()), width=0.01)
axes[1, 0].set_title("Input Histogram")
# Plot output histogram
output_hist = calculate_hist_of_img(output_array, return_normalized=True)
axes[1, 1].bar(list(output_hist.keys()), list(output_hist.values()), width=0.01)
axes[1, 1].set_title("Output Histogram")
# Set overall title
fig.suptitle(title, fontsize=16)
# Adjust layout
plt.tight_layout(rect=[0, 0, 1, 0.95])
# Save
plt.savefig(filename)
plt.close()
# Modes to test
modes = ["greedy", "non-greedy", "post-disturbance"]
# Run equalization
for mode in modes:
print("Perform histogram equalization in mode: ", mode)
equalized_img = perform_hist_eq(input_array, mode)
out_filename = f"demo_outputs/equalization_{mode}.png"
plot_comparison(input_array, equalized_img, f"Histogram Equalization ({mode})", out_filename)
# Run matching
for mode in modes:
print("Perform histogram matching in mode: ", mode)
matched_img = perform_hist_matching(input_array, ref_array, mode)
out_filename = f"demo_outputs/matching_{mode}.png"
plot_comparison(input_array, matched_img, f"Histogram Matching ({mode})", out_filename)