未验证 提交 bb076fdc 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 b4dbd5ac
加载中
加载中
加载中
加载中
+1 −5
原始行号 差异行号 差异行
@@ -113,7 +113,6 @@ def main():
    parser.add_argument('--gpu_id', type=int, default=0)
    # Dataset options
    parser.add_argument('--dataset_path', type=str, default='')
    parser.add_argument('--data_nums', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--T', type=int, default=20)
@@ -149,9 +148,6 @@ def main():
    if not os.path.exists(args.dataset_path):
        raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))
   
    if args.data_nums == 0:
        raise ValueError('Wrong data numbers: {}'.format(args.data_nums))
   
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

@@ -174,7 +170,7 @@ def main():
                         args.win_size,
                         args.T,args.l)
        
    kpi_value_train = KpiReader(args.dataset_path, args.data_nums)
    kpi_value_train = KpiReader(args.dataset_path)

    train_loader = torch.utils.data.DataLoader(kpi_value_train, 
                                               batch_size  = args.batch_size,