mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 08:54:07 +00:00
Merge branch 'master' of https://github.com/20kaushik02/CSE515_MWDB_Project
This commit is contained in:
commit
6d1abb4f19
BIN
Phase 3/decision_tree_10_150.pkl
Normal file
BIN
Phase 3/decision_tree_10_150.pkl
Normal file
Binary file not shown.
7675
Phase 3/task_3.ipynb
7675
Phase 3/task_3.ipynb
File diff suppressed because it is too large
Load Diff
243
Phase 3/task_3_sklearn_decision_tree.ipynb
Normal file
243
Phase 3/task_3_sklearn_decision_tree.ipynb
Normal file
@ -0,0 +1,243 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import *\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"%matplotlib inline\n",
|
||||
"from statistics import mode\n",
|
||||
"import pickle\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from sklearn.decomposition import PCA\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fd_collection = getCollection(\"team_5_mwdb_phase_2\", \"fd_collection\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"selected_feature_model = \"fc_fd\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"all_images = list(fd_collection.find())\n",
|
||||
"all_images = sorted(all_images, key = lambda x: x[\"image_id\"])\n",
|
||||
"\n",
|
||||
"odd_image_ids = [img[\"image_id\"] for img in all_images if img[\"image_id\"] % 2 == 0]\n",
|
||||
"\n",
|
||||
"even_image_labels = [img[\"true_label\"] for img in all_images if img[\"image_id\"] % 2 == 0]\n",
|
||||
"odd_image_labels = [img[\"true_label\"] for img in all_images if img[\"image_id\"] % 2 != 0]\n",
|
||||
"\n",
|
||||
"feature_vectors = [np.array(img[selected_feature_model]).flatten() for img in all_images]\n",
|
||||
"\n",
|
||||
"total_len = len(feature_vectors)\n",
|
||||
"even_feature_vectors = []\n",
|
||||
"odd_feature_vectors = []\n",
|
||||
"\n",
|
||||
"for i in range(total_len):\n",
|
||||
" if i % 2 == 0:\n",
|
||||
" even_feature_vectors.append(feature_vectors[i])\n",
|
||||
" else:\n",
|
||||
" odd_feature_vectors.append(feature_vectors[i])\n",
|
||||
"\n",
|
||||
"even_feature_vectors = np.array(even_feature_vectors)\n",
|
||||
"odd_feature_vectors = np.array(odd_feature_vectors)\n",
|
||||
"\n",
|
||||
"odd_len = odd_feature_vectors.shape[0]\n",
|
||||
"even_len = even_feature_vectors.shape[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_depth = 10\n",
|
||||
"reduced_dimensionality = 150\n",
|
||||
"\n",
|
||||
"# pca = PCA(n_components = reduced_dimensionality)\n",
|
||||
"# even_feature_vectors = pca.fit_transform(even_feature_vectors)\n",
|
||||
"# odd_feature_vectors = pca.fit_transform(odd_feature_vectors)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"clf = DecisionTreeClassifier()\n",
|
||||
"clf.fit(even_feature_vectors, even_image_labels)\n",
|
||||
"predictions = clf.predict(odd_feature_vectors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Class 0: Precision=0.7946428571428571, Recall=0.8202764976958525, F1-score=0.8072562358276645\n",
|
||||
"Class 1: Precision=0.8195121951219512, Recall=0.7706422018348624, F1-score=0.7943262411347518\n",
|
||||
"Class 2: Precision=0.9361702127659575, Recall=0.88, F1-score=0.9072164948453608\n",
|
||||
"Class 3: Precision=0.9897172236503856, Recall=0.9649122807017544, F1-score=0.9771573604060915\n",
|
||||
"Class 4: Precision=1.0, Recall=0.9259259259259259, F1-score=0.9615384615384615\n",
|
||||
"Class 5: Precision=0.9503722084367245, Recall=0.9575, F1-score=0.9539227895392279\n",
|
||||
"Class 6: Precision=0.09523809523809523, Recall=0.09523809523809523, F1-score=0.09523809523809523\n",
|
||||
"Class 7: Precision=0.875, Recall=0.6666666666666666, F1-score=0.7567567567567567\n",
|
||||
"Class 8: Precision=0.9473684210526315, Recall=0.75, F1-score=0.8372093023255814\n",
|
||||
"Class 9: Precision=0.5652173913043478, Recall=0.48148148148148145, F1-score=0.52\n",
|
||||
"Class 10: Precision=0.7307692307692307, Recall=0.8260869565217391, F1-score=0.7755102040816326\n",
|
||||
"Class 11: Precision=1.0, Recall=0.875, F1-score=0.9333333333333333\n",
|
||||
"Class 12: Precision=0.8524590163934426, Recall=0.8125, F1-score=0.8319999999999999\n",
|
||||
"Class 13: Precision=0.7272727272727273, Recall=0.8163265306122449, F1-score=0.7692307692307693\n",
|
||||
"Class 14: Precision=0.32, Recall=0.36363636363636365, F1-score=0.3404255319148936\n",
|
||||
"Class 15: Precision=0.5102040816326531, Recall=0.5952380952380952, F1-score=0.5494505494505494\n",
|
||||
"Class 16: Precision=0.5471698113207547, Recall=0.6304347826086957, F1-score=0.5858585858585859\n",
|
||||
"Class 17: Precision=0.7931034482758621, Recall=0.92, F1-score=0.851851851851852\n",
|
||||
"Class 18: Precision=0.75, Recall=0.5714285714285714, F1-score=0.6486486486486486\n",
|
||||
"Class 19: Precision=0.9824561403508771, Recall=0.9032258064516129, F1-score=0.9411764705882352\n",
|
||||
"Class 20: Precision=0.9090909090909091, Recall=0.43478260869565216, F1-score=0.5882352941176471\n",
|
||||
"Class 21: Precision=0.84, Recall=0.7, F1-score=0.7636363636363636\n",
|
||||
"Class 22: Precision=0.7083333333333334, Recall=0.5483870967741935, F1-score=0.6181818181818182\n",
|
||||
"Class 23: Precision=0.5254237288135594, Recall=0.5849056603773585, F1-score=0.5535714285714286\n",
|
||||
"Class 24: Precision=0.64, Recall=0.6666666666666666, F1-score=0.6530612244897959\n",
|
||||
"Class 25: Precision=0.6, Recall=0.7058823529411765, F1-score=0.6486486486486486\n",
|
||||
"Class 26: Precision=0.78125, Recall=0.6756756756756757, F1-score=0.7246376811594203\n",
|
||||
"Class 27: Precision=0.5227272727272727, Recall=0.6571428571428571, F1-score=0.5822784810126581\n",
|
||||
"Class 28: Precision=0.5, Recall=0.48, F1-score=0.4897959183673469\n",
|
||||
"Class 29: Precision=0.6086956521739131, Recall=0.56, F1-score=0.5833333333333334\n",
|
||||
"Class 30: Precision=0.7241379310344828, Recall=0.7241379310344828, F1-score=0.7241379310344829\n",
|
||||
"Class 31: Precision=1.0, Recall=0.9696969696969697, F1-score=0.9846153846153847\n",
|
||||
"Class 32: Precision=0.7272727272727273, Recall=0.6153846153846154, F1-score=0.6666666666666667\n",
|
||||
"Class 33: Precision=0.875, Recall=0.6363636363636364, F1-score=0.7368421052631579\n",
|
||||
"Class 34: Precision=0.5609756097560976, Recall=0.6764705882352942, F1-score=0.6133333333333334\n",
|
||||
"Class 35: Precision=0.7027027027027027, Recall=0.7027027027027027, F1-score=0.7027027027027027\n",
|
||||
"Class 36: Precision=0.9259259259259259, Recall=0.78125, F1-score=0.847457627118644\n",
|
||||
"Class 37: Precision=0.7096774193548387, Recall=0.8148148148148148, F1-score=0.7586206896551724\n",
|
||||
"Class 38: Precision=0.9285714285714286, Recall=0.8125, F1-score=0.8666666666666666\n",
|
||||
"Class 39: Precision=0.8541666666666666, Recall=0.9761904761904762, F1-score=0.9111111111111111\n",
|
||||
"Class 40: Precision=0.8888888888888888, Recall=0.9411764705882353, F1-score=0.9142857142857143\n",
|
||||
"Class 41: Precision=0.7647058823529411, Recall=0.7878787878787878, F1-score=0.7761194029850745\n",
|
||||
"Class 42: Precision=0.5, Recall=0.5217391304347826, F1-score=0.5106382978723404\n",
|
||||
"Class 43: Precision=0.8888888888888888, Recall=0.47058823529411764, F1-score=0.6153846153846153\n",
|
||||
"Class 44: Precision=0.8333333333333334, Recall=0.8823529411764706, F1-score=0.8571428571428571\n",
|
||||
"Class 45: Precision=0.2926829268292683, Recall=0.48, F1-score=0.3636363636363636\n",
|
||||
"Class 46: Precision=1.0, Recall=0.98, F1-score=0.98989898989899\n",
|
||||
"Class 47: Precision=0.96, Recall=0.96, F1-score=0.96\n",
|
||||
"Class 48: Precision=0.5416666666666666, Recall=0.6190476190476191, F1-score=0.5777777777777778\n",
|
||||
"Class 49: Precision=0.92, Recall=0.8518518518518519, F1-score=0.8846153846153846\n",
|
||||
"Class 50: Precision=0.3333333333333333, Recall=0.36363636363636365, F1-score=0.34782608695652173\n",
|
||||
"Class 51: Precision=0.8461538461538461, Recall=0.825, F1-score=0.8354430379746836\n",
|
||||
"Class 52: Precision=0.5217391304347826, Recall=0.8, F1-score=0.6315789473684211\n",
|
||||
"Class 53: Precision=0.68, Recall=0.53125, F1-score=0.5964912280701754\n",
|
||||
"Class 54: Precision=0.9545454545454546, Recall=0.9767441860465116, F1-score=0.9655172413793104\n",
|
||||
"Class 55: Precision=0.8392857142857143, Recall=0.8245614035087719, F1-score=0.8318584070796461\n",
|
||||
"Class 56: Precision=0.6774193548387096, Recall=0.6774193548387096, F1-score=0.6774193548387096\n",
|
||||
"Class 57: Precision=0.8604651162790697, Recall=0.925, F1-score=0.891566265060241\n",
|
||||
"Class 58: Precision=0.8787878787878788, Recall=0.7435897435897436, F1-score=0.8055555555555556\n",
|
||||
"Class 59: Precision=0.5333333333333333, Recall=0.38095238095238093, F1-score=0.4444444444444444\n",
|
||||
"Class 60: Precision=0.6206896551724138, Recall=0.5454545454545454, F1-score=0.5806451612903226\n",
|
||||
"Class 61: Precision=0.36, Recall=0.42857142857142855, F1-score=0.391304347826087\n",
|
||||
"Class 62: Precision=0.5263157894736842, Recall=0.5, F1-score=0.5128205128205129\n",
|
||||
"Class 63: Precision=0.5625, Recall=0.6136363636363636, F1-score=0.5869565217391304\n",
|
||||
"Class 64: Precision=0.45454545454545453, Recall=0.3125, F1-score=0.3703703703703703\n",
|
||||
"Class 65: Precision=0.3090909090909091, Recall=0.4473684210526316, F1-score=0.3655913978494624\n",
|
||||
"Class 66: Precision=1.0, Recall=0.8888888888888888, F1-score=0.9411764705882353\n",
|
||||
"Class 67: Precision=0.2777777777777778, Recall=0.2777777777777778, F1-score=0.2777777777777778\n",
|
||||
"Class 68: Precision=0.6923076923076923, Recall=0.9473684210526315, F1-score=0.7999999999999999\n",
|
||||
"Class 69: Precision=0.6875, Recall=0.4583333333333333, F1-score=0.5499999999999999\n",
|
||||
"Class 70: Precision=0.9333333333333333, Recall=0.7368421052631579, F1-score=0.8235294117647058\n",
|
||||
"Class 71: Precision=0.85, Recall=0.7727272727272727, F1-score=0.8095238095238095\n",
|
||||
"Class 72: Precision=0.8076923076923077, Recall=0.7777777777777778, F1-score=0.7924528301886792\n",
|
||||
"Class 73: Precision=1.0, Recall=0.7647058823529411, F1-score=0.8666666666666666\n",
|
||||
"Class 74: Precision=0.3617021276595745, Recall=0.6071428571428571, F1-score=0.4533333333333333\n",
|
||||
"Class 75: Precision=1.0, Recall=1.0, F1-score=1.0\n",
|
||||
"Class 76: Precision=0.7, Recall=0.7, F1-score=0.7\n",
|
||||
"Class 77: Precision=0.7241379310344828, Recall=0.875, F1-score=0.7924528301886793\n",
|
||||
"Class 78: Precision=0.9090909090909091, Recall=1.0, F1-score=0.9523809523809523\n",
|
||||
"Class 79: Precision=0.7096774193548387, Recall=0.6875, F1-score=0.6984126984126984\n",
|
||||
"Class 80: Precision=0.2777777777777778, Recall=0.2631578947368421, F1-score=0.27027027027027023\n",
|
||||
"Class 81: Precision=0.803921568627451, Recall=0.9761904761904762, F1-score=0.8817204301075269\n",
|
||||
"Class 82: Precision=0.5294117647058824, Recall=0.3103448275862069, F1-score=0.391304347826087\n",
|
||||
"Class 83: Precision=0.4, Recall=0.35294117647058826, F1-score=0.37500000000000006\n",
|
||||
"Class 84: Precision=0.8181818181818182, Recall=0.84375, F1-score=0.8307692307692308\n",
|
||||
"Class 85: Precision=0.39285714285714285, Recall=0.4782608695652174, F1-score=0.4313725490196078\n",
|
||||
"Class 86: Precision=1.0, Recall=0.9534883720930233, F1-score=0.9761904761904763\n",
|
||||
"Class 87: Precision=0.7, Recall=0.7241379310344828, F1-score=0.711864406779661\n",
|
||||
"Class 88: Precision=0.8333333333333334, Recall=0.9375, F1-score=0.8823529411764706\n",
|
||||
"Class 89: Precision=0.8, Recall=0.8888888888888888, F1-score=0.8421052631578948\n",
|
||||
"Class 90: Precision=0.868421052631579, Recall=0.7857142857142857, F1-score=0.825\n",
|
||||
"Class 91: Precision=0.9615384615384616, Recall=1.0, F1-score=0.9803921568627451\n",
|
||||
"Class 92: Precision=0.8333333333333334, Recall=0.9302325581395349, F1-score=0.8791208791208791\n",
|
||||
"Class 93: Precision=0.9090909090909091, Recall=0.8108108108108109, F1-score=0.8571428571428571\n",
|
||||
"Class 94: Precision=0.925, Recall=0.925, F1-score=0.925\n",
|
||||
"Class 95: Precision=0.2857142857142857, Recall=0.2222222222222222, F1-score=0.25\n",
|
||||
"Class 96: Precision=0.53125, Recall=0.5666666666666667, F1-score=0.5483870967741935\n",
|
||||
"Class 97: Precision=0.5555555555555556, Recall=0.5882352941176471, F1-score=0.5714285714285715\n",
|
||||
"Class 98: Precision=0.78125, Recall=0.8928571428571429, F1-score=0.8333333333333334\n",
|
||||
"Class 99: Precision=0.3125, Recall=0.2631578947368421, F1-score=0.2857142857142857\n",
|
||||
"Class 100: Precision=0.6666666666666666, Recall=0.6666666666666666, F1-score=0.6666666666666666\n",
|
||||
"Accuracy: 77.82%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"precision, recall, f1_score, _ = precision_recall_fscore_support(odd_image_labels, predictions, labels=range(101))\n",
|
||||
"\n",
|
||||
"for i in range(101):\n",
|
||||
" print(f\"Class {i}: Precision={precision[i]}, Recall={recall[i]}, F1-score={f1_score[i]}\")\n",
|
||||
"\n",
|
||||
"correct_predictions = sum(1 for actual, predicted in zip(odd_image_labels, predictions) if actual == predicted)\n",
|
||||
"accuracy = (correct_predictions / len(odd_image_labels)) * 100.0\n",
|
||||
"print(f\"Accuracy: {accuracy:.2f}%\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
161
Phase 3/utils.py
161
Phase 3/utils.py
@ -68,7 +68,6 @@ def getCollection(db, collection):
|
||||
client = MongoClient("mongodb://localhost:27017")
|
||||
return client[db][collection]
|
||||
|
||||
|
||||
def euclidean_distance_measure(img_1_fd, img_2_fd):
|
||||
img_1_fd_reshaped = img_1_fd.flatten()
|
||||
img_2_fd_reshaped = img_2_fd.flatten()
|
||||
@ -86,75 +85,88 @@ valid_feature_models = {
|
||||
"resnet": "resnet_fd",
|
||||
}
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
|
||||
self.feature = feature # Index of feature to split on
|
||||
self.threshold = threshold # Threshold value for the feature
|
||||
self.left = left # Left child node
|
||||
self.right = right # Right child node
|
||||
self.value = value # Class label for leaf node (if applicable)
|
||||
|
||||
class DecisionTree:
|
||||
def __init__(self, max_depth=None):
|
||||
self.max_depth = max_depth
|
||||
self.tree = {}
|
||||
|
||||
def calculate_gini(self, labels):
|
||||
classes, counts = np.unique(labels, return_counts=True)
|
||||
probabilities = counts / len(labels)
|
||||
gini = 1 - sum(probabilities ** 2)
|
||||
return gini
|
||||
|
||||
def find_best_split(self, data, labels):
|
||||
best_gini = float('inf')
|
||||
best_index = None
|
||||
best_value = None
|
||||
|
||||
for index in range(len(data[0])):
|
||||
unique_values = np.unique(data[:, index])
|
||||
for value in unique_values:
|
||||
left_indices = np.where(data[:, index] <= value)[0]
|
||||
right_indices = np.where(data[:, index] > value)[0]
|
||||
|
||||
left_gini = self.calculate_gini(labels[left_indices])
|
||||
right_gini = self.calculate_gini(labels[right_indices])
|
||||
|
||||
gini = (len(left_indices) * left_gini + len(right_indices) * right_gini) / len(data)
|
||||
|
||||
if gini < best_gini:
|
||||
best_gini = gini
|
||||
best_index = index
|
||||
best_value = value
|
||||
|
||||
return best_index, best_value
|
||||
|
||||
def build_tree(self, data, labels, depth=0):
|
||||
if len(np.unique(labels)) == 1 or (self.max_depth and depth >= self.max_depth):
|
||||
return {'class': np.argmax(np.bincount(labels))}
|
||||
|
||||
best_index, best_value = self.find_best_split(data, labels)
|
||||
left_indices = np.where(data[:, best_index] <= best_value)[0]
|
||||
right_indices = np.where(data[:, best_index] > best_value)[0]
|
||||
|
||||
left_subtree = self.build_tree(data[left_indices], labels[left_indices], depth + 1)
|
||||
right_subtree = self.build_tree(data[right_indices], labels[right_indices], depth + 1)
|
||||
|
||||
return {'index': best_index, 'value': best_value,
|
||||
'left': left_subtree, 'right': right_subtree}
|
||||
|
||||
def fit(self, data, labels):
|
||||
self.tree = self.build_tree(data, labels)
|
||||
|
||||
def predict_sample(self, sample, tree):
|
||||
if 'class' in tree:
|
||||
return tree['class']
|
||||
self.max_depth = max_depth # Maximum depth of the tree
|
||||
self.tree = None # Root node of the tree
|
||||
|
||||
def entropy(self, y):
|
||||
_, counts = np.unique(y, return_counts=True)
|
||||
probabilities = counts / len(y)
|
||||
return -np.sum(probabilities * np.log2(probabilities))
|
||||
|
||||
def information_gain(self, X, y, feature, threshold):
|
||||
left_idxs = X[:, feature] <= threshold
|
||||
right_idxs = ~left_idxs
|
||||
|
||||
if sample[tree['index']] <= tree['value']:
|
||||
return self.predict_sample(sample, tree['left'])
|
||||
left_y = y[left_idxs]
|
||||
right_y = y[right_idxs]
|
||||
|
||||
p_left = len(left_y) / len(y)
|
||||
p_right = len(right_y) / len(y)
|
||||
|
||||
gain = self.entropy(y) - (p_left * self.entropy(left_y) + p_right * self.entropy(right_y))
|
||||
return gain
|
||||
|
||||
def find_best_split(self, X, y):
|
||||
best_gain = 0
|
||||
best_feature = None
|
||||
best_threshold = None
|
||||
|
||||
for feature in range(X.shape[1]):
|
||||
thresholds = np.unique(X[:, feature])
|
||||
for threshold in thresholds:
|
||||
gain = self.information_gain(X, y, feature, threshold)
|
||||
if gain > best_gain:
|
||||
best_gain = gain
|
||||
best_feature = feature
|
||||
best_threshold = threshold
|
||||
|
||||
return best_feature, best_threshold
|
||||
|
||||
def build_tree(self, X, y, depth=0):
|
||||
if len(np.unique(y)) == 1 or depth == self.max_depth:
|
||||
return Node(value=np.argmax(np.bincount(y)))
|
||||
|
||||
best_feature, best_threshold = self.find_best_split(X, y)
|
||||
|
||||
if best_feature is None:
|
||||
return Node(value=np.argmax(np.bincount(y)))
|
||||
|
||||
left_idxs = X[:, best_feature] <= best_threshold
|
||||
right_idxs = ~left_idxs
|
||||
|
||||
left_subtree = self.build_tree(X[left_idxs], y[left_idxs], depth + 1)
|
||||
right_subtree = self.build_tree(X[right_idxs], y[right_idxs], depth + 1)
|
||||
|
||||
return Node(feature=best_feature, threshold=best_threshold, left=left_subtree, right=right_subtree)
|
||||
|
||||
def fit(self, X, y):
|
||||
self.tree = self.build_tree(X, y)
|
||||
|
||||
def predict_instance(self, x, node):
|
||||
if node.value is not None:
|
||||
return node.value
|
||||
|
||||
if x[node.feature] <= node.threshold:
|
||||
return self.predict_instance(x, node.left)
|
||||
else:
|
||||
return self.predict_sample(sample, tree['right'])
|
||||
|
||||
def predict(self, data):
|
||||
return self.predict_instance(x, node.right)
|
||||
|
||||
def predict(self, X):
|
||||
predictions = []
|
||||
for sample in data:
|
||||
prediction = self.predict_sample(sample, self.tree)
|
||||
predictions.append(prediction)
|
||||
return predictions
|
||||
|
||||
for x in X:
|
||||
pred = self.predict_instance(x, self.tree)
|
||||
predictions.append(pred)
|
||||
return np.array(predictions)
|
||||
|
||||
class LSH:
|
||||
def __init__(self, data, num_layers, num_hashes):
|
||||
@ -163,7 +175,7 @@ class LSH:
|
||||
self.num_hashes = num_hashes
|
||||
self.hash_tables = [defaultdict(list) for _ in range(num_layers)]
|
||||
self.unique_images_considered = set()
|
||||
self.overall_images_considered = set()
|
||||
self.overall_images_considered = []
|
||||
self.create_hash_tables()
|
||||
|
||||
def hash_vector(self, vector, seed):
|
||||
@ -177,25 +189,32 @@ class LSH:
|
||||
hash_code = self.hash_vector(vector, seed=layer)
|
||||
self.hash_tables[layer][hash_code].append(i)
|
||||
|
||||
def find_similar(self, external_image, t, threshold=0.9):
|
||||
def find_similar(self, external_image, t):
|
||||
similar_images = set()
|
||||
visited_buckets = set()
|
||||
unique_images_considered = set()
|
||||
unique_images_considered = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
hash_code = self.hash_vector(external_image, seed=layer)
|
||||
visited_buckets.add(hash_code)
|
||||
|
||||
# Handling exact matches explicitly
|
||||
if hash_code in self.hash_tables[layer]:
|
||||
for idx in self.hash_tables[layer][hash_code]:
|
||||
similar_images.add(idx)
|
||||
unique_images_considered.append(idx)
|
||||
|
||||
# Searching in nearby buckets based on Hamming distance
|
||||
for key in self.hash_tables[layer]:
|
||||
if key != hash_code and self.hamming_distance(key, hash_code) <= 2:
|
||||
if self.hamming_distance(key, hash_code) <= 1:
|
||||
visited_buckets.add(key)
|
||||
|
||||
for idx in self.hash_tables[layer][key]:
|
||||
similar_images.add(idx)
|
||||
unique_images_considered.add(idx)
|
||||
unique_images_considered.append(idx)
|
||||
|
||||
self.unique_images_considered = unique_images_considered
|
||||
self.overall_images_considered = similar_images
|
||||
self.overall_images_considered = unique_images_considered
|
||||
self.unique_images_considered = set(unique_images_considered)
|
||||
|
||||
similarities = [
|
||||
(idx, self.euclidean_distance(external_image, self.data[idx])) for idx in similar_images
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user