Skip to content
Snippets Groups Projects
Commit f13bf9bf authored by rromb's avatar rromb
Browse files

add vqgan loss with codebook statistic eval

parent 2b46bcb9
No related branches found
No related tags found
No related merge requests found
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
def l2(x, y):
return torch.pow((x-y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment