The image-to-Image translation is a field in the computer vision domain that deals with generating a modified image from the original input image based on certain conditions. The conditions can be multi-labels or multi-styles, or both. In recent successful methods, translation of the input image is performed based on the multi-labels and the generation of output image out of the translated feature map is performed based on the multi-styles. The labels and styles are fed to the models via texts or reference images. The translation sometimes takes unnecessary manipulations and alterations in identity attributes that are difficult to control in a semi-supervised setting.
Chinese researchers Xinyang Li, Shengchuan Zhang, Jie Hu, Liujuan Cao, Xiaopeng Hong, Xudong Mao, Feiyue Huang, Yongjian Wu and Rongrong Ji have introduced a new approach to control the image-to-image translation process via Hierarchical Style Disentanglement (HiSD).
HiSD breaks the original labels into tags and attributes. It ensures that the tags are independent of each other, and the attributes are mutually exclusive. While deploying, the model first looks for the tags and then the attributes in a sequential manner. Finally, shapes are defined by latent codes extracted from reference images. Thus improper or unwanted manipulations are avoided. The tags, the attributes and the style requirements are arranged in a crystal-clear hierarchical structure that leads to state-of-the-art disentanglement performance on many public datasets.
HiSD processes all the conditions (i.e., tags, attributes and styles) in an independent strategy so that they can be controlled alone or on inter-conditions or intra-conditions. The model extracts styles easily from the reference images by converting them into latent codes and Gaussian noises. It adds the style to the input image without affecting its identity or other styles, tags and attributes.
Python implementation of HiSD
HiSD needs a Python environment and PyTorch framework to set up and run. Usage of a GPU runtime is optional. Pre-trained HiSD can be loaded and inference may be performed with a CPU runtime itself. Install dependencies using the following command.
!pip install tensorboardx
The following command downloads the source codes from the official repository to the local machine.
!git clone https://github.com/imlixinyang/HiSD.git
Output:
Change the directory to content/HiSD/
using the following command.
%cd HiSD/
Download the publicly available CelebAMask-HQ dataset from the google drive to the local machine to proceed further. Ensure that the train images are stored in the directory /HiSD/datasets
and their corresponding labels are stored in the directory /HiSD/labels
. The following command preprocesses the dataset for training.
!python /content/HiSD/preprocessors/celeba-hq.py --img_path /HiSD/datasets/ --label_path /HiSD/labels/ --target_path datasets --start 3002 --end 30002
The following command trains the model and fits the model configuration to the machine and dataset. It creates a new directory under the current path named ‘outputs’ to store its outputs.
!python core/train.py --config configs/celeba-hq.yaml --gpus 0
Once the dataset preprocessing and the model checkpoint restoration are finished, they can be used for similar applications. A sample implementation is carried out with the following simple python codes. First, import the necessary modules and libraries.
%cd /content/HiSD/ from core.utils import get_config from core.trainer import HiSD_Trainer import argparse import torchvision.utils as vutils import sys import torch import os from torchvision import transforms from PIL import Image import numpy as np import time import matplotlib.pyplot as plt
Download the checkpoint parquet file from the official page using the following command.
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KDrNWLejpo02fcalUOrAJOl1hGoccBKl' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KDrNWLejpo02fcalUOrAJOl1hGoccBKl" -O checkpoint_256_celeba-hq.pt && rm -rf /tmp/cookies.txt
Output:
Move the checkpoint parquet file to the /HiSD
directory using the following commands.
%cd /content/ !mv checkpoint_256_celeba-hq.pt HiSD/
Load the checkpoint and prepare the model for inference using the following codes.
device = 'cpu' config = get_config('configs/celeba-hq_256.yaml') noise_dim = config['noise_dim'] image_size = config['new_size'] checkpoint = 'checkpoint_256_celeba-hq.pt' trainer = HiSD_Trainer(config) # assumed CPU device # if GPU is available, set map_location = None state_dict = torch.load(checkpoint, map_location=torch.device('cpu')) trainer.models.gen.load_state_dict(state_dict['gen_test']) trainer.models.gen.to(device) E = trainer.models.gen.encode T = trainer.models.gen.translate G = trainer.models.gen.decode M = trainer.models.gen.map F = trainer.models.gen.extract transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
Define a function to perform the image-to-image translation.
def translate(input, steps): x = transform(Image.open(input).convert('RGB')).unsqueeze(0).to(device) c = E(x) c_trg = c for j in range(len(steps)): step = steps[j] if step['type'] == 'latent-guided': if step['seed'] is not None: torch.manual_seed(step['seed']) torch.cuda.manual_seed(step['seed']) z = torch.randn(1, noise_dim).to(device) s_trg = M(z, step['tag'], step['attribute']) elif step['type'] == 'reference-guided': reference = transform(Image.open(step['reference']).convert('RGB')).unsqueeze(0).to(device) s_trg = F(reference, step['tag']) c_trg = T(c_trg, s_trg, step['tag']) x_trg = G(c_trg) output = x_trg.squeeze(0).cpu().permute(1, 2, 0).add(1).mul(1/2).clamp(0,1).detach().numpy() return output
The following commands set the desired tags, the attributes and the styles to perform translation. They use in-built example images for translation. Users can opt for their own data images.
First example inference:
input = 'examples/input_0.jpg' # e.g.1 change tag 'Bangs' to attribute 'with' using 3x latent-guided styles (generated by random noise). steps = [ {'type': 'latent-guided', 'tag': 0, 'attribute': 0, 'seed': None} ] plt.figure(figsize=(12,4)) for i in range(3): plt.subplot(1, 3, i+1) output = translate(input, steps) plt.imshow(output, aspect='auto') plt.show()
Output:
Second example inference:
input = 'examples/input_1.jpg' plt.figure(figsize=(12,4)) # e.g.2 change tag 'Glasses' to attribute 'with' using reference-guided styles (extracted from another image). steps = [ {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'} ] output = translate(input, steps) plt.subplot(131) plt.imshow(output, aspect='auto') steps = [ {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_1.jpg'} ] output = translate(input, steps) plt.subplot(132) plt.imshow(output, aspect='auto') steps = [ {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_2.jpg'} ] output = translate(input, steps) plt.subplot(133) plt.imshow(output, aspect='auto') plt.show()
Output:
Third example inference:
input = 'examples/input_2.jpg' # e.g.3 change tag 'Glasses' and 'Bangs 'to attribute 'with', 'Hair color' to 'black' during one translation. steps = [ {'type': 'reference-guided', 'tag': 0, 'reference': 'examples/reference_bangs_0.jpg'}, {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'}, {'type': 'latent-guided', 'tag': 2, 'attribute': 0, 'seed': None} ] output = translate(input, steps) plt.figure(figsize=(5,5)) plt.imshow(output, aspect='auto') plt.show()
Output:
Performance of HiSD
HiSD is trained and evaluated on the famous CelebA-HQ dataset with 30,000 facial images of celebrities with labels of tags and attributes such as hair colour, presence of glasses, bangs, beard and gender. The first 3,000 images are used as test images, and the remaining 27,000 images are used as train images. Competitive models are also trained with the same dataset under identical device configurations for enabling comparison.
HiSD outperforms the current state-of-the-art models including SDIT, ELEGANT, and StarGANv2, on the FID scale (Frechet Inception Distance), which measures realism & FID scale that measures the disentanglement.
Note: Images and illustrations other than the code outputs are taken from the original research paper and the official repository.
No comments:
Post a Comment