加载中 sdfvae/tester.py +6 −0 原始行号 差异行号 差异行 加载中 @@ -82,6 +82,9 @@ class Tester(object): def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # N(x|mu,var) # = log{1/(sqrt(2*pi)*var)exp{-(x-mu)^2/(2*var^2)}} # = -0.5*{log(2*pi)+2*log(var)+[(x-mu)/exp{log(var)}]^2} loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() + np.log(np.pi*2)) 加载中 @@ -98,6 +101,9 @@ class Tester(object): def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logvar): # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # N(x|mu,var) # = log{1/(sqrt(2*pi)*var)exp{-(x-mu)^2/(2*var^2)}} # = -0.5*{log(2*pi)+2*log(var)+[(x-mu)/exp{log(var)}]^2} llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logvar.float())), 2) + 2 * recon_x_logvar.float() + np.log(np.pi*2)) 加载中 加载中
sdfvae/tester.py +6 −0 原始行号 差异行号 差异行 加载中 @@ -82,6 +82,9 @@ class Tester(object): def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # N(x|mu,var) # = log{1/(sqrt(2*pi)*var)exp{-(x-mu)^2/(2*var^2)}} # = -0.5*{log(2*pi)+2*log(var)+[(x-mu)/exp{log(var)}]^2} loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() + np.log(np.pi*2)) 加载中 @@ -98,6 +101,9 @@ class Tester(object): def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logvar): # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # N(x|mu,var) # = log{1/(sqrt(2*pi)*var)exp{-(x-mu)^2/(2*var^2)}} # = -0.5*{log(2*pi)+2*log(var)+[(x-mu)/exp{log(var)}]^2} llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logvar.float())), 2) + 2 * recon_x_logvar.float() + np.log(np.pi*2)) 加载中