From 505806ea95ad82fed0643c0a8e482674a929894b Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 24 Apr 2023 18:48:04 -0400 Subject: [PATCH 1/4] add_movq_to_pretrained --- .gitignore | 1 + muse/modeling_taming_vqgan.py | 132 ++++++++++++++++----- scripts/add_spectral_norm_to_fp16_vqgan.py | 42 +++++++ 3 files changed, 144 insertions(+), 31 deletions(-) create mode 100644 scripts/add_spectral_norm_to_fp16_vqgan.py diff --git a/.gitignore b/.gitignore index 70ead23a..87f647e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # Byte-compiled / optimized / DLL files +vqgan-f16-8192-laion-movq output.jpg __pycache__/ *.py[cod] diff --git a/muse/modeling_taming_vqgan.py b/muse/modeling_taming_vqgan.py index ce2aa2ff..01872587 100644 --- a/muse/modeling_taming_vqgan.py +++ b/muse/modeling_taming_vqgan.py @@ -23,6 +23,35 @@ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config +class SpatialNorm(nn.Module): + def __init__( + self, + zq_channels, + num_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + **norm_layer_params, + ): + super().__init__() + self.norm_layer = norm_layer(num_channels=num_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1) + self.conv_y = nn.Conv2d(zq_channels, num_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, num_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + if self.add_conv: + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f class Upsample(nn.Module): def __init__(self, in_channels: int, with_conv: bool): @@ -69,6 +98,7 @@ def __init__( out_channels: int = None, use_conv_shortcut: bool = False, dropout_prob: float = 0.0, + zq_ch: int = None, ): super().__init__() @@ -76,8 +106,10 @@ def __init__( self.out_channels = out_channels self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels self.use_conv_shortcut = use_conv_shortcut - - self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if zq_ch: + self.norm1 = SpatialNorm(num_groups=32, zq_channels=zq_ch, num_channels=in_channels, eps=1e-6, affine=True) + else: + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d( self.in_channels, self.out_channels_, @@ -85,8 +117,10 @@ def __init__( stride=1, padding=1, ) - - self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True) + if zq_ch: + self.norm2 = SpatialNorm(num_groups=32, zq_channels=zq_ch, num_channels=self.out_channels_, eps=1e-6, affine=True) + else: + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True) self.dropout = nn.Dropout(dropout_prob) self.conv2 = nn.Conv2d( self.out_channels_, @@ -114,13 +148,20 @@ def __init__( padding=0, ) - def forward(self, hidden_states): + def forward(self, hidden_states, quantized_states=None): residual = hidden_states - hidden_states = self.norm1(hidden_states) + if quantized_states: + hidden_states = self.norm1(hidden_states, quantized_states) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = F.silu(hidden_states) hidden_states = self.conv1(hidden_states) - hidden_states = self.norm2(hidden_states) + if quantized_states: + hidden_states = self.norm2(hidden_states, quantized_states) + else: + hidden_states = self.norm2(hidden_states) hidden_states = F.silu(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) @@ -135,19 +176,25 @@ def forward(self, hidden_states): class AttnBlock(nn.Module): - def __init__(self, in_channels: int): + def __init__(self, in_channels: int, zq_ch: int = None): super().__init__() self.in_channels = in_channels conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0) - - self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True) + if zq_ch: + self.norm = SpatialNorm(num_groups=32, zq_channels=zq_ch, num_channels=self.in_channels, eps=1e-6, affine=True) + else: + self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True) self.q, self.k, self.v = conv(), conv(), conv() self.proj_out = conv() - def forward(self, hidden_states): + def forward(self, hidden_states, quantized_states=None): residual = hidden_states - hidden_states = self.norm(hidden_states) + if quantized_states: + hidden_states = self.norm(hidden_states, quantized_states) + else: + hidden_states = self.norm(hidden_states) + query = self.q(hidden_states) key = self.k(hidden_states) @@ -175,7 +222,7 @@ def forward(self, hidden_states): class UpsamplingBlock(nn.Module): - def __init__(self, config, curr_res: int, block_idx: int): + def __init__(self, config, curr_res: int, block_idx: int, zq_ch: int = None): super().__init__() self.config = config @@ -192,10 +239,10 @@ def __init__(self, config, curr_res: int, block_idx: int): res_blocks = [] attn_blocks = [] for _ in range(self.config.num_res_blocks + 1): - res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) + res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout, zq_ch=zq_ch)) block_in = block_out if self.curr_res in self.config.attn_resolutions: - attn_blocks.append(AttnBlock(block_in)) + attn_blocks.append(AttnBlock(block_in, zq_ch=zq_ch)) self.block = nn.ModuleList(res_blocks) self.attn = nn.ModuleList(attn_blocks) @@ -204,11 +251,11 @@ def __init__(self, config, curr_res: int, block_idx: int): if self.block_idx != 0: self.upsample = Upsample(block_in, self.config.resample_with_conv) - def forward(self, hidden_states): + def forward(self, hidden_states, quantized_states=None): for i, res_block in enumerate(self.block): - hidden_states = res_block(hidden_states) + hidden_states = res_block(hidden_states, quantized_states) if len(self.attn) > 1: - hidden_states = self.attn[i](hidden_states) + hidden_states = self.attn[i](hidden_states, quantized_states) if self.upsample is not None: hidden_states = self.upsample(hidden_states) @@ -256,7 +303,7 @@ def forward(self, hidden_states): class MidBlock(nn.Module): - def __init__(self, config, in_channels: int, no_attn: False, dropout: float): + def __init__(self, config, in_channels: int, no_attn: False, dropout: float, zq_ch: int = None): super().__init__() self.config = config @@ -268,13 +315,15 @@ def __init__(self, config, in_channels: int, no_attn: False, dropout: float): self.in_channels, self.in_channels, dropout_prob=self.dropout, + zq_ch=zq_ch ) if not no_attn: - self.attn_1 = AttnBlock(self.in_channels) + self.attn_1 = AttnBlock(self.in_channels, zq_ch=zq_ch) self.block_2 = ResnetBlock( self.in_channels, self.in_channels, dropout_prob=self.dropout, + zq_ch=zq_ch ) def forward(self, hidden_states): @@ -341,7 +390,7 @@ def forward(self, pixel_values): class Decoder(nn.Module): - def __init__(self, config): + def __init__(self, config, zq_ch=None): super().__init__() self.config = config @@ -361,19 +410,22 @@ def __init__(self, config): ) # middle - self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout) + self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout, zq_ch=zq_ch) # upsampling upsample_blocks = [] for i_level in reversed(range(self.config.num_resolutions)): - upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level)) + upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, zq_ch=zq_ch)) if i_level != 0: curr_res = curr_res * 2 self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order # end block_out = self.config.hidden_channels * self.config.channel_mult[0] - self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True) + if zq_ch is not None: + self.norm_out = SpatialNorm(num_groups=32, zq_channels=zq_ch, num_channels=block_out, eps=1e-6, affine=True) + else: + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True) self.conv_out = nn.Conv2d( block_out, self.config.num_channels, @@ -382,19 +434,22 @@ def __init__(self, config): padding=1, ) - def forward(self, hidden_states): + def forward(self, hidden_states, quantized_states=None): # z to block_in hidden_states = self.conv_in(hidden_states) # middle - hidden_states = self.mid(hidden_states) + hidden_states = self.mid(hidden_states, quantized_states) # upsampling for block in reversed(self.up): - hidden_states = block(hidden_states) + hidden_states = block(hidden_states, quantized_states) # end - hidden_states = self.norm_out(hidden_states) + if quantized_states: + hidden_states = self.norm_out(hidden_states, quantized_states) + else: + hidden_states = self.norm_out(hidden_states) hidden_states = F.silu(hidden_states) hidden_states = self.conv_out(hidden_states) @@ -518,15 +573,27 @@ def __init__( dropout: float = 0.0, resample_with_conv: bool = True, commitment_cost: float = 0.25, + use_z_channels: bool = False, ): super().__init__() - + self.use_z_channels = use_z_channels + self.resolution = resolution + self.channel_mult = channel_mult self.config.num_resolutions = len(channel_mult) self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1) self.config.latent_size = resolution // self.config.reduction_factor + self.config.no_attn_mid_block = no_attn_mid_block + self.config.attn_resolutions = attn_resolutions + self.config.z_channels = z_channels + self.config.num_embeddings = num_embeddings + self.config.quantized_embed_dim = quantized_embed_dim + self.encoder = Encoder(self.config) - self.decoder = Decoder(self.config) + if use_z_channels: + self.decoder = Decoder(self.config, zq_ch=self.config.z_channels) + else: + self.decoder = Decoder(self.config) self.quantize = VectorQuantizer( self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost ) @@ -552,7 +619,10 @@ def encode(self, pixel_values, return_loss=False): def decode(self, quantized_states): hidden_states = self.post_quant_conv(quantized_states) - reconstructed_pixel_values = self.decoder(hidden_states) + if self.use_z_channels: + reconstructed_pixel_values = self.decoder(hidden_states, quantized_states) + else: + reconstructed_pixel_values = self.decoder(hidden_states) return reconstructed_pixel_values def decode_code(self, codebook_indices): diff --git a/scripts/add_spectral_norm_to_fp16_vqgan.py b/scripts/add_spectral_norm_to_fp16_vqgan.py new file mode 100644 index 00000000..aeb3959b --- /dev/null +++ b/scripts/add_spectral_norm_to_fp16_vqgan.py @@ -0,0 +1,42 @@ +import json +from argparse import ArgumentParser +from muse import VQGANModel +import torch + + +def add_spectral_norm_to_vae(args): + vae = VQGANModel.from_pretrained(args.vae) + vae_with_spectral = VQGANModel(vae.resolution, + no_attn_mid_block=args.no_attn_mid_block, + z_channels=args.z_channels, + channel_mult=vae.channel_mult, + quantized_embed_dim=args.quantized_embed_dim, + num_embeddings=args.num_embeddings, + attn_resolutions=() if len(args.attn_resolutions) == 0 else [int(resolution) for resolution in args.attn_resolutions.split('|')], + use_z_channels=True + ) + original_state_dict = vae.state_dict() + output_dict = {} + for key in original_state_dict: + if "decoder" in key and "norm" in key: + weight_or_bias = key.split(".")[-1] + new_key = ".".join(key.split(".")[:-1])+".norm_layer."+weight_or_bias + output_dict[new_key] = original_state_dict[key] + else: + output_dict[key] = original_state_dict[key] + vae_with_spectral.load_state_dict(output_dict, strict=False) + print(args.movq_vae_output_path) + vae_with_spectral.save_pretrained(args.movq_vae_output_path) + + # print(vae_with_spectral.decoder) +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--vae", type=str, default="openMUSE/vqgan-f16-8192-laion") + parser.add_argument("--movq_vae_output_path", type=str, default="vqgan-f16-8192-laion-movq") + parser.add_argument("--no_attn_mid_block", action="store_false", default=True) + parser.add_argument("--z_channels", type=int, default=64) + parser.add_argument("--attn_resolutions", type=str, default="", help="Attention resolutions split by |") + parser.add_argument("--quantized_embed_dim", type=int, default=64) + parser.add_argument("--num_embeddings", type=int, default=8192) + args = parser.parse_args() + add_spectral_norm_to_vae(args) From 13fe7891f153c3f3e75fbd606789e431f7d5a891 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 15 May 2023 18:31:44 -0400 Subject: [PATCH 2/4] Fixing logs --- scripts/add_spectral_norm_to_fp16_vqgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/add_spectral_norm_to_fp16_vqgan.py b/scripts/add_spectral_norm_to_fp16_vqgan.py index aeb3959b..d8182245 100644 --- a/scripts/add_spectral_norm_to_fp16_vqgan.py +++ b/scripts/add_spectral_norm_to_fp16_vqgan.py @@ -25,7 +25,7 @@ def add_spectral_norm_to_vae(args): else: output_dict[key] = original_state_dict[key] vae_with_spectral.load_state_dict(output_dict, strict=False) - print(args.movq_vae_output_path) + print(f"Saving to {args.movq_vae_output_path}") vae_with_spectral.save_pretrained(args.movq_vae_output_path) # print(vae_with_spectral.decoder) From 9f475e29f341934e794e24f40313970d76c2757d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 15 May 2023 18:45:07 -0400 Subject: [PATCH 3/4] Adding missing quantized states from midblock --- muse/modeling_taming_vqgan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/muse/modeling_taming_vqgan.py b/muse/modeling_taming_vqgan.py index 01872587..bb063a20 100644 --- a/muse/modeling_taming_vqgan.py +++ b/muse/modeling_taming_vqgan.py @@ -326,11 +326,11 @@ def __init__(self, config, in_channels: int, no_attn: False, dropout: float, zq_ zq_ch=zq_ch ) - def forward(self, hidden_states): - hidden_states = self.block_1(hidden_states) + def forward(self, hidden_states, quantized_states=None): + hidden_states = self.block_1(hidden_states, quantized_states) if not self.no_attn: - hidden_states = self.attn_1(hidden_states) - hidden_states = self.block_2(hidden_states) + hidden_states = self.attn_1(hidden_states, quantized_states) + hidden_states = self.block_2(hidden_states, quantized_states) return hidden_states From 22529abdc48dd1cc8a896157a93005dda142fb61 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 15 May 2023 18:49:19 -0400 Subject: [PATCH 4/4] Fixed ambiguous if --- muse/modeling_taming_vqgan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/muse/modeling_taming_vqgan.py b/muse/modeling_taming_vqgan.py index bb063a20..d8a03e8c 100644 --- a/muse/modeling_taming_vqgan.py +++ b/muse/modeling_taming_vqgan.py @@ -150,7 +150,7 @@ def __init__( def forward(self, hidden_states, quantized_states=None): residual = hidden_states - if quantized_states: + if quantized_states is not None: hidden_states = self.norm1(hidden_states, quantized_states) else: hidden_states = self.norm1(hidden_states) @@ -158,7 +158,7 @@ def forward(self, hidden_states, quantized_states=None): hidden_states = F.silu(hidden_states) hidden_states = self.conv1(hidden_states) - if quantized_states: + if quantized_states is not None: hidden_states = self.norm2(hidden_states, quantized_states) else: hidden_states = self.norm2(hidden_states) @@ -190,7 +190,7 @@ def __init__(self, in_channels: int, zq_ch: int = None): def forward(self, hidden_states, quantized_states=None): residual = hidden_states - if quantized_states: + if quantized_states is not None: hidden_states = self.norm(hidden_states, quantized_states) else: hidden_states = self.norm(hidden_states) @@ -446,7 +446,7 @@ def forward(self, hidden_states, quantized_states=None): hidden_states = block(hidden_states, quantized_states) # end - if quantized_states: + if quantized_states is not None: hidden_states = self.norm_out(hidden_states, quantized_states) else: hidden_states = self.norm_out(hidden_states)