mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 10:34:07 +00:00
158 lines
3.6 KiB
Plaintext
158 lines
3.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Getting started"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"import torchvision.models as models\n",
|
|
"from torchinfo import summary\n",
|
|
"\n",
|
|
"\n",
|
|
"from PIL import Image\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"import warnings\n",
|
|
"warnings.filterwarnings('ignore')\n",
|
|
"%matplotlib inline\n",
|
|
"\n",
|
|
"dev = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
|
"print(f'Using {dev} for inference')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load Caltech101 dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torchvision.datasets as datasets\n",
|
|
"\n",
|
|
"dataset_path = \"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\"\n",
|
|
"\n",
|
|
"dataset = datasets.Caltech101(root=\"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\", download=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Visualize a sample image from the dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"sample_image, _ = dataset.__getitem__(random.randint(0,len(dataset)))\n",
|
|
"plt.imshow(sample_image)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### ResNet50 - Example classification"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Step 1: Load model\n",
|
|
"from torchvision.models import ResNet50_Weights\n",
|
|
"\n",
|
|
"weights = ResNet50_Weights.DEFAULT\n",
|
|
"model = models.resnet50(weights)\n",
|
|
"\n",
|
|
"if(torch.cuda.is_available()):\n",
|
|
" model = model.to(dev)\n",
|
|
"\n",
|
|
"model.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torchvision.io import read_image\n",
|
|
"\n",
|
|
"# Step 2: Initialize the inference transforms\n",
|
|
"preprocess = weights.transforms()\n",
|
|
"\n",
|
|
"img = sample_image\n",
|
|
"# Step 3: Apply inference preprocessing transforms\n",
|
|
"batch = preprocess(img).unsqueeze(0)\n",
|
|
"\n",
|
|
"# (convert to CUDA tensor)\n",
|
|
"batch = batch.to(dev)\n",
|
|
"# Step 4: Use the model and print the predicted category\n",
|
|
"\n",
|
|
"prediction = model(batch).squeeze(0).softmax(0)\n",
|
|
"class_id = prediction.argmax().item()\n",
|
|
"score = prediction[class_id].item()\n",
|
|
"category_name = weights.meta[\"categories\"][class_id]\n",
|
|
"print(f\"{category_name}: {100 * score:.1f}%\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.5"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|