Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions mapreader/classify/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,33 @@ def __len__(self) -> int:
"""
return len(self.patch_df)

def get_patch(self,idx : int | torch.Tensor) -> Image:
"""
Return the image at the given index.

Parameters
----------
idx : int or Tensor
The index of the desired image, or a Tensor containing the index.

Returns
-------
PIL.Image.Image
The transformed image at the given index.
"""
if torch.is_tensor(idx):
idx = idx.tolist()

img_path = self.patch_df.iloc[idx][self.patch_paths_col]

if os.path.exists(img_path):
img = Image.open(img_path).convert(self.image_mode)
else:
raise ValueError(
f'[ERROR] "{img_path} cannot be found.\n\n\'Please check the image exists, your file paths are correct and that ``.patch_paths_col`` is set to the correct column.'
)
return self.transform(img)

def __getitem__(
self, idx: int | torch.Tensor
) -> tuple[tuple[torch.Tensor], str, int]:
Expand All @@ -203,17 +230,8 @@ def __getitem__(
if torch.is_tensor(idx):
idx = idx.tolist()

img_path = self.patch_df.iloc[idx][self.patch_paths_col]

if os.path.exists(img_path):
img = Image.open(img_path).convert(self.image_mode)
else:
raise ValueError(
f'[ERROR] "{img_path} cannot be found.\n\n\
Please check the image exists, your file paths are correct and that ``.patch_paths_col`` is set to the correct column.'
)

img = self.transform(img)
img = self.get_patch(idx)

if self.label_col in self.patch_df.iloc[idx].keys():
image_label = self.patch_df.iloc[idx][self.label_col]
Expand Down Expand Up @@ -844,3 +862,24 @@ def __getitem__(
image_label_index = -1

return (context_img,), image_label, image_label_index

class PatchFromImageDataset(PatchDataset):
# overwrite the __getitem__ method to return the image patch
def __init__(self, args*, **kwargs):
super().__init__(args*, **kwargs)
#if "method" in kwargs:
# self.method = kwargs["method"]
#else:
# self.method = "pixels"
def get_patch(self, idx):
# load original image
img= self.return_orig_image(idx)

# use self.patch_df to find which part of the image to return
bounds=self.patch_df.iloc[idx]["pixel_bounds"]
img= img.crop(bounds)


return self.transform(img)


Loading