Skip to content
Snippets Groups Projects
Commit eef5da90 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish

parent d0c714ae
No related branches found
No related tags found
No related merge requests found
...@@ -19,8 +19,10 @@ from ldm.models.diffusion.plms import PLMSSampler ...@@ -19,8 +19,10 @@ from ldm.models.diffusion.plms import PLMSSampler
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True) # load safety model
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True) safety_model_id = "CompVis/stable-diffusion-v-1-3"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, use_auth_token=True)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, use_auth_token=True)
def chunk(it, size): def chunk(it, size):
it = iter(it) it = iter(it)
...@@ -266,16 +268,23 @@ def main(): ...@@ -266,16 +268,23 @@ def main():
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_image = x_samples_ddim
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 2, 1)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples_ddim: for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save( Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")) os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
if not opt.skip_grid: if not opt.skip_grid:
all_samples.append(x_samples_ddim) all_samples.append(x_checked_image_torch)
if not opt.skip_grid: if not opt.skip_grid:
# additionally, save as grid # additionally, save as grid
...@@ -288,12 +297,6 @@ def main(): ...@@ -288,12 +297,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.") f" \nEnjoy.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment