1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| from tqdm import tqdm
import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader, random_split
def parse_args(): config = { "data_dir": "../input/ml2021springhw43/Dataset", "save_path": "./model.ckpt", "batch_size": 32, "n_workers": 2, "valid_steps": 2000, "warmup_steps": 1000, "save_steps": 10000, "total_steps": 70000, } return config
def main(data_dir, save_path, batch_size, n_workers, valid_steps, warmup_steps, total_steps, save_steps): """Main function.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[Info]: Use {device} now!")
train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers) train_iterator = iter(train_loader) print(f"[Info]: Finish loading data!",flush = True)
model = Classifier(n_spks=speaker_num).to(device) criterion = nn.CrossEntropyLoss() optimizer = AdamW(model.parameters(), lr=1e-3) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) print(f"[Info]: Finish creating model!",flush = True)
best_accuracy = -1.0 best_state_dict = None
pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")
for step in range(total_steps): try: batch = next(train_iterator) except StopIteration: train_iterator = iter(train_loader) batch = next(train_iterator)
loss, accuracy = model_fn(batch, model, criterion, device) batch_loss = loss.item() batch_accuracy = accuracy.item()
loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() pbar.update() pbar.set_postfix( loss=f"{batch_loss:.2f}", accuracy=f"{batch_accuracy:.2f}", step=step + 1, )
if (step + 1) % valid_steps == 0: pbar.close()
valid_accuracy = valid(valid_loader, model, criterion, device)
if valid_accuracy > best_accuracy: best_accuracy = valid_accuracy best_state_dict = model.state_dict()
pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")
if (step + 1) % save_steps == 0 and best_state_dict is not None: torch.save(best_state_dict, save_path) pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})")
pbar.close()
if __name__ == "__main__": main(**parse_args())
|