mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 12:44:06 +00:00
starting out
This commit is contained in:
commit
c518f2090e
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
Datasets/
|
||||
BIN
phase1_project23.pdf
Normal file
BIN
phase1_project23.pdf
Normal file
Binary file not shown.
157
test.ipynb
Normal file
157
test.ipynb
Normal file
@ -0,0 +1,157 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Getting started"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"import torchvision.models as models\n",
|
||||
"from torchinfo import summary\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from PIL import Image\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"dev = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
||||
"print(f'Using {dev} for inference')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load Caltech101 dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torchvision.datasets as datasets\n",
|
||||
"\n",
|
||||
"dataset_path = \"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\"\n",
|
||||
"\n",
|
||||
"dataset = datasets.Caltech101(root=\"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\", download=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Visualize a sample image from the dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
"sample_image, _ = dataset.__getitem__(random.randint(0,len(dataset)))\n",
|
||||
"plt.imshow(sample_image)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### ResNet50 - Example classification"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Step 1: Load model\n",
|
||||
"from torchvision.models import ResNet50_Weights\n",
|
||||
"\n",
|
||||
"weights = ResNet50_Weights.DEFAULT\n",
|
||||
"model = models.resnet50(weights)\n",
|
||||
"\n",
|
||||
"if(torch.cuda.is_available()):\n",
|
||||
" model = model.to(dev)\n",
|
||||
"\n",
|
||||
"model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchvision.io import read_image\n",
|
||||
"\n",
|
||||
"# Step 2: Initialize the inference transforms\n",
|
||||
"preprocess = weights.transforms()\n",
|
||||
"\n",
|
||||
"img = sample_image\n",
|
||||
"# Step 3: Apply inference preprocessing transforms\n",
|
||||
"batch = preprocess(img).unsqueeze(0)\n",
|
||||
"\n",
|
||||
"# (convert to CUDA tensor)\n",
|
||||
"batch = batch.to(dev)\n",
|
||||
"# Step 4: Use the model and print the predicted category\n",
|
||||
"\n",
|
||||
"prediction = model(batch).squeeze(0).softmax(0)\n",
|
||||
"class_id = prediction.argmax().item()\n",
|
||||
"score = prediction[class_id].item()\n",
|
||||
"category_name = weights.meta[\"categories\"][class_id]\n",
|
||||
"print(f\"{category_name}: {100 * score:.1f}%\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user