From 6b256ca19d85e88fc46baad0e174dc4e846911d0 Mon Sep 17 00:00:00 2001 From: Kaushik Narayan R Date: Wed, 11 Oct 2023 17:38:45 -0700 Subject: [PATCH] get_all_fd bugfix --- Phase 2/utils.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/Phase 2/utils.py b/Phase 2/utils.py index 312d324..44ec0c0 100644 --- a/Phase 2/utils.py +++ b/Phase 2/utils.py @@ -270,9 +270,28 @@ def resnet_extractor(image): ) -def get_all_fd(image_id, img=None, label=None): +def resnet_output(image): + """Get image features from ResNet50 (full execution)""" + resized_image = ( + torch.Tensor(np.array(transforms.Resize((224, 224))(image)).flatten()) + .view(1, 3, 224, 224) + .to(dev) + ) + + with torch.no_grad(): + features = model(resized_image) + + return features.detach().cpu().tolist() + + +def get_all_fd(image_id, given_image=None, given_label=None): """Get all feature descriptors of a given image""" + if image_id == -1: + img, label = given_image, given_label + else: + img, label = dataset[image_id] img_shape = np.array(img).shape + print(img_shape) if img_shape[0] >= 3: true_channels = 3 else: @@ -283,6 +302,7 @@ def get_all_fd(image_id, img=None, label=None): cm_fd = CM_transform(img).tolist() hog_fd = HOG_transform(img).tolist() avgpool_1024_fd, layer3_1024_fd, fc_1000_fd = resnet_extractor(img) + resnet_fd = resnet_output(img) return { "image_id": image_id, @@ -293,6 +313,7 @@ def get_all_fd(image_id, img=None, label=None): "avgpool_fd": avgpool_1024_fd, "layer3_fd": layer3_1024_fd, "fc_fd": fc_1000_fd, + "resnet_fd": resnet_fd, } @@ -336,6 +357,7 @@ valid_feature_models = { "avgpool": "avgpool_fd", "layer3": "layer3_fd", "fc": "fc_fd", + "resnet": "resnet_fd" } valid_distance_measures = { "euclidean": euclidean_distance_measure, @@ -348,6 +370,7 @@ feature_distance_matches = { "layer3_fd": pearson_distance_measure, "avgpool_fd": pearson_distance_measure, "fc_fd": pearson_distance_measure, + "resnet_fd": pearson_distance_measure, } @@ -576,6 +599,109 @@ def show_similar_images_for_label( plt.show() +def show_similar_labels_for_image( + fd_collection, + target_image_id, + target_image=None, + target_label=None, + k=10, + feature_model="fc", + distance_measure=pearson_distance_measure, + save_plots=False, +): + # if target from dataset + if target_image_id != -1: + print( + "Showing {} similar labels for image ID {}, using {} for {} feature descriptor...".format( + k, target_image_id, distance_measure.__name__, feature_model + ) + ) + + # store target_image itself + min_dists = {target_image_id: 0} + + if target_image_id % 2 == 0: + # Get target image's feature descriptors from database + target_image = fd_collection.find_one({"image_id": target_image_id}) + else: + # Calculate target image's feature descriptors + target_image = get_all_fd(target_image_id) + + target_image_fd = target_image[feature_model] + target_label = target_image["true_label"] + + else: + print( + "Showing {} similar labels for given image, using {} for {} feature descriptor...".format( + k, distance_measure.__name__, feature_model + ) + ) + + # store distance to target_image itself + min_dists = {-1: 0} + + target_image_fds = get_all_fd(-1, target_image, target_label) + target_image_fd = np.array(target_image_fds[feature_model]) + + label_dict = {target_image_id: target_label} + + target_image_fd = np.array(target_image[feature_model + "_fd"]) + + assert ( + feature_model in valid_feature_models + ), "feature_model should be one of " + str(valid_feature_models) + + assert ( + distance_measure in valid_distance_measures.values() + ), "distance_measure should be one of " + str(list(valid_distance_measures.keys())) + + # only RGB for non RGB images + if feature_model != "hog": + all_images = fd_collection.find({"true_channels": 3}) + else: + all_images = fd_collection.find() + + for cur_img in all_images: + cur_img_id = cur_img["image_id"] + # skip target itself + if cur_img_id == target_image_id: + continue + cur_img_fd = np.array(cur_img[feature_model + "_fd"]) + cur_dist = distance_measure( + cur_img_fd, + target_image_fd, + ) + cursor = fd_collection.find({"image_id": cur_img_id}) + label=cursor[0]["true_label"] + + # store first k images irrespective of distance (so that we store no more than k minimum distances) + if len(min_dists) < k + 1 and label not in label_dict.values(): + min_dists[cur_img_id] = cur_dist + label_dict[cur_img_id] = label + + # if lower distance: + elif cur_dist < max(min_dists.values()) and label not in label_dict.values(): + # add to min_dists + min_dists.update({cur_img_id: cur_dist}) + label_dict.update({cur_img_id: label}) + # remove greatest distance by index + pop_key=max(min_dists, key=min_dists.get) + min_dists.pop(pop_key) + label_dict.pop(pop_key) + + min_dists = dict(sorted(min_dists.items(), key=lambda item: item[1])) + + for image_id in min_dists.keys(): + if image_id==target_image_id: + continue + else: + print("Label: ", label_dict[image_id], "; distance: ", min_dists[image_id]) + sample_image, sample_label = dataset[image_id] + plt.imshow(transforms.ToPILImage()(sample_image)) + plt.show() + + + valid_dim_reduction_methods = { "svd": 1, "nmf": 2,