From 9362c043c1a4bcb00877d082202ba1cba125b276 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Wed, 4 Sep 2024 15:39:13 -0400 Subject: [PATCH] Convert tensor dtype to float64 Signed-off-by: Christina Xu --- .../modules/text_classification/sequence_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_classification/sequence_classification.py b/caikit_nlp/modules/text_classification/sequence_classification.py index a485c59b..dadc8dcc 100644 --- a/caikit_nlp/modules/text_classification/sequence_classification.py +++ b/caikit_nlp/modules/text_classification/sequence_classification.py @@ -185,7 +185,7 @@ def _get_scores(self, text: Union[str, List[str]]): softmax = torch.nn.Softmax(dim=1) raw_scores = softmax(logits) - scores = raw_scores.numpy() + scores = raw_scores.double().numpy() num_labels = self.model.num_labels num_texts = 1 # str if isinstance(text, List):