1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
   | from tqdm import tqdm
  import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader, random_split
 
  def parse_args():     config = {         "data_dir": "../input/ml2021springhw43/Dataset",         "save_path": "./model.ckpt",         "batch_size": 32,         "n_workers": 2,         "valid_steps": 2000,         "warmup_steps": 1000,         "save_steps": 10000,         "total_steps": 70000,       }          return config
 
  def main(data_dir, save_path, batch_size, n_workers, valid_steps, warmup_steps, total_steps, save_steps):     """Main function."""     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")     print(f"[Info]: Use {device} now!")
      train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)     train_iterator = iter(train_loader)     print(f"[Info]: Finish loading data!",flush = True)
      model = Classifier(n_spks=speaker_num).to(device)     criterion = nn.CrossEntropyLoss()     optimizer = AdamW(model.parameters(), lr=1e-3)     scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)     print(f"[Info]: Finish creating model!",flush = True)
      best_accuracy = -1.0     best_state_dict = None
      pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")
      for step in range(total_steps):                  try:             batch = next(train_iterator)         except StopIteration:              train_iterator = iter(train_loader)             batch = next(train_iterator)
          loss, accuracy = model_fn(batch, model, criterion, device)          batch_loss = loss.item()         batch_accuracy = accuracy.item()
                   loss.backward()         optimizer.step()         scheduler.step()         optimizer.zero_grad()                       pbar.update()         pbar.set_postfix(           loss=f"{batch_loss:.2f}",           accuracy=f"{batch_accuracy:.2f}",           step=step + 1,         )
                   if (step + 1) % valid_steps == 0:              pbar.close()
              valid_accuracy = valid(valid_loader, model, criterion, device)
                           if valid_accuracy > best_accuracy:                 best_accuracy = valid_accuracy                 best_state_dict = model.state_dict() 
              pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") 
                   if (step + 1) % save_steps == 0 and best_state_dict is not None:              torch.save(best_state_dict, save_path)             pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})")
      pbar.close()
 
  if __name__ == "__main__":     main(**parse_args())
   |