mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 10:44:06 +00:00
refactored pranav's code for task 1
This commit is contained in:
parent
57e35d2388
commit
dcde5f75f4
@ -131,23 +131,18 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"k = int(input(\"Enter value of k: \"))\n",
|
"k = int(input(\"Enter value of k: \"))\n",
|
||||||
"if k < 1:\n",
|
"if k < 1:\n",
|
||||||
" raise ValueError(\"k should be positive integer\")\n",
|
" raise ValueError(\"k should be a positive integer\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"selected_feature_model = str(\n",
|
"selected_feature_model = valid_feature_models[\n",
|
||||||
" input(\"Enter feature model - one of \" + str(valid_feature_models))\n",
|
" str(input(\"Enter feature model - one of \" + str(list(valid_feature_models.keys()))))\n",
|
||||||
")\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"selected_distance_measure = valid_distance_measures[\n",
|
"selected_distance_measure = valid_distance_measures[\n",
|
||||||
" str(\n",
|
" str(input(\"Enter distance measure - one of \" + str(list(valid_distance_measures.keys()))))\n",
|
||||||
" input(\n",
|
|
||||||
" \"Enter distance measure - one of \"\n",
|
|
||||||
" + str(list(valid_distance_measures.keys()))\n",
|
|
||||||
" )\n",
|
|
||||||
" )\n",
|
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if selected_image_id == -1:\n",
|
"if selected_image_id == -1:\n",
|
||||||
" show_similar_images(\n",
|
" show_similar_images_for_image(\n",
|
||||||
" fd_collection,\n",
|
" fd_collection,\n",
|
||||||
" -1,\n",
|
" -1,\n",
|
||||||
" sample_image,\n",
|
" sample_image,\n",
|
||||||
@ -158,7 +153,7 @@
|
|||||||
" save_plots=False,\n",
|
" save_plots=False,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" show_similar_images(\n",
|
" show_similar_images_for_image(\n",
|
||||||
" fd_collection,\n",
|
" fd_collection,\n",
|
||||||
" selected_image_id,\n",
|
" selected_image_id,\n",
|
||||||
" None,\n",
|
" None,\n",
|
||||||
|
|||||||
99
Phase 2/task_1.ipynb
Normal file
99
Phase 2/task_1.ipynb
Normal file
File diff suppressed because one or more lines are too long
114
Phase 2/utils.py
114
Phase 2/utils.py
@ -322,15 +322,28 @@ def pearson_distance_measure(img_1_fd, img_2_fd):
|
|||||||
return 0.5 * (1 - pearsonr(img_1_fd_reshaped, img_2_fd_reshaped).statistic)
|
return 0.5 * (1 - pearsonr(img_1_fd_reshaped, img_2_fd_reshaped).statistic)
|
||||||
|
|
||||||
|
|
||||||
valid_feature_models = ["cm", "hog", "avgpool", "layer3", "fc"]
|
valid_feature_models = {
|
||||||
|
"cm": "cm_fd",
|
||||||
|
"hog": "hog_fd",
|
||||||
|
"avgpool": "avgpool_fd",
|
||||||
|
"layer3": "layer3_fd",
|
||||||
|
"fc": "fc_fd",
|
||||||
|
}
|
||||||
valid_distance_measures = {
|
valid_distance_measures = {
|
||||||
"euclidean": euclidean_distance_measure,
|
"euclidean": euclidean_distance_measure,
|
||||||
"cosine": cosine_distance_measure,
|
"cosine": cosine_distance_measure,
|
||||||
"pearson": pearson_distance_measure,
|
"pearson": pearson_distance_measure,
|
||||||
}
|
}
|
||||||
|
feature_distance_matches = {
|
||||||
|
"cm_fd": euclidean_distance_measure,
|
||||||
|
"hog_fd": cosine_distance_measure,
|
||||||
|
"layer3_fd": pearson_distance_measure,
|
||||||
|
"avgpool_fd": pearson_distance_measure,
|
||||||
|
"fc_fd": pearson_distance_measure,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def show_similar_images(
|
def show_similar_images_for_image(
|
||||||
fd_collection,
|
fd_collection,
|
||||||
target_image_id,
|
target_image_id,
|
||||||
target_image=None,
|
target_image=None,
|
||||||
@ -343,8 +356,8 @@ def show_similar_images(
|
|||||||
"""Set `target_image_id = -1` if giving image data and label manually"""
|
"""Set `target_image_id = -1` if giving image data and label manually"""
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
feature_model in valid_feature_models
|
feature_model in valid_feature_models.values()
|
||||||
), "feature_model should be one of " + str(valid_feature_models)
|
), "feature_model should be one of " + str(list(valid_feature_models.keys()))
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
distance_measure in valid_distance_measures.values()
|
distance_measure in valid_distance_measures.values()
|
||||||
@ -372,14 +385,14 @@ def show_similar_images(
|
|||||||
target_image, target_label = dataset[target_image_id]
|
target_image, target_label = dataset[target_image_id]
|
||||||
target_image_fds = get_all_fd(target_image_id, target_image, target_label)
|
target_image_fds = get_all_fd(target_image_id, target_image, target_label)
|
||||||
|
|
||||||
target_image_fd = np.array(target_image_fds[feature_model + "_fd"])
|
target_image_fd = np.array(target_image_fds[feature_model])
|
||||||
|
|
||||||
for cur_img in all_images:
|
for cur_img in all_images:
|
||||||
cur_img_id = cur_img["image_id"]
|
cur_img_id = cur_img["image_id"]
|
||||||
# skip target itself
|
# skip target itself
|
||||||
if cur_img_id == target_image_id:
|
if cur_img_id == target_image_id:
|
||||||
continue
|
continue
|
||||||
cur_img_fd = np.array(cur_img[feature_model + "_fd"])
|
cur_img_fd = np.array(cur_img[feature_model])
|
||||||
|
|
||||||
cur_dist = distance_measure(
|
cur_dist = distance_measure(
|
||||||
cur_img_fd,
|
cur_img_fd,
|
||||||
@ -428,11 +441,11 @@ def show_similar_images(
|
|||||||
min_dists = {-1: 0}
|
min_dists = {-1: 0}
|
||||||
|
|
||||||
target_image_fds = get_all_fd(-1, target_image, target_label)
|
target_image_fds = get_all_fd(-1, target_image, target_label)
|
||||||
target_image_fd = np.array(target_image_fds[feature_model + "_fd"])
|
target_image_fd = np.array(target_image_fds[feature_model])
|
||||||
|
|
||||||
for cur_img in all_images:
|
for cur_img in all_images:
|
||||||
cur_img_id = cur_img["image_id"]
|
cur_img_id = cur_img["image_id"]
|
||||||
cur_img_fd = np.array(cur_img[feature_model + "_fd"])
|
cur_img_fd = np.array(cur_img[feature_model])
|
||||||
cur_dist = distance_measure(
|
cur_dist = distance_measure(
|
||||||
cur_img_fd,
|
cur_img_fd,
|
||||||
target_image_fd,
|
target_image_fd,
|
||||||
@ -468,3 +481,88 @@ def show_similar_images(
|
|||||||
f"Plots/Image_{target_image_id}_{feature_model}_{distance_measure.__name__}_k{k}.png"
|
f"Plots/Image_{target_image_id}_{feature_model}_{distance_measure.__name__}_k{k}.png"
|
||||||
)
|
)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_label_representatives(fd_collection, label, feature_model):
|
||||||
|
"""Calculate representative feature vector of a label as the mean of all feature vectors under a feature model"""
|
||||||
|
|
||||||
|
label_fds = [
|
||||||
|
img_fds[feature_model] # get the specific feature model's feature vector
|
||||||
|
for img_fds in fd_collection.find(
|
||||||
|
{"true_label": label}
|
||||||
|
) # repeat for all images
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate mean across each dimension
|
||||||
|
# and build a mean vector out of these means
|
||||||
|
label_mean_vector = [sum(col) / len(col) for col in zip(*label_fds)]
|
||||||
|
return label_mean_vector
|
||||||
|
|
||||||
|
|
||||||
|
def show_similar_images_for_label(
|
||||||
|
fd_collection,
|
||||||
|
target_label,
|
||||||
|
k=10,
|
||||||
|
feature_model="fc",
|
||||||
|
distance_measure=pearson_distance_measure,
|
||||||
|
save_plots=False,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
feature_model in valid_feature_models.values()
|
||||||
|
), "feature_model should be one of " + str(list(valid_feature_models.keys()))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
distance_measure in valid_distance_measures.values()
|
||||||
|
), "distance_measure should be one of " + str(list(valid_distance_measures.keys()))
|
||||||
|
|
||||||
|
all_images = fd_collection.find()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Showing {} similar images for label {}, using {} for {} feature descriptor...".format(
|
||||||
|
k, target_label, distance_measure.__name__, feature_model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# store distance to target_label itself ({image_id: distance}, -1 for target label)
|
||||||
|
min_dists = {}
|
||||||
|
|
||||||
|
# Calculate representative feature vector for label
|
||||||
|
label_rep = calculate_label_representatives(
|
||||||
|
fd_collection, target_label, feature_model
|
||||||
|
)
|
||||||
|
|
||||||
|
for cur_img in all_images:
|
||||||
|
cur_img_id = cur_img["image_id"]
|
||||||
|
cur_img_fd = np.array(cur_img[feature_model])
|
||||||
|
|
||||||
|
cur_dist = distance_measure(
|
||||||
|
cur_img_fd,
|
||||||
|
np.array(label_rep),
|
||||||
|
)
|
||||||
|
|
||||||
|
# store first k images irrespective of distance (so that we store no more than k minimum distances)
|
||||||
|
if len(min_dists) < k:
|
||||||
|
min_dists[cur_img_id] = cur_dist
|
||||||
|
|
||||||
|
# if lower distance:
|
||||||
|
elif cur_dist < max(min_dists.values()):
|
||||||
|
# add to min_dists
|
||||||
|
min_dists.update({cur_img_id: cur_dist})
|
||||||
|
# remove greatest distance by index
|
||||||
|
min_dists.pop(max(min_dists, key=min_dists.get))
|
||||||
|
|
||||||
|
min_dists = dict(sorted(min_dists.items(), key=lambda item: item[1]))
|
||||||
|
|
||||||
|
# Display the k images
|
||||||
|
fig, axs = plt.subplots(1, k, figsize=(48, 12))
|
||||||
|
for idx, (img_id, distance) in enumerate(min_dists.items()):
|
||||||
|
cur_img, _cur_label = dataset[img_id]
|
||||||
|
axs[idx].imshow(transforms.ToPILImage()(cur_img))
|
||||||
|
axs[idx].set_title(f"Distance: {round(distance, 3)}")
|
||||||
|
axs[idx].axis("off")
|
||||||
|
|
||||||
|
if save_plots:
|
||||||
|
plt.savefig(
|
||||||
|
f"Plots/Label_{target_label}_{feature_model}_{distance_measure.__name__}_k{k}.png"
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user