mirror of
https://github.com/20kaushik02/CSE515_MWDB_Project.git
synced 2025-12-06 09:34:07 +00:00
display images task 4
This commit is contained in:
parent
e7777526b0
commit
e97be19053
File diff suppressed because one or more lines are too long
@ -41,6 +41,22 @@ import matplotlib.pyplot as plt
|
|||||||
NUM_LABELS = 101
|
NUM_LABELS = 101
|
||||||
NUM_IMAGES = 4338
|
NUM_IMAGES = 4338
|
||||||
|
|
||||||
|
def datasetTransform(image):
|
||||||
|
"""Transform while loading dataset as scaled tensors of shape (channels, (img_shape))"""
|
||||||
|
return transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor() # ToTensor by default scales to [0,1] range, the input range for ResNet
|
||||||
|
]
|
||||||
|
)(image)
|
||||||
|
|
||||||
|
def loadDataset(dataset):
|
||||||
|
"""Load TorchVision dataset with the defined transform"""
|
||||||
|
return dataset(
|
||||||
|
root=getenv("DATASET_PATH"),
|
||||||
|
download=False, # True if you wish to download for first time
|
||||||
|
transform=datasetTransform,
|
||||||
|
)
|
||||||
|
|
||||||
valid_classification_methods = {
|
valid_classification_methods = {
|
||||||
"m-nn": 1,
|
"m-nn": 1,
|
||||||
"decision-tree": 2,
|
"decision-tree": 2,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user