2021-01-29 13:02:58 +00:00
|
|
|
|
|
|
|
# via https://github.com/huggingface/pytorch-pretrained-BigGAN
|
|
|
|
|
|
|
|
import torch
|
2021-02-09 17:19:14 +00:00
|
|
|
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, one_hot_from_int, truncated_noise_sample, save_as_images, display_in_terminal)
|
|
|
|
|
|
|
|
from PIL import Image
|
2021-01-29 13:02:58 +00:00
|
|
|
|
|
|
|
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
|
|
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
# Load pre-trained model tokenizer (vocabulary)
|
|
|
|
model = BigGAN.from_pretrained('biggan-deep-512')
|
|
|
|
|
|
|
|
# Prepare a input
|
2021-02-09 17:19:14 +00:00
|
|
|
truncation = 0.001
|
|
|
|
class_vector = one_hot_from_int([22, 65, 555, 333], batch_size=4)
|
|
|
|
#class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom', 'daisy'], batch_size=4)
|
|
|
|
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=4)
|
2021-01-29 13:02:58 +00:00
|
|
|
|
|
|
|
# All in tensors
|
2021-02-09 17:19:14 +00:00
|
|
|
# noise_vector = torch.from_numpy(noise_vector)
|
|
|
|
# class_vector = torch.from_numpy(class_vector)
|
|
|
|
|
2021-01-29 13:02:58 +00:00
|
|
|
noise_vector = torch.from_numpy(noise_vector)
|
|
|
|
class_vector = torch.from_numpy(class_vector)
|
|
|
|
|
|
|
|
# If you have a GPU, put everything on 'cuda'
|
|
|
|
noise_vector = noise_vector.to('cpu')
|
|
|
|
class_vector = class_vector.to('cpu')
|
|
|
|
model.to('cpu')
|
|
|
|
|
|
|
|
# Generate an image
|
|
|
|
with torch.no_grad():
|
|
|
|
output = model(noise_vector, class_vector, truncation)
|
|
|
|
|
|
|
|
# If you have a GPU put back on CPU
|
|
|
|
output = output.to('cpu')
|
|
|
|
|
|
|
|
# If you have a sixtel compatible terminal you can display the images in the terminal
|
|
|
|
# (see https://github.com/saitoha/libsixel for details)
|
|
|
|
display_in_terminal(output)
|
|
|
|
|
|
|
|
# Save results as png images
|
|
|
|
save_as_images(output)
|