未验证 提交 84fe054e 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 e9b5d6ac
加载中
加载中
加载中
加载中
+5 −5
原始行号 差异行号 差异行
@@ -53,7 +53,7 @@ class Trainer(object):
            print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints))
            self.start_epoch = 0

    def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, 
    def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean, 
                s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
        batch_size = original_seq.size(0)
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
@@ -64,8 +64,8 @@ class Trainer(object):
        # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
        # Note that var = sigma^2, i.e., log(var) = 2*log(sigma),
        # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter
        loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) 
                                         + 2 * recon_seq_logvar.float() 
        loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2) 
                                         + 2 * recon_seq_logsigma.float() 
                                         + np.log(np.pi*2))
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
        kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
@@ -89,8 +89,8 @@ class Trainer(object):
                _,_,data = dataitem
                data = data.to(self.device)
                self.optimizer.zero_grad()
                s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.model(data)
                loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, 
                s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.model(data)
                loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar, 
                                                       d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar)
                loss.backward()
                self.optimizer.step()