mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 07:54:07 +00:00
display images task 4
This commit is contained in:
parent
e7777526b0
commit
e97be19053
File diff suppressed because one or more lines are too long
@ -41,6 +41,22 @@ import matplotlib.pyplot as plt
|
||||
NUM_LABELS = 101
|
||||
NUM_IMAGES = 4338
|
||||
|
||||
def datasetTransform(image):
|
||||
"""Transform while loading dataset as scaled tensors of shape (channels, (img_shape))"""
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor() # ToTensor by default scales to [0,1] range, the input range for ResNet
|
||||
]
|
||||
)(image)
|
||||
|
||||
def loadDataset(dataset):
|
||||
"""Load TorchVision dataset with the defined transform"""
|
||||
return dataset(
|
||||
root=getenv("DATASET_PATH"),
|
||||
download=False, # True if you wish to download for first time
|
||||
transform=datasetTransform,
|
||||
)
|
||||
|
||||
valid_classification_methods = {
|
||||
"m-nn": 1,
|
||||
"decision-tree": 2,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user