diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 3fb9da5f2c6782d98224784f97cef7a6a05fa5a0..4b7cd86ab04508bf77f69a92e5615f9faaeebc03 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -59,7 +59,7 @@ class aa_learner: self.history = [] - def translate_operation_tensor(self, operation_tensor): + def translate_operation_tensor(self, operation_tensor, argmax=False): ''' takes in a tensor representing an operation and returns an actual operation which is in the form of: @@ -69,12 +69,16 @@ class aa_learner: Args: operation_tensor (tensor): - - If discrete_p_m is True, we expect to take in a tensor with + We expect this tensor to already have been softmaxed. + Furthermore, + - If self.discrete_p_m is True, we expect to take in a tensor with dimension (self.fun_num + self.p_bins + self.m_bins) - - If discrete_p_m is False, we expect to take in a tensor with + - If self.discrete_p_m is False, we expect to take in a tensor with dimension (self.fun_num + 1 + 1) - continuous_p_m (boolean): whether the operation_tensor has continuous representations - of probability and magnitude + + argmax (boolean): + Whether we are taking the argmax of the softmaxed tensors. + If this is False, we treat the softmaxed outputs as multinomial pdf's. ''' # if probability and magnitude are represented as discrete variables if self.discrete_p_m: @@ -82,9 +86,23 @@ class aa_learner: prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins] mag_t = operation_tensor[-self.m_bins : ] - fun = torch.argmax(fun_t) - prob = torch.argmax(prob_t) # 0 <= p <= 10 - mag = torch.argmax(mag_t) # 0 <= m <= 9 + # make sure they are of right size + assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' + assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}' + assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}' + + if argmax==True: + fun = torch.argmax(fun_t) + prob = torch.argmax(prob_t) # 0 <= p <= 10 + mag = torch.argmax(mag_t) # 0 <= m <= 9 + elif argmax==False: + # we need these to add up to 1 to be valid pdf's of multinomials + assert torch.sum(fun_t)==1 + assert torch.sum(prob_t)==1 + assert torch.sum(mag_t)==1 + fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1 + prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10 + mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9 function = augmentation_space[fun][0] prob = prob/10 @@ -96,17 +114,31 @@ class aa_learner: p = operation_tensor[-2].item() # 0 < p < 1 m = operation_tensor[-1].item() # 0 < m < 9 - fun = torch.argmax(fun_t) - + # make sure the shape is correct + assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' + + if argmax==True: + fun = torch.argmax(fun_t) + elif argmax==False: + assert torch.sum(fun_t)==1 + fun = torch.multinomial(fun_t, 1) + function = augmentation_space[fun][0] prob = round(p, 1) # round to nearest first decimal digit mag = round(m) # round to nearest integer + # If argmax is False, we treat operation_tensor as a concatenation of three + # multinomial pdf's. + assert 0 <= prob <= 1 + assert 0 <= mag <= self.m_bins-1 + # if the image function does not require a magnitude, we set the magnitude to None - if augmentation_space[fun][0] == True: # if the image function has a magnitude + if augmentation_space[fun][1] == True: # if the image function has a magnitude return (function, prob, mag) else: return (function, prob, None) + + def generate_new_policy(self): diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index da76cf0b16bdd829e19fbbd649b37d5577df586a..3657224f7bb538c0c436c3f12ff451ce1ccc2c1b 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -63,11 +63,11 @@ class randomsearch_learner(aa_learner): random_mag = np.random.randint(0, self.m_bins) fun_t= torch.zeros(self.fun_num) - fun_t[random_fun] = 1 + fun_t[random_fun] = 1.0 prob_t = torch.zeros(self.p_bins) - prob_t[random_prob] = 1 + prob_t[random_prob] = 1.0 mag_t = torch.zeros(self.m_bins) - mag_t[random_mag] = 1 + mag_t[random_mag] = 1.0 return torch.cat([fun_t, prob_t, mag_t]) @@ -152,10 +152,10 @@ if __name__=='__main__': # We can initialize the train_dataset with its transform as None. # Later on, we will change this object's transform attribute to the policy # that we want to test - train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, - transform=None) - test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, - transform=torchvision.transforms.ToTensor()) + train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', + train=True, download=True, transform=None) + test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', + train=False, download=True, transform=torchvision.transforms.ToTensor()) child_network = cn.lenet