加载中 sdfvae/trainer.py +5 −5 原始行号 差异行号 差异行 加载中 @@ -53,7 +53,7 @@ class Trainer(object): print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints)) self.start_epoch = 0 def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. 加载中 @@ -64,8 +64,8 @@ class Trainer(object): # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2} # Note that var = sigma^2, i.e., log(var) = 2*log(sigma), # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2) + 2 * recon_seq_logsigma.float() + np.log(np.pi*2)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) 加载中 @@ -89,8 +89,8 @@ class Trainer(object): _,_,data = dataitem data = data.to(self.device) self.optimizer.zero_grad() s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.model(data) loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.model(data) loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar) loss.backward() self.optimizer.step() 加载中 加载中
sdfvae/trainer.py +5 −5 原始行号 差异行号 差异行 加载中 @@ -53,7 +53,7 @@ class Trainer(object): print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints)) self.start_epoch = 0 def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. 加载中 @@ -64,8 +64,8 @@ class Trainer(object): # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2} # Note that var = sigma^2, i.e., log(var) = 2*log(sigma), # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2) + 2 * recon_seq_logsigma.float() + np.log(np.pi*2)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) 加载中 @@ -89,8 +89,8 @@ class Trainer(object): _,_,data = dataitem data = data.to(self.device) self.optimizer.zero_grad() s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.model(data) loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.model(data) loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar) loss.backward() self.optimizer.step() 加载中