Skip to content
Snippets Groups Projects
Commit 239ed0fd authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix more

parent f3f60fcc
No related branches found
No related tags found
No related merge requests found
...@@ -266,7 +266,7 @@ def main(): ...@@ -266,7 +266,7 @@ def main():
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1)
x_image = x_samples_ddim x_image = x_samples_ddim
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
...@@ -295,11 +295,6 @@ def main(): ...@@ -295,11 +295,6 @@ def main():
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
toc = time.time() toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment