未验证 提交 5a77131d 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 6cff38f7
加载中
加载中
加载中
加载中
+31 −0
原始行号 差异行号 差异行
@@ -188,12 +188,30 @@ class SDFVAE(nn.Module):
        d_logvars = None
            
        d_t = torch.zeros(batch_size, self.d_dim, device=self.device)
        # Here we assume that p(d_0) = N(0,I), thus d_mean_0 = 0, d_logvar_0 = 0 due to log1 = 0
        d_mean_t = torch.zeros(batch_size, self.d_dim, device=self.device)
        d_logvar_t = torch.zeros(batch_size, self.d_dim, device=self.device)
        h_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)
        c_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)
            
        for _ in range(self.T):
            '''
            When t = 1:
            # According to Figure 5(a), in the beginning, we use the information hidden in h_0 (no any information) to get d_1
            # Here d_mean_1 and d_logvar_1 are still 0, due to h_0 is 0, thus prior p(d_1|d_0) = N(0,I)
            # So we sample d_1 from N(0, I) based on reparameterization trick
            # Next, we update h_1 by using Eq. (2), h1 = r_p(h_0, d_1), also see Figure 5(a)
            
            When t = 2:
            # Here d_2 ~ p(d_2|h_1), since h1 = r_p(h_0, d_1), d_2 ~ p(d_2|d_1), we still sample d_2 based on reparameterization trick
            # It should be noted that p(d_2|d_1) is not N(0,I), due to d_mean_2 and d_logvar_0 are no longer 0,
            # but parameterized by NNs.
            # Then update h_2 by using Eq. (2), h2 = r_p(h_1, d_2),
            # So we construct the time-dependent prior of latent variable d
            
            When t = 3:
            ...
            '''
            enc_d_t = self.enc_d_prior(h_t)
            d_mean_t = self.d_mean_prior(enc_d_t)
            d_logvar_t = self.d_logvar_prior(enc_d_t)
@@ -278,6 +296,19 @@ class SDFVAE(nn.Module):
        c_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)

        for t in range(self.T):
            '''
            Note: the following t denotes the index of x_t, not the t in the loop
            When t = 1:
            # (1) d_1 ~ p(d_1|d_<1,x_=<1), Eq. (10) (sample d_1 based on reparameterization trick)
            # (2) h_1 = r(h_0, d_1, x_1), Eq. (3)
            
            When t = 2:
            # (1) d_2 ~ p(d_2|h_1,x_2), thus d_2 ~ p(d_2|d_<2,x_=<2) Eq. (10) (sample d_2 based on reparameterization trick)
            # (2) h_2 = r(h_1, d_2, x_2), Eq. (3)
            
            When t = 3:
            ...
            '''
            phi_conv_t = self.phi_conv(x[:,t,:]) 
            enc_d_t = self.enc_d(torch.cat([phi_conv_t, h_t], 1)) 
            d_mean_t = self.d_mean(enc_d_t)