Skip to content

Commit

Permalink
Merge pull request #236 from GokuMohandas/dev
Browse files Browse the repository at this point in the history
fixed predict with probs error
  • Loading branch information
GokuMohandas authored Aug 4, 2023
2 parents 68e031d + d7f2822 commit 0cfb704
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions madewithml/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Iterable, List
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import ray
import torch
Expand Down Expand Up @@ -62,8 +63,6 @@ def predict_with_proba(
"""
preprocessor = predictor.get_preprocessor()
z = predictor.predict(data=df)["predictions"]
import numpy as np

y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
results = []
for i, prob in enumerate(y_prob):
Expand Down Expand Up @@ -130,7 +129,7 @@ def predict(

# Predict
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
results = predict_with_proba(df=sample_df, predictor=predictor)
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
return results

Expand Down

0 comments on commit 0cfb704

Please sign in to comment.