diff --git a/README.md b/README.md
index 50d768a1927623dd40cd1895a06e1ce81a5c321b..b162fa20a701a6caf7b6c96cea856e06d668cd56 100644
--- a/README.md
+++ b/README.md
@@ -55,18 +55,7 @@ bash scripts/download_first_stages.sh
 ```
 
 The first stage models can then be found in `models/first_stage_models/<model_spec>`
-### Training autoencoder models
 
-Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
-Training can be started by running
-```
-CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec> -t --gpus 0,    
-```
-where `config_spec` is one of {`autoencoder_kl_8x8x64.yaml`(f=32, d=64), `autoencoder_kl_16x16x16.yaml`(f=16, d=16), 
-`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
-
-For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers) 
-repository.
 
 
 ## Pretrained LDMs
@@ -78,9 +67,10 @@ repository.
 | LSUN-Bedrooms                   | Unconditional Image Synthesis   |  LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0)          | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip |                                                 |  
 | ImageNet                        | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** |   https://ommer-lab.com/files/latent-diffusion/cin.zip                                                                   | *: w/ guiding, classifier_scale 10  **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |   
 | Conceptual Captions             |  Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79         | 13.89           | N/A | N/A |              https://ommer-lab.com/files/latent-diffusion/text2img.zip                                | finetuned from LAION                            |   
-| OpenImages                      | Super-resolution   | N/A           | N/A            | N/A               | N/A    | N/A    |                                    https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip                                    | BSR image degradation                           |
+| OpenImages                      | Super-resolution   | LDM-VQ-4     | N/A            | N/A               | N/A    | N/A    |                                    https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip                                    | BSR image degradation                           |
 | OpenImages                      | Layout-to-Image Synthesis    | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02         | 15.92           | N/A    | N/A    |                  https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip                                           |                                                 | 
-| Landscapes      (finetuned 512) |  Semantic Image Synthesis   | LDM-VQ-4 (100 DDIM steps, eta=1) | N/A             | N/A               | N/A    | N/A    |           https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip                                    |                                                 |
+| Landscapes      |  Semantic Image Synthesis   | LDM-VQ-4  | N/A             | N/A               | N/A    | N/A    |           https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip                                    |                                                 |
+| Landscapes       |  Semantic Image Synthesis   | LDM-VQ-4  | N/A             | N/A               | N/A    | N/A    |           https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip                                    |             finetuned on resolution 512x512                                     |
 
 
 ### Get the models
@@ -116,10 +106,90 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp
 `indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
 the examples provided in `data/inpainting_examples`.
 
+
+# Train your own LDMs
+
+## Data preparation
+
+### Faces 
+For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq) 
+repository.
+
+### LSUN 
+
+The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
+We performed a custom split into training and validation images, and provide the corresponding filenames
+at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip). 
+After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
+also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
+
+### ImageNet
+The code will try to download (through [Academic
+Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
+is used. However, since ImageNet is quite large, this requires a lot of disk
+space and time. If you already have ImageNet on your disk, you can speed things
+up by putting the data into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
+`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
+of `train`/`validation`. It should have the following structure:
+
+```
+${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
+├── n01440764
+│   ├── n01440764_10026.JPEG
+│   ├── n01440764_10027.JPEG
+│   ├── ...
+├── n01443537
+│   ├── n01443537_10007.JPEG
+│   ├── n01443537_10014.JPEG
+│   ├── ...
+├── ...
+```
+
+If you haven't extracted the data, you can also place
+`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
+extracted into above structure without downloading it again.  Note that this
+will only happen if neither a folder
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
+if you want to force running the dataset preparation again.
+
+
+## Model Training
+
+Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
+
+### Training autoencoder models
+
+Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
+Training can be started by running
+```
+CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,    
+```
+where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16), 
+`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
+
+For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers) 
+repository.
+
+### Training LDMs 
+
+In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets. 
+Training can be started by running
+
+```shell script
+CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
+``` 
+
+where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
+`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
+`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
+
 ## Coming Soon...
 
-* Code for training LDMs and the corresponding compression models.
-* Inference scripts for conditional LDMs for various conditioning modalities.
+* More inference scripts for conditional LDMs.
 * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
 * We will also release some further pretrained models.
 
diff --git a/configs/latent-diffusion/celebahq-ldm-vq-4.yaml b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..89b3df4fe1822295d509dca2237ea891cdd964bf
--- /dev/null
+++ b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml
@@ -0,0 +1,86 @@
+model:
+  base_learning_rate: 2.0e-06
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.0015
+    linear_end: 0.0195
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: image
+    image_size: 64
+    channels: 3
+    monitor: val/loss_simple_ema
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 64
+        in_channels: 3
+        out_channels: 3
+        model_channels: 224
+        attention_resolutions:
+        # note: this isn\t actually the resolution but
+        # the downsampling factor, i.e. this corresnponds to
+        # attention on spatial resolution 8,16,32, as the
+        # spatial reolution of the latents is 64 for f4
+        - 8
+        - 4
+        - 2
+        num_res_blocks: 2
+        channel_mult:
+        - 1
+        - 2
+        - 3
+        - 4
+        num_head_channels: 32
+    first_stage_config:
+      target: ldm.models.autoencoder.VQModelInterface
+      params:
+        embed_dim: 3
+        n_embed: 8192
+        ckpt_path: models/first_stage_models/vq-f4/model.ckpt
+        ddconfig:
+          double_z: false
+          z_channels: 3
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+    cond_stage_config: __is_unconditional__
+data:
+  target: main.DataModuleFromConfig
+  params:
+    batch_size: 48
+    num_workers: 5
+    wrap: false
+    train:
+      target: taming.data.faceshq.CelebAHQTrain
+      params:
+        size: 256
+    validation:
+      target: taming.data.faceshq.CelebAHQValidation
+      params:
+        size: 256
+
+
+lightning:
+  callbacks:
+    image_logger:
+      target: main.ImageLogger
+      params:
+        batch_frequency: 5000
+        max_images: 8
+        increase_log_steps: False
+
+  trainer:
+    benchmark: True
\ No newline at end of file
diff --git a/configs/latent-diffusion/cin-ldm-vq-f8.yaml b/configs/latent-diffusion/cin-ldm-vq-f8.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b8cd9e2ef5d26870bdbb26bf04a9b47aaa78feeb
--- /dev/null
+++ b/configs/latent-diffusion/cin-ldm-vq-f8.yaml
@@ -0,0 +1,98 @@
+model:
+  base_learning_rate: 1.0e-06
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.0015
+    linear_end: 0.0195
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: image
+    cond_stage_key: class_label
+    image_size: 32
+    channels: 4
+    cond_stage_trainable: true
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 32
+        in_channels: 4
+        out_channels: 4
+        model_channels: 256
+        attention_resolutions:
+        #note: this isn\t actually the resolution but
+        # the downsampling factor, i.e. this corresnponds to
+        # attention on spatial resolution 8,16,32, as the
+        # spatial reolution of the latents is 32 for f8
+        - 4
+        - 2
+        - 1
+        num_res_blocks: 2
+        channel_mult:
+        - 1
+        - 2
+        - 4
+        num_head_channels: 32
+        use_spatial_transformer: true
+        transformer_depth: 1
+        context_dim: 512
+    first_stage_config:
+      target: ldm.models.autoencoder.VQModelInterface
+      params:
+        embed_dim: 4
+        n_embed: 16384
+        ckpt_path: configs/first_stage_models/vq-f8/model.yaml
+        ddconfig:
+          double_z: false
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 2
+          - 4
+          num_res_blocks: 2
+          attn_resolutions:
+          - 32
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.ClassEmbedder
+      params:
+        embed_dim: 512
+        key: class_label
+data:
+  target: main.DataModuleFromConfig
+  params:
+    batch_size: 64
+    num_workers: 12
+    wrap: false
+    train:
+      target: ldm.data.imagenet.ImageNetTrain
+      params:
+        config:
+          size: 256
+    validation:
+      target: ldm.data.imagenet.ImageNetValidation
+      params:
+        config:
+          size: 256
+
+
+lightning:
+  callbacks:
+    image_logger:
+      target: main.ImageLogger
+      params:
+        batch_frequency: 5000
+        max_images: 8
+        increase_log_steps: False
+
+  trainer:
+    benchmark: True
\ No newline at end of file
diff --git a/configs/latent-diffusion/ffhq-ldm-vq-4.yaml b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1899e30f77222142d7b33f45a6dcff086a31e174
--- /dev/null
+++ b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml
@@ -0,0 +1,85 @@
+model:
+  base_learning_rate: 2.0e-06
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.0015
+    linear_end: 0.0195
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: image
+    image_size: 64
+    channels: 3
+    monitor: val/loss_simple_ema
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 64
+        in_channels: 3
+        out_channels: 3
+        model_channels: 224
+        attention_resolutions:
+        # note: this isn\t actually the resolution but
+        # the downsampling factor, i.e. this corresnponds to
+        # attention on spatial resolution 8,16,32, as the
+        # spatial reolution of the latents is 64 for f4
+        - 8
+        - 4
+        - 2
+        num_res_blocks: 2
+        channel_mult:
+        - 1
+        - 2
+        - 3
+        - 4
+        num_head_channels: 32
+    first_stage_config:
+      target: ldm.models.autoencoder.VQModelInterface
+      params:
+        embed_dim: 3
+        n_embed: 8192
+        ckpt_path: configs/first_stage_models/vq-f4/model.yaml
+        ddconfig:
+          double_z: false
+          z_channels: 3
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+    cond_stage_config: __is_unconditional__
+data:
+  target: main.DataModuleFromConfig
+  params:
+    batch_size: 42
+    num_workers: 5
+    wrap: false
+    train:
+      target: taming.data.faceshq.FFHQTrain
+      params:
+        size: 256
+    validation:
+      target: taming.data.faceshq.FFHQValidation
+      params:
+        size: 256
+
+
+lightning:
+  callbacks:
+    image_logger:
+      target: main.ImageLogger
+      params:
+        batch_frequency: 5000
+        max_images: 8
+        increase_log_steps: False
+
+  trainer:
+    benchmark: True
\ No newline at end of file
diff --git a/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4ca66c16c00a0c3fd13ae1ad03635039161e7ad
--- /dev/null
+++ b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
@@ -0,0 +1,85 @@
+model:
+  base_learning_rate: 2.0e-06
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.0015
+    linear_end: 0.0195
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: image
+    image_size: 64
+    channels: 3
+    monitor: val/loss_simple_ema
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 64
+        in_channels: 3
+        out_channels: 3
+        model_channels: 224
+        attention_resolutions:
+        # note: this isn\t actually the resolution but
+        # the downsampling factor, i.e. this corresnponds to
+        # attention on spatial resolution 8,16,32, as the
+        # spatial reolution of the latents is 64 for f4
+        - 8
+        - 4
+        - 2
+        num_res_blocks: 2
+        channel_mult:
+        - 1
+        - 2
+        - 3
+        - 4
+        num_head_channels: 32
+    first_stage_config:
+      target: ldm.models.autoencoder.VQModelInterface
+      params:
+        ckpt_path: configs/first_stage_models/vq-f4/model.yaml
+        embed_dim: 3
+        n_embed: 8192
+        ddconfig:
+          double_z: false
+          z_channels: 3
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+    cond_stage_config: __is_unconditional__
+data:
+  target: main.DataModuleFromConfig
+  params:
+    batch_size: 48
+    num_workers: 5
+    wrap: false
+    train:
+      target: ldm.data.lsun.LSUNBedroomsTrain
+      params:
+        size: 256
+    validation:
+      target: ldm.data.lsun.LSUNBedroomsValidation
+      params:
+        size: 256
+
+
+lightning:
+  callbacks:
+    image_logger:
+      target: main.ImageLogger
+      params:
+        batch_frequency: 5000
+        max_images: 8
+        increase_log_steps: False
+
+  trainer:
+    benchmark: True
\ No newline at end of file
diff --git a/configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
similarity index 83%
rename from configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml
rename to configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
index a2e99c604eae1673fcda36d11d0aa052afd648ca..18dc8c2d9cfb925b0f45e5b89186d71e3274b086 100644
--- a/configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml
+++ b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
@@ -45,7 +45,7 @@ model:
       params:
         embed_dim: 4
         monitor: "val/rec_loss"
-        ckpt_path: "/export/compvis-nfs/user/ablattma/logs/braket/2021-11-26T11-25-56_lsun_churches-convae-f8-ft_from_oi/checkpoints/step=000180071-fidfrechet_inception_distance=2.335.ckpt"
+        ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
         ddconfig:
           double_z: True
           z_channels: 4
@@ -65,7 +65,7 @@ model:
 data:
   target: main.DataModuleFromConfig
   params:
-    batch_size: 24  # TODO: was 96 in our experiments
+    batch_size: 96
     num_workers: 5
     wrap: False
     train:
@@ -82,14 +82,10 @@ lightning:
     image_logger:
       target: main.ImageLogger
       params:
-        batch_frequency: 1000 # TODO 5000
+        batch_frequency: 5000
         max_images: 8
         increase_log_steps: False
 
-    metrics_over_trainsteps_checkpoint:
-      target: pytorch_lightning.callbacks.ModelCheckpoint
-      params:
-        every_n_train_steps: 20000
 
   trainer:
     benchmark: True
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 546e81a00adb38562be5a1eb11ec137c4207b87d..01c0e5666291da4b452cbe1cf3bbf2f2b9e7f0d9 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -5,8 +5,7 @@ import numpy as np
 from tqdm import tqdm
 from functools import partial
 
-from ldm.models.diffusion.ddpm import noise_like
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
 
 
 class DDIMSampler(object):
@@ -27,8 +26,7 @@ class DDIMSampler(object):
                                                   num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
         alphas_cumprod = self.model.alphas_cumprod
         assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
-
-        to_torch = partial(torch.tensor, dtype=torch.float32, device=self.model.device)
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
 
         self.register_buffer('betas', to_torch(self.model.betas))
         self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
@@ -73,7 +71,8 @@ class DDIMSampler(object):
                corrector_kwargs=None,
                verbose=True,
                x_T=None,
-               log_every_t=100
+               log_every_t=100,
+               **kwargs
                ):
         if conditioning is not None:
             if isinstance(conditioning, dict):
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
index 642204ae1474df74d7b2c33ceefd350febc292e0..9835129694a77e488e937f9a3951470904e890db 100644
--- a/ldm/models/diffusion/ddpm.py
+++ b/ldm/models/diffusion/ddpm.py
@@ -16,14 +16,14 @@ from contextlib import contextmanager
 from functools import partial
 from tqdm import tqdm
 from torchvision.utils import make_grid
-from PIL import Image
 from pytorch_lightning.utilities.distributed import rank_zero_only
 
 from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
 from ldm.modules.ema import LitEma
 from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
 from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
 
 
 __conditioning_keys__ = {'concat': 'c_concat',
@@ -37,12 +37,6 @@ def disabled_train(self, mode=True):
     return self
 
 
-def noise_like(shape, device, repeat=False):
-    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
-    noise = lambda: torch.randn(shape, device=device)
-    return repeat_noise() if repeat else noise()
-
-
 def uniform_on_device(r1, r2, shape, device):
     return (r1 - r2) * torch.rand(*shape, device=device) + r2
 
@@ -119,6 +113,7 @@ class DDPM(pl.LightningModule):
         if self.learn_logvar:
             self.logvar = nn.Parameter(self.logvar, requires_grad=True)
 
+
     def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                           linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         if exists(given_betas):
@@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM):
 
         if start_T is not None:
             timesteps = min(timesteps, start_T)
-        print(timesteps, start_T)
         iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
             range(0, timesteps))
 
@@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM):
     @torch.no_grad()
     def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
                verbose=True, timesteps=None, quantize_denoised=False,
-               mask=None, x0=None, shape=None):
+               mask=None, x0=None, shape=None,**kwargs):
         if shape is None:
             shape = (batch_size, self.channels, self.image_size, self.image_size)
         if cond is not None:
@@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM):
                                   mask=mask, x0=x0)
 
     @torch.no_grad()
-    def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None,
+    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+                                                        shape,cond,verbose=False,**kwargs)
+
+        else:
+            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+                                                 return_intermediates=True,**kwargs)
+
+        return samples, intermediates
+
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
                    quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                    plot_diffusion_rows=True, **kwargs):
-        # TODO: maybe add option for ddim sampling via DDIMSampler class
+
+        use_ddim = ddim_steps is not None
+
         log = dict()
         z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                            return_first_stage_outputs=True,
@@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM):
         if sample:
             # get denoise row
             with self.ema_scope("Plotting"):
-                samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+                                                         ddim_steps=ddim_steps,eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
             x_samples = self.decode_first_stage(samples)
             log["samples"] = x_samples
             if plot_denoise_rows:
@@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM):
                     self.first_stage_model, IdentityFirstStage):
                 # also display when quantizing x0 while sampling
                 with self.ema_scope("Plotting Quantized Denoised"):
-                    samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
-                                                         quantize_denoised=True)
+                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+                                                             ddim_steps=ddim_steps,eta=ddim_eta,
+                                                             quantize_denoised=True)
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
                 x_samples = self.decode_first_stage(samples.to(self.device))
                 log["samples_x0_quantized"] = x_samples
 
@@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM):
                 mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
                 mask = mask[:, None, ...]
                 with self.ema_scope("Plotting Inpaint"):
-                    samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
-                                                         quantize_denoised=False, x0=z[:N], mask=mask)
+
+                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                 x_samples = self.decode_first_stage(samples.to(self.device))
                 log["samples_inpainting"] = x_samples
                 log["mask"] = mask
-                if plot_denoise_rows:
-                    denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
-                    log["denoise_row_inpainting"] = denoise_grid
 
                 # outpaint
                 with self.ema_scope("Plotting Outpaint"):
-                    samples = self.sample(cond=c, batch_size=N, return_intermediates=False,
-                                          quantize_denoised=False, x0=z[:N], mask=1. - mask)
+                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                 x_samples = self.decode_first_stage(samples.to(self.device))
                 log["samples_outpainting"] = x_samples
 
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
index b626b9ad3ea6515ea1f9f8c252af01b880e45d6f..a952e6c40308c33edd422da0ce6a60f47e73661b 100644
--- a/ldm/modules/diffusionmodules/util.py
+++ b/ldm/modules/diffusionmodules/util.py
@@ -259,3 +259,9 @@ class HybridConditioner(nn.Module):
         c_concat = self.concat_conditioner(c_concat)
         c_crossattn = self.crossattn_conditioner(c_crossattn)
         return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/main.py b/main.py
index d3498b217f862b792c782fc934c58b56555146c6..e8e18c18fbb01f2e16d9376ea1bfb51f3b5df601 100644
--- a/main.py
+++ b/main.py
@@ -676,7 +676,10 @@ if __name__ == "__main__":
             ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
         else:
             ngpu = 1
-        accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
+        if 'accumulate_grad_batches' in lightning_config.trainer:
+            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
+        else:
+            accumulate_grad_batches = 1
         print(f"accumulate_grad_batches = {accumulate_grad_batches}")
         lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
         if opt.scale_lr:
diff --git a/models/ldm/semantic_synthesis256/config.yaml b/models/ldm/semantic_synthesis256/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a721cfffa5f0cd627be56c07cf5306ae4933cd1
--- /dev/null
+++ b/models/ldm/semantic_synthesis256/config.yaml
@@ -0,0 +1,59 @@
+model:
+  base_learning_rate: 1.0e-06
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.0015
+    linear_end: 0.0205
+    log_every_t: 100
+    timesteps: 1000
+    loss_type: l1
+    first_stage_key: image
+    cond_stage_key: segmentation
+    image_size: 64
+    channels: 3
+    concat_mode: true
+    cond_stage_trainable: true
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 64
+        in_channels: 6
+        out_channels: 3
+        model_channels: 128
+        attention_resolutions:
+        - 32
+        - 16
+        - 8
+        num_res_blocks: 2
+        channel_mult:
+        - 1
+        - 4
+        - 8
+        num_heads: 8
+    first_stage_config:
+      target: ldm.models.autoencoder.VQModelInterface
+      params:
+        embed_dim: 3
+        n_embed: 8192
+        ddconfig:
+          double_z: false
+          z_channels: 3
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.SpatialRescaler
+      params:
+        n_stages: 2
+        in_channels: 182
+        out_channels: 3
diff --git a/scripts/download_first_stages.sh b/scripts/download_first_stages.sh
index 6e8b77585147c44db5bbfb359487b86c74d236b2..a8d79e99ccdff0a8d8762f23f3c0642401f32f6c 100644
--- a/scripts/download_first_stages.sh
+++ b/scripts/download_first_stages.sh
@@ -4,10 +4,10 @@ wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/la
 wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
 wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
 wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
-wget -O models/first_stage_models/vq-f4-noattn/model.zip https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1
+wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
 wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
 wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
-wget -O models/first_stage_models/vq-f16/model.zip https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1
+wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
 
 
 
diff --git a/scripts/download_models.sh b/scripts/download_models.sh
index a6e74de4d5eba51201c0a82376d439b0f994199b..84297d7b8b9a78d241edcd5adaf7d9aa273790de 100644
--- a/scripts/download_models.sh
+++ b/scripts/download_models.sh
@@ -6,9 +6,10 @@ wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/la
 wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
 wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
 wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
+wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
 wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
 wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
-wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
+wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
 
 
 
@@ -33,10 +34,16 @@ unzip -o model.zip
 cd ../semantic_synthesis512
 unzip -o model.zip
 
+cd ../semantic_synthesis256
+unzip -o model.zip
+
 cd ../bsr_sr
 unzip -o model.zip
 
 cd ../layout2img-openimages256
 unzip -o model.zip
 
+cd ../inpainting_big
+unzip -o model.zip
+
 cd ../..