Skip to content

Commit

Permalink
fix bug:预测图像没有在cuda上
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Aug 10, 2020
1 parent 5b672d3 commit 2f62684
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,13 @@ def inference(image, h, w):
:return: text
"""
image = torch.FloatTensor(image)
image = image.to(device)

if h > w:
predict = v_net(image)[0].detach().cpu().numpy() # [W,num_classes]
else:
predict = h_net(image)[0].detach().cpu().numpy() # [W,num_classes]

image.to(device)

label = np.argmax(predict[:], axis=1)
label = [alpha[class_id] for class_id in label]
label = [k for k, g in itertools.groupby(list(label))]
Expand Down

0 comments on commit 2f62684

Please sign in to comment.