Skip to content
Snippets Groups Projects
Commit 363e8919 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Make aa_learner.translate_... robust wrt input size

parent 38498217
No related branches found
No related tags found
No related merge requests found
......@@ -177,7 +177,7 @@ class aa_learner:
mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
function = augmentation_space[fun_idx][0]
prob = prob_idx/10
prob = prob_idx/self.p_bins
indices = (fun_idx, prob_idx, mag)
......@@ -207,8 +207,8 @@ class aa_learner:
function = augmentation_space[fun_idx][0]
assert 0 <= prob <= 1
assert 0 <= mag <= self.m_bins-1
assert 0 <= prob <= 1, prob
assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins)
# if the image function does not require a magnitude, we set the magnitude to None
if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
......@@ -335,6 +335,8 @@ class aa_learner:
if isinstance(child_network_architecture, types.FunctionType):
child_network = child_network_architecture()
elif isinstance(child_network_architecture, type):
child_network = child_network_architecture()
elif isinstance(child_network_architecture, torch.nn.Module):
child_network = copy.deepcopy(child_network_architecture)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment