1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
| def test_model(): model = models.resnet101(pretrained=True).eval().to(device) for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
def test_autocast_model(): model = models.resnet101(pretrained=True).eval().to(device) with torch.autocast(device_type="cuda"): for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
def test_jit_model(): model = models.resnet101(pretrained=True).eval().to(device) model = torch.jit.script(model) model = torch.jit.freeze(model) for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
def test_antocast_jit_model(): model = models.resnet101(pretrained=True).eval().to(device) with torch.cuda.amp.autocast(cache_enabled=False): model = torch.jit.script(model) model = torch.jit.freeze(model) for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
def test_trt_model(): model = models.resnet101(pretrained=True).eval().to(device) x = torch.ones((1, 3, 512, 512)).cuda() model_trt = torch2trt(model, [x]) model = TRTModule() model.load_state_dict(model_trt.state_dict()) for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
def test_autocast_trt_model(): model = models.resnet101(pretrained=True).eval().to(device) x = torch.ones((1, 3, 512, 512)).cuda() model_trt = torch2trt(model, [x]) model = TRTModule() model.load_state_dict(model_trt.state_dict()) with torch.autocast(device_type="cuda"): for inputs, labels in tqdm(dataloaders['valid']): inputs = inputs.to(device) outputs = model(inputs)
|