From 7c0c4a183830ed0abb096e57223d3198c18daae3 Mon Sep 17 00:00:00 2001 From: Kaushik Narayan R Date: Sat, 14 Oct 2023 12:30:18 -0700 Subject: [PATCH] nmf update, use own svd nmf --- .gitignore | 3 ++- Phase 2/utils.py | 60 ++++++++++++++++++------------------------------ 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 8bf7c7d..74f653e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ Datasets/ Other code/ *.zip *.env -__pycache__ \ No newline at end of file +__pycache__ +*.json diff --git a/Phase 2/utils.py b/Phase 2/utils.py index 263cc17..5153773 100644 --- a/Phase 2/utils.py +++ b/Phase 2/utils.py @@ -5,10 +5,9 @@ import random import cv2 import numpy as np from scipy.stats import pearsonr -from scipy.sparse.linalg import svds -from sklearn.decomposition import NMF +# from scipy.sparse.linalg import svds +# from sklearn.decomposition import NMF from sklearn.decomposition import LatentDirichletAllocation -from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # from sklearn.cluster import KMeans @@ -286,7 +285,7 @@ def resnet_output(image): with torch.no_grad(): features = model(resized_image) features = torch.nn.Softmax()(features) - + return features.detach().cpu().tolist() @@ -817,17 +816,25 @@ def svd(matrix, k): return left_singular_vectors, np.diag(singular_values), right_singular_vectors.T -def nmf(matrix, k, num_iterations=100): + +def nmf(matrix, k, H=None, update_H=True, num_iterations=100): + """ + Non-negative matrix factorization by multiplicative update + + Pass `H` and `update_H=False` to transform given data as per the given H matrix, else leave `H=None` and `update_H=True` to fit and transform + """ d1, d2 = matrix.shape # Initialize W and H matrices with random non-negative values W = np.random.rand(d1, k) - H = np.random.rand(k, d2) + if update_H is True: + H = np.random.rand(k, d2) for iteration in range(num_iterations): - # Update H matrix - numerator_h = np.dot(W.T, matrix) - denominator_h = np.dot(np.dot(W.T, W), H) - H *= numerator_h / denominator_h + if update_H is True: + # Update H matrix + numerator_h = np.dot(W.T, matrix) + denominator_h = np.dot(np.dot(W.T, W), H) + H *= numerator_h / denominator_h # Update W matrix numerator_w = np.dot(matrix, H.T) @@ -836,6 +843,7 @@ def nmf(matrix, k, num_iterations=100): return W, H + def extract_latent_semantics_from_feature_model( fd_collection, k, @@ -879,9 +887,8 @@ def extract_latent_semantics_from_feature_model( match valid_dim_reduction_methods[dim_reduction_method]: # singular value decomposition - # sparse version of SVD to get only k singular values case 1: - U, S, V_T = svds(feature_vectors, k=k) + U, S, V_T = svd(feature_vectors, k=k) all_latent_semantics = { "image-semantic": U.tolist(), @@ -906,18 +913,7 @@ def extract_latent_semantics_from_feature_model( min_value = np.min(feature_vectors) feature_vectors_shifted = feature_vectors - min_value - model = NMF( - n_components=k, - init="random", - solver="cd", - alpha_H=0.01, - alpha_W=0.01, - max_iter=10000, - ) - model.fit(feature_vectors_shifted) - - W = model.transform(feature_vectors_shifted) - H = model.components_ + W, H = nmf(feature_vectors_shifted, k) all_latent_semantics = { "image-semantic": W.tolist(), @@ -1053,9 +1049,8 @@ def extract_latent_semantics_from_sim_matrix( match valid_dim_reduction_methods[dim_reduction_method]: # singular value decomposition - # sparse version of SVD to get only k singular values case 1: - U, S, V_T = svds(feature_vectors, k=k) + U, S, V_T = svd(feature_vectors, k=k) all_latent_semantics = { "image-semantic": U.tolist(), @@ -1080,18 +1075,7 @@ def extract_latent_semantics_from_sim_matrix( min_value = np.min(feature_vectors) feature_vectors_shifted = feature_vectors - min_value - model = NMF( - n_components=k, - init="random", - solver="cd", - alpha_H=0.01, - alpha_W=0.01, - max_iter=10000, - ) - model.fit(feature_vectors_shifted) - - W = model.transform(feature_vectors_shifted) - H = model.components_ + W, H = nmf(feature_vectors_shifted, k) all_latent_semantics = { "image-semantic": W.tolist(),