Skip to content
Snippets Groups Projects
Commit 92d97847 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

translate_operation_tensor can now translate probabilistically. (Using tensor as multinomial pdf's)

parent 70d8d4de
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,7 @@ class aa_learner:
self.history = []
def translate_operation_tensor(self, operation_tensor):
def translate_operation_tensor(self, operation_tensor, argmax=False):
'''
takes in a tensor representing an operation and returns an actual operation which
is in the form of:
......@@ -69,12 +69,16 @@ class aa_learner:
Args:
operation_tensor (tensor):
- If discrete_p_m is True, we expect to take in a tensor with
We expect this tensor to already have been softmaxed.
Furthermore,
- If self.discrete_p_m is True, we expect to take in a tensor with
dimension (self.fun_num + self.p_bins + self.m_bins)
- If discrete_p_m is False, we expect to take in a tensor with
- If self.discrete_p_m is False, we expect to take in a tensor with
dimension (self.fun_num + 1 + 1)
continuous_p_m (boolean): whether the operation_tensor has continuous representations
of probability and magnitude
argmax (boolean):
Whether we are taking the argmax of the softmaxed tensors.
If this is False, we treat the softmaxed outputs as multinomial pdf's.
'''
# if probability and magnitude are represented as discrete variables
if self.discrete_p_m:
......@@ -82,9 +86,23 @@ class aa_learner:
prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins]
mag_t = operation_tensor[-self.m_bins : ]
fun = torch.argmax(fun_t)
prob = torch.argmax(prob_t) # 0 <= p <= 10
mag = torch.argmax(mag_t) # 0 <= m <= 9
# make sure they are of right size
assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
if argmax==True:
fun = torch.argmax(fun_t)
prob = torch.argmax(prob_t) # 0 <= p <= 10
mag = torch.argmax(mag_t) # 0 <= m <= 9
elif argmax==False:
# we need these to add up to 1 to be valid pdf's of multinomials
assert torch.sum(fun_t)==1
assert torch.sum(prob_t)==1
assert torch.sum(mag_t)==1
fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
function = augmentation_space[fun][0]
prob = prob/10
......@@ -96,17 +114,31 @@ class aa_learner:
p = operation_tensor[-2].item() # 0 < p < 1
m = operation_tensor[-1].item() # 0 < m < 9
fun = torch.argmax(fun_t)
# make sure the shape is correct
assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
if argmax==True:
fun = torch.argmax(fun_t)
elif argmax==False:
assert torch.sum(fun_t)==1
fun = torch.multinomial(fun_t, 1)
function = augmentation_space[fun][0]
prob = round(p, 1) # round to nearest first decimal digit
mag = round(m) # round to nearest integer
# If argmax is False, we treat operation_tensor as a concatenation of three
# multinomial pdf's.
assert 0 <= prob <= 1
assert 0 <= mag <= self.m_bins-1
# if the image function does not require a magnitude, we set the magnitude to None
if augmentation_space[fun][0] == True: # if the image function has a magnitude
if augmentation_space[fun][1] == True: # if the image function has a magnitude
return (function, prob, mag)
else:
return (function, prob, None)
def generate_new_policy(self):
......
......@@ -63,11 +63,11 @@ class randomsearch_learner(aa_learner):
random_mag = np.random.randint(0, self.m_bins)
fun_t= torch.zeros(self.fun_num)
fun_t[random_fun] = 1
fun_t[random_fun] = 1.0
prob_t = torch.zeros(self.p_bins)
prob_t[random_prob] = 1
prob_t[random_prob] = 1.0
mag_t = torch.zeros(self.m_bins)
mag_t[random_mag] = 1
mag_t[random_mag] = 1.0
return torch.cat([fun_t, prob_t, mag_t])
......@@ -152,10 +152,10 @@ if __name__=='__main__':
# We can initialize the train_dataset with its transform as None.
# Later on, we will change this object's transform attribute to the policy
# that we want to test
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=None)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test',
train=False, download=True, transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment