加载中 sdfvae/model.py +6 −6 原始行号 差异行号 差异行 加载中 @@ -230,7 +230,7 @@ class SDFVAE(nn.Module): def encode_frames(self, x): def encoder_x(self, x): if self.enc_dec == 'CNN': x = x.view(-1, 1, self.n, self.w) x = self.conv(x) 加载中 @@ -241,7 +241,7 @@ class SDFVAE(nn.Module): raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec)) return x def decode_frames_mu(self, sdh): def decoder_mu(self, sdh): if self.enc_dec == 'CNN': x = self.deconv_fc_mu(sdh) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) 加载中 @@ -252,7 +252,7 @@ class SDFVAE(nn.Module): return x def decode_frames_logvar(self, sdh): def decoder_logvar(self, sdh): if self.enc_dec == 'CNN': x = self.deconv_fc_logvar(sdh) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) 加载中 加载中 @@ -331,13 +331,13 @@ class SDFVAE(nn.Module): def forward(self, x): x = x.float() d_mean_prior, d_logvar_prior, _ = self.sample_d_lstmcell(x.size(0), random_sampling = self.training) x_hat = self.encode_frames(x) x_hat = self.encoder_x(x) d_mean, d_logvar, d, h = self.encode_d(x.size(0), x_hat) s_mean, s_logvar, s = self.encode_s(x_hat) s_expand = s.unsqueeze(1).expand(-1, self.T, self.s_dim) ds = torch.cat((d, s_expand), dim=2) dsh = torch.cat((ds, h), dim=2) recon_x_mu = self.decode_frames_mu(dsh) recon_x_logvar = self.decode_frames_logvar(dsh) recon_x_mu = self.decoder_mu(dsh) recon_x_logvar = self.decoder_logvar(dsh) return s_mean, s_logvar, s, d_mean, d_logvar, d, d_mean_prior, d_logvar_prior, recon_x_mu, recon_x_logvar 加载中
sdfvae/model.py +6 −6 原始行号 差异行号 差异行 加载中 @@ -230,7 +230,7 @@ class SDFVAE(nn.Module): def encode_frames(self, x): def encoder_x(self, x): if self.enc_dec == 'CNN': x = x.view(-1, 1, self.n, self.w) x = self.conv(x) 加载中 @@ -241,7 +241,7 @@ class SDFVAE(nn.Module): raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec)) return x def decode_frames_mu(self, sdh): def decoder_mu(self, sdh): if self.enc_dec == 'CNN': x = self.deconv_fc_mu(sdh) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) 加载中 @@ -252,7 +252,7 @@ class SDFVAE(nn.Module): return x def decode_frames_logvar(self, sdh): def decoder_logvar(self, sdh): if self.enc_dec == 'CNN': x = self.deconv_fc_logvar(sdh) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) 加载中 加载中 @@ -331,13 +331,13 @@ class SDFVAE(nn.Module): def forward(self, x): x = x.float() d_mean_prior, d_logvar_prior, _ = self.sample_d_lstmcell(x.size(0), random_sampling = self.training) x_hat = self.encode_frames(x) x_hat = self.encoder_x(x) d_mean, d_logvar, d, h = self.encode_d(x.size(0), x_hat) s_mean, s_logvar, s = self.encode_s(x_hat) s_expand = s.unsqueeze(1).expand(-1, self.T, self.s_dim) ds = torch.cat((d, s_expand), dim=2) dsh = torch.cat((ds, h), dim=2) recon_x_mu = self.decode_frames_mu(dsh) recon_x_logvar = self.decode_frames_logvar(dsh) recon_x_mu = self.decoder_mu(dsh) recon_x_logvar = self.decoder_logvar(dsh) return s_mean, s_logvar, s, d_mean, d_logvar, d, d_mean_prior, d_logvar_prior, recon_x_mu, recon_x_logvar