display images task 4

This commit is contained in:
pranavbrkr 2023-11-27 19:16:32 -07:00
parent e7777526b0
commit e97be19053
2 changed files with 109 additions and 17 deletions

File diff suppressed because one or more lines are too long

View File

@ -41,6 +41,22 @@ import matplotlib.pyplot as plt
NUM_LABELS = 101
NUM_IMAGES = 4338
def datasetTransform(image):
"""Transform while loading dataset as scaled tensors of shape (channels, (img_shape))"""
return transforms.Compose(
[
transforms.ToTensor() # ToTensor by default scales to [0,1] range, the input range for ResNet
]
)(image)
def loadDataset(dataset):
"""Load TorchVision dataset with the defined transform"""
return dataset(
root=getenv("DATASET_PATH"),
download=False, # True if you wish to download for first time
transform=datasetTransform,
)
valid_classification_methods = {
"m-nn": 1,
"decision-tree": 2,