CSE515_MWDB_Project/test.ipynb
2023-09-02 17:28:29 -07:00

158 lines
3.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting started"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torchvision.transforms as transforms\n",
"import torchvision.models as models\n",
"from torchinfo import summary\n",
"\n",
"\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"%matplotlib inline\n",
"\n",
"dev = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"print(f'Using {dev} for inference')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Caltech101 dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torchvision.datasets as datasets\n",
"\n",
"dataset_path = \"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\"\n",
"\n",
"dataset = datasets.Caltech101(root=\"C:\\Kaushik\\ASU\\CSE 515 - Multimedia and Web Databases\\Project\\Datasets\", download=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize a sample image from the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"sample_image, _ = dataset.__getitem__(random.randint(0,len(dataset)))\n",
"plt.imshow(sample_image)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ResNet50 - Example classification"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Step 1: Load model\n",
"from torchvision.models import ResNet50_Weights\n",
"\n",
"weights = ResNet50_Weights.DEFAULT\n",
"model = models.resnet50(weights)\n",
"\n",
"if(torch.cuda.is_available()):\n",
" model = model.to(dev)\n",
"\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torchvision.io import read_image\n",
"\n",
"# Step 2: Initialize the inference transforms\n",
"preprocess = weights.transforms()\n",
"\n",
"img = sample_image\n",
"# Step 3: Apply inference preprocessing transforms\n",
"batch = preprocess(img).unsqueeze(0)\n",
"\n",
"# (convert to CUDA tensor)\n",
"batch = batch.to(dev)\n",
"# Step 4: Use the model and print the predicted category\n",
"\n",
"prediction = model(batch).squeeze(0).softmax(0)\n",
"class_id = prediction.argmax().item()\n",
"score = prediction[class_id].item()\n",
"category_name = weights.meta[\"categories\"][class_id]\n",
"print(f\"{category_name}: {100 * score:.1f}%\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}