未验证 提交 e9b5d6ac 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 8f3d70fd
加载中
加载中
加载中
加载中
+4 −4
原始行号 差异行号 差异行
@@ -158,10 +158,10 @@ class SDFVAE(nn.Module):
                                               padding=(self.pd[0][0],self.pd[0][1]), 
                                               nonlinearity=nn.Tanh()) 
                      ) 
            self.deconv_fc_logvar = nn.Sequential(
            self.deconv_fc_logsigma = nn.Sequential(
                         LinearUnit(self.dec_init_dim, self.conv_dim*2),
                         LinearUnit(self.conv_dim*2, self.cd[0]*self.cd[1]*self.cd[2]))
            self.deconv_logvar = nn.Sequential(
            self.deconv_logsigma = nn.Sequential(
                      ConvUnitTranspose(64, 32, kernel=(self.krl[2][0],self.krl[2][1]),
                                                  stride=(self.srd[2][0],self.srd[2][1]),
                                                  padding=(self.pd[2][0],self.pd[2][1])),  
@@ -254,9 +254,9 @@ class SDFVAE(nn.Module):
 
    def decoder_logsigma(self, sdh):
        if self.enc_dec == 'CNN':
            x = self.deconv_fc_logvar(sdh)
            x = self.deconv_fc_logsigma(sdh)
            x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
            x = self.deconv_logvar(x)
            x = self.deconv_logsigma(x)
            x = x.view(-1, self.T, 1, self.n, self.w)
        else:
            raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec))