In [110]:
from utils import *
warnings.filterwarnings('ignore')
%matplotlib inline
from statistics import mode
import pickle

In [111]:
fd_collection = getCollection("team_5_mwdb_phase_2", "fd_collection")

In [112]:
selected_feature_model = "fc_fd"

classification_method = str(
    input(
        "Enter classification method - one of "
        + str(list(valid_classification_methods.keys()))
    )
)

if classification_method == "m-nn":
    m = int(input("Enter value of m: "))
    if m < 1:
        raise ValueError("m should be a positive integer")

In [113]:
all_images = list(fd_collection.find())
all_images = sorted(all_images, key = lambda x: x["image_id"])

odd_image_ids = [img["image_id"] for img in all_images if img["image_id"] % 2 == 0]

even_image_labels = [img["true_label"] for img in all_images if img["image_id"] % 2 == 0]
odd_image_labels = [img["true_label"] for img in all_images if img["image_id"] % 2 != 0]

feature_vectors = [np.array(img[selected_feature_model]).flatten() for img in all_images]

total_len = len(feature_vectors)
even_feature_vectors = []
odd_feature_vectors = []

for i in range(total_len):
  if i % 2 == 0:
    even_feature_vectors.append(feature_vectors[i])
  else:
    odd_feature_vectors.append(feature_vectors[i])

even_feature_vectors = np.array(even_feature_vectors)
odd_feature_vectors = np.array(odd_feature_vectors)

odd_len = odd_feature_vectors.shape[0]
even_len = even_feature_vectors.shape[0]

In [114]:
match valid_classification_methods[classification_method]:

    case 1:

        predictions = []

        for i, odd_vector in enumerate(odd_feature_vectors):

            pq = []

            for j, even_vector in enumerate(even_feature_vectors):
                
                distance = np.linalg.norm(odd_vector - even_vector)

                if len(pq) < m:
                    heapq.heappush(pq, (-distance, even_image_labels[j]))
                else:
                    heapq.heappushpop(pq, (-distance, even_image_labels[j]))
            
            labels = [label for dist, label in pq]

            
            pred = max(set(labels), key = labels.count)

            predictions.append(pred)

            print(f"Image ID: {i * 2 + 1} is similar to {pred}")
    

    case 2:

        max_depth = 10

        if os.path.exists(f'decision_tree_{max_depth}.pkl'):
            with open(f'decision_tree_{max_depth}.pkl', 'rb') as file:
                tree = pickle.load(file)
                print("Decision tree loaded")
        else:
            print("Creating the decision tree ...")
            tree = DecisionTree(max_depth = max_depth)
            tree.fit(even_feature_vectors, even_image_labels)
            print("Decision tree formed")
            with open(f'decision_tree_{max_depth}.pkl', 'wb') as file:
                pickle.dump(tree, file)

        predictions = tree.predict(odd_feature_vectors)

        pred_len = len(predictions)

        for i in range(pred_len):
            print(f"Image ID: {i * 2 + 1} is similar to {predictions[i]}")


correct_predictions = sum(1 for actual, predicted in zip(odd_image_labels, predictions) if actual == predicted)
accuracy = (correct_predictions / len(odd_image_labels)) * 100.0
print(f"Accuracy: {accuracy:.2f}%")        


Image ID: 1 is similar to 1
Image ID: 3 is similar to 0
Image ID: 5 is similar to 1
Image ID: 7 is similar to 1
Image ID: 9 is similar to 1
Image ID: 11 is similar to 1
Image ID: 13 is similar to 0
Image ID: 15 is similar to 0
Image ID: 17 is similar to 1
Image ID: 19 is similar to 1
Image ID: 21 is similar to 0
Image ID: 23 is similar to 1
Image ID: 25 is similar to 1
Image ID: 27 is similar to 0
Image ID: 29 is similar to 0
Image ID: 31 is similar to 0
Image ID: 33 is similar to 1
Image ID: 35 is similar to 0
Image ID: 37 is similar to 1
Image ID: 39 is similar to 0
Image ID: 41 is similar to 1
Image ID: 43 is similar to 1
Image ID: 45 is similar to 1
Image ID: 47 is similar to 0
Image ID: 49 is similar to 1
Image ID: 51 is similar to 1
Image ID: 53 is similar to 0
Image ID: 55 is similar to 1
Image ID: 57 is similar to 1
Image ID: 59 is similar to 0
Image ID: 61 is similar to 1
Image ID: 63 is similar to 0
Image ID: 65 is similar to 1
Image ID: 67 is similar to 0
Image ID: 69 is sim