https://arxiv.org/abs/2202.04200
Unlike text, images are not sequential. This makes auto-regressive models unsuitable for image generation tasks.
During training, MaskGIT is trained on a masked prediction task, similar to what is used in BERT.
Inference: At each iteration, the model predicts all tokens simultaneously but only keeps the most confident ones. The remaining tokens are masked out and re-predicted in the next iteration.
This approach is much faster than autoregressive models.
Mask scheduling: cosine schedule.
After tokenization using a VQ-encoder, a mask is applied to the input tokens. If , the token is replaced with a special [MASK]
token. Otherwise, the original token is kept.
A value between 0 and 1 (inclusive) is used to determine how many tokens to mask.
tokens are replaced with [MASK]
tokens, where is the total number of tokens in the input sequence.
Mask scheduling is important for generation quality.
Training objective: The model is trained to predict the original tokens using a negative log-likelihood loss, effectively treating it as a classification problem where the generated tokens should match the real tokens in a one-hot manner.
In theory, all tokens could be generated in one pass. However, this is inconsistent with the training task, so iterative decoding is used.
In the first iteration, all tokens are masked out.
The model predicts the probability of each token.
At each masked location , a token is randomly sampled from the predicted distribution using temperature annealing to encourage more diversity. This is similar to how language models decode by sampling a token from a probability distribution.
A confidence score is assigned to each token. The confidence score is the probability of the predicted token. For unmasked positions, the confidence score is 1.
A mask schedule determines how many tokens should be replaced after each iteration.
Apply mask: The tokens are sorted based on their confidence scores, and the tokens with the lowest scores are masked for the next iteration.
The mask design, which determines how many tokens to mask in each iteration, is important for generation quality.
Requirements for :
Types of mask schedules:
Concave > Linear > Convex
Cosine > Cubic > Exponential
There is a sweet spot for the number of iterations. We hypothesize that too many iterations may discourage the model from keeping less confident predictions, which worsens token diversity.