未验证 提交 b1b41b64 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 5a77131d
加载中
加载中
加载中
加载中
+6 −6
原始行号 差异行号 差异行
@@ -230,7 +230,7 @@ class SDFVAE(nn.Module):



    def encode_frames(self, x):
    def encoder_x(self, x):
        if self.enc_dec == 'CNN':
            x = x.view(-1, 1, self.n, self.w)
            x = self.conv(x)
@@ -241,7 +241,7 @@ class SDFVAE(nn.Module):
            raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec))
        return x

    def decode_frames_mu(self, sdh):
    def decoder_mu(self, sdh):
        if self.enc_dec == 'CNN':
            x = self.deconv_fc_mu(sdh)
            x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
@@ -252,7 +252,7 @@ class SDFVAE(nn.Module):
        return x
    
 
    def decode_frames_logvar(self, sdh):
    def decoder_logvar(self, sdh):
        if self.enc_dec == 'CNN':
            x = self.deconv_fc_logvar(sdh)
            x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
@@ -331,13 +331,13 @@ class SDFVAE(nn.Module):
    def forward(self, x):
        x = x.float()
        d_mean_prior, d_logvar_prior, _ = self.sample_d_lstmcell(x.size(0), random_sampling = self.training)
        x_hat = self.encode_frames(x)
        x_hat = self.encoder_x(x)
        d_mean, d_logvar, d, h = self.encode_d(x.size(0), x_hat) 
        s_mean, s_logvar, s = self.encode_s(x_hat)
        s_expand = s.unsqueeze(1).expand(-1, self.T, self.s_dim)
        ds = torch.cat((d, s_expand), dim=2)
        dsh = torch.cat((ds, h), dim=2)
        recon_x_mu = self.decode_frames_mu(dsh)
        recon_x_logvar = self.decode_frames_logvar(dsh)
        recon_x_mu = self.decoder_mu(dsh)
        recon_x_logvar = self.decoder_logvar(dsh)
        return s_mean, s_logvar, s, d_mean, d_logvar, d, d_mean_prior, d_logvar_prior, recon_x_mu, recon_x_logvar