Commit 824bb396 authored by cc215's avatar cc215 💬
Browse files

add LRU cache supportttto avoid memory overload when training dataset is too large

parent 62f69885
from collections import OrderedDict, MutableMapping
class Cache(MutableMapping):
## Cache with limited maximum capacit.
## This is a simplified LRU caching scheme,
# when the cache is full and a new page is referenced which is not there in cache,
# it will remove the least recently used frame to spare space for new page.
## source: https://stackoverflow.com/questions/2437617/how-to-limit-the-size-of-a-dictionary
def __init__(self, maxlen, items=None):
self._maxlen = maxlen
self.d = OrderedDict()
if items:
for k, v in items:
self[k] = v
@property
def maxlen(self):
return self._maxlen
def __getitem__(self, key):
self.d.move_to_end(key)
return self.d[key]
def __setitem__(self, key, value):
if key in self.d:
self.d.move_to_end(key)
elif len(self.d) == self.maxlen:
self.d.popitem(last=False)
self.d[key] = value
def __delitem__(self, key):
del self.d[key]
def __iter__(self):
return self.d.__iter__()
def __len__(self):
return len(self.d)
\ No newline at end of file
......@@ -9,11 +9,11 @@ from torch.utils.data import Dataset
import numpy as np
from common_utils.basic_operations import switch_kv_in_dict
from common_utils.data_structure import Cache
class BaseSegDataset(Dataset):
def __init__(self, dataset_name, transform, no_aug_transform,image_size, label_size, idx2cls_dict=None, num_classes=2,
use_cache=False, formalized_label_dict=None,keep_orig_image_label_pair=False):
use_cache=False, formalized_label_dict=None,keep_orig_image_label_pair=False,maximum_cache_size=20000):
'''
:param dataset_name:
......@@ -47,7 +47,7 @@ class BaseSegDataset(Dataset):
# num_classes)
self.formalized_label_dict = self.idx2cls_dict if formalized_label_dict is None else formalized_label_dict
self.use_cache = use_cache
self.cache_dict = {}
self.cache_dict = Cache(maxlen=maximum_cache_size)
self.index = 0
self.voxelspacing = [1., 1., 1.]
self.keep_orig_image_label_pair=keep_orig_image_label_pair
......
......@@ -321,7 +321,7 @@ if __name__ == '__main__':
pad_size = (256, 256, 1)
crop_size = (192, 192, 1)
tr = Transformations(data_aug_policy_name='UKBB_advancedv4', pad_size=pad_size, crop_size=crop_size).get_transformation()
dataset = CardiacUKBBDataset(debug=True,transform=tr['train'], if_resample=True,
dataset = CardiacUKBBDataset(debug=True,transform=tr['train'], if_resample=True, use_cache=True,
no_aug_transform=tr['validate'],formalized_label_dict={0: 'BG', 1: 'LV',2:'MYO',3:'RV'})
train_loader = DataLoader(dataset=dataset, num_workers=0, batch_size=1, shuffle=True, drop_last=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment