handle ls2

This commit is contained in:
Niraj Sonje 2023-10-13 20:50:03 -07:00
parent c652a6606e
commit b657e75c42

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -21,7 +21,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -35,7 +35,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -45,14 +45,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"label_sim-cm_fd-kmeans-10-semantics.json loaded\n" "cm_fd-cp-10-semantics.json loaded\n"
] ]
} }
], ],
@ -69,6 +69,7 @@
"if k < 1:\n", "if k < 1:\n",
" raise ValueError(\"k should be a positive integer\")\n", " raise ValueError(\"k should be a positive integer\")\n",
"\n", "\n",
"if selected_latent_space != 'cp':\n",
" selected_dim_reduction_method = str(\n", " selected_dim_reduction_method = str(\n",
" input(\n", " input(\n",
" \"Enter dimensionality reduction method - one of \"\n", " \"Enter dimensionality reduction method - one of \"\n",
@ -109,7 +110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -187,7 +188,39 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def extract_similarities_ls2(data, label):\n",
"\n",
" LS_f = np.array(data[\"feature-semantic\"])\n",
" LS_i = np.array(data[\"image-semantic\"])\n",
" S = np.array(data[\"semantics-core\"])\n",
"\n",
" if len(S.shape) == 1:\n",
" S = np.diag(S)\n",
"\n",
" label_rep = calculate_label_representatives(fd_collection, label, selected_feature_model)\n",
" comparison_feature_space = np.matmul(label_rep, LS_f)\n",
" comparison_vector = np.matmul(comparison_feature_space, S)\n",
"\n",
" comparison_image_space = np.matmul(LS_i, S)\n",
" distances = []\n",
"\n",
" n = len(comparison_image_space)\n",
" for i in range(n):\n",
" distances.append({\"image\": i, \"distance\": math.dist(comparison_vector, comparison_image_space[i])})\n",
" \n",
" distances = sorted(distances, key=lambda x: x[\"distance\"], reverse=False)[:knum]\n",
"\n",
" for x in distances:\n",
" print(x)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -237,18 +270,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"{'image_id': 2309, 'distance': 4.117664288663269}\n", "{'image': 823, 'distance': 4006.335159603778}\n",
"{'image_id': 1930, 'distance': 4.117664288663269}\n", "{'image': 809, 'distance': 4006.3942621209867}\n",
"{'image_id': 1940, 'distance': 4.117664288663269}\n", "{'image': 806, 'distance': 4006.421689986329}\n",
"{'image_id': 1929, 'distance': 4.117664288663269}\n", "{'image': 832, 'distance': 4006.422683206996}\n",
"{'image_id': 2250, 'distance': 4.117664288663269}\n" "{'image': 830, 'distance': 4006.44733072835}\n"
] ]
} }
], ],
@ -262,6 +295,10 @@
" case \"label_sim\":\n", " case \"label_sim\":\n",
"\n", "\n",
" extract_similarities_ls3(selected_dim_reduction_method, data, label)\n", " extract_similarities_ls3(selected_dim_reduction_method, data, label)\n",
"\n",
" case \"cp\":\n",
"\n",
" extract_similarities_ls2(data, label)\n",
" " " "
] ]
} }