mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 09:24:07 +00:00
get_all_fd bugfix
This commit is contained in:
parent
331f346756
commit
6b256ca19d
128
Phase 2/utils.py
128
Phase 2/utils.py
@ -270,9 +270,28 @@ def resnet_extractor(image):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_all_fd(image_id, img=None, label=None):
|
def resnet_output(image):
|
||||||
|
"""Get image features from ResNet50 (full execution)"""
|
||||||
|
resized_image = (
|
||||||
|
torch.Tensor(np.array(transforms.Resize((224, 224))(image)).flatten())
|
||||||
|
.view(1, 3, 224, 224)
|
||||||
|
.to(dev)
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
features = model(resized_image)
|
||||||
|
|
||||||
|
return features.detach().cpu().tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_fd(image_id, given_image=None, given_label=None):
|
||||||
"""Get all feature descriptors of a given image"""
|
"""Get all feature descriptors of a given image"""
|
||||||
|
if image_id == -1:
|
||||||
|
img, label = given_image, given_label
|
||||||
|
else:
|
||||||
|
img, label = dataset[image_id]
|
||||||
img_shape = np.array(img).shape
|
img_shape = np.array(img).shape
|
||||||
|
print(img_shape)
|
||||||
if img_shape[0] >= 3:
|
if img_shape[0] >= 3:
|
||||||
true_channels = 3
|
true_channels = 3
|
||||||
else:
|
else:
|
||||||
@ -283,6 +302,7 @@ def get_all_fd(image_id, img=None, label=None):
|
|||||||
cm_fd = CM_transform(img).tolist()
|
cm_fd = CM_transform(img).tolist()
|
||||||
hog_fd = HOG_transform(img).tolist()
|
hog_fd = HOG_transform(img).tolist()
|
||||||
avgpool_1024_fd, layer3_1024_fd, fc_1000_fd = resnet_extractor(img)
|
avgpool_1024_fd, layer3_1024_fd, fc_1000_fd = resnet_extractor(img)
|
||||||
|
resnet_fd = resnet_output(img)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image_id": image_id,
|
"image_id": image_id,
|
||||||
@ -293,6 +313,7 @@ def get_all_fd(image_id, img=None, label=None):
|
|||||||
"avgpool_fd": avgpool_1024_fd,
|
"avgpool_fd": avgpool_1024_fd,
|
||||||
"layer3_fd": layer3_1024_fd,
|
"layer3_fd": layer3_1024_fd,
|
||||||
"fc_fd": fc_1000_fd,
|
"fc_fd": fc_1000_fd,
|
||||||
|
"resnet_fd": resnet_fd,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -336,6 +357,7 @@ valid_feature_models = {
|
|||||||
"avgpool": "avgpool_fd",
|
"avgpool": "avgpool_fd",
|
||||||
"layer3": "layer3_fd",
|
"layer3": "layer3_fd",
|
||||||
"fc": "fc_fd",
|
"fc": "fc_fd",
|
||||||
|
"resnet": "resnet_fd"
|
||||||
}
|
}
|
||||||
valid_distance_measures = {
|
valid_distance_measures = {
|
||||||
"euclidean": euclidean_distance_measure,
|
"euclidean": euclidean_distance_measure,
|
||||||
@ -348,6 +370,7 @@ feature_distance_matches = {
|
|||||||
"layer3_fd": pearson_distance_measure,
|
"layer3_fd": pearson_distance_measure,
|
||||||
"avgpool_fd": pearson_distance_measure,
|
"avgpool_fd": pearson_distance_measure,
|
||||||
"fc_fd": pearson_distance_measure,
|
"fc_fd": pearson_distance_measure,
|
||||||
|
"resnet_fd": pearson_distance_measure,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -576,6 +599,109 @@ def show_similar_images_for_label(
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def show_similar_labels_for_image(
|
||||||
|
fd_collection,
|
||||||
|
target_image_id,
|
||||||
|
target_image=None,
|
||||||
|
target_label=None,
|
||||||
|
k=10,
|
||||||
|
feature_model="fc",
|
||||||
|
distance_measure=pearson_distance_measure,
|
||||||
|
save_plots=False,
|
||||||
|
):
|
||||||
|
# if target from dataset
|
||||||
|
if target_image_id != -1:
|
||||||
|
print(
|
||||||
|
"Showing {} similar labels for image ID {}, using {} for {} feature descriptor...".format(
|
||||||
|
k, target_image_id, distance_measure.__name__, feature_model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# store target_image itself
|
||||||
|
min_dists = {target_image_id: 0}
|
||||||
|
|
||||||
|
if target_image_id % 2 == 0:
|
||||||
|
# Get target image's feature descriptors from database
|
||||||
|
target_image = fd_collection.find_one({"image_id": target_image_id})
|
||||||
|
else:
|
||||||
|
# Calculate target image's feature descriptors
|
||||||
|
target_image = get_all_fd(target_image_id)
|
||||||
|
|
||||||
|
target_image_fd = target_image[feature_model]
|
||||||
|
target_label = target_image["true_label"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Showing {} similar labels for given image, using {} for {} feature descriptor...".format(
|
||||||
|
k, distance_measure.__name__, feature_model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# store distance to target_image itself
|
||||||
|
min_dists = {-1: 0}
|
||||||
|
|
||||||
|
target_image_fds = get_all_fd(-1, target_image, target_label)
|
||||||
|
target_image_fd = np.array(target_image_fds[feature_model])
|
||||||
|
|
||||||
|
label_dict = {target_image_id: target_label}
|
||||||
|
|
||||||
|
target_image_fd = np.array(target_image[feature_model + "_fd"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
feature_model in valid_feature_models
|
||||||
|
), "feature_model should be one of " + str(valid_feature_models)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
distance_measure in valid_distance_measures.values()
|
||||||
|
), "distance_measure should be one of " + str(list(valid_distance_measures.keys()))
|
||||||
|
|
||||||
|
# only RGB for non RGB images
|
||||||
|
if feature_model != "hog":
|
||||||
|
all_images = fd_collection.find({"true_channels": 3})
|
||||||
|
else:
|
||||||
|
all_images = fd_collection.find()
|
||||||
|
|
||||||
|
for cur_img in all_images:
|
||||||
|
cur_img_id = cur_img["image_id"]
|
||||||
|
# skip target itself
|
||||||
|
if cur_img_id == target_image_id:
|
||||||
|
continue
|
||||||
|
cur_img_fd = np.array(cur_img[feature_model + "_fd"])
|
||||||
|
cur_dist = distance_measure(
|
||||||
|
cur_img_fd,
|
||||||
|
target_image_fd,
|
||||||
|
)
|
||||||
|
cursor = fd_collection.find({"image_id": cur_img_id})
|
||||||
|
label=cursor[0]["true_label"]
|
||||||
|
|
||||||
|
# store first k images irrespective of distance (so that we store no more than k minimum distances)
|
||||||
|
if len(min_dists) < k + 1 and label not in label_dict.values():
|
||||||
|
min_dists[cur_img_id] = cur_dist
|
||||||
|
label_dict[cur_img_id] = label
|
||||||
|
|
||||||
|
# if lower distance:
|
||||||
|
elif cur_dist < max(min_dists.values()) and label not in label_dict.values():
|
||||||
|
# add to min_dists
|
||||||
|
min_dists.update({cur_img_id: cur_dist})
|
||||||
|
label_dict.update({cur_img_id: label})
|
||||||
|
# remove greatest distance by index
|
||||||
|
pop_key=max(min_dists, key=min_dists.get)
|
||||||
|
min_dists.pop(pop_key)
|
||||||
|
label_dict.pop(pop_key)
|
||||||
|
|
||||||
|
min_dists = dict(sorted(min_dists.items(), key=lambda item: item[1]))
|
||||||
|
|
||||||
|
for image_id in min_dists.keys():
|
||||||
|
if image_id==target_image_id:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print("Label: ", label_dict[image_id], "; distance: ", min_dists[image_id])
|
||||||
|
sample_image, sample_label = dataset[image_id]
|
||||||
|
plt.imshow(transforms.ToPILImage()(sample_image))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
valid_dim_reduction_methods = {
|
valid_dim_reduction_methods = {
|
||||||
"svd": 1,
|
"svd": 1,
|
||||||
"nmf": 2,
|
"nmf": 2,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user