From dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 Mon Sep 17 00:00:00 2001 From: tminer Date: Tue, 4 Jun 2024 20:47:22 +0100 Subject: [PATCH] import packaging to be compatible with setuptools==70.0.0 (#449) * import packaging to be compatible with setuptools==70.0.0 * importing the version module --------- Co-authored-by: Jamie Co-authored-by: Jong Wook Kim --- clip/clip.py | 8 ++++---- requirements.txt | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index f7a5da5e6..398a6282c 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -2,8 +2,8 @@ import os import urllib import warnings -from typing import Any, Union, List -from pkg_resources import packaging +from packaging import version +from typing import Union, List import torch from PIL import Image @@ -20,7 +20,7 @@ BICUBIC = Image.BICUBIC -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): +if version.parse(torch.__version__) < version.parse("1.7.1"): warnings.warn("PyTorch version 1.7.1 or higher is recommended") @@ -228,7 +228,7 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + if version.parse(torch.__version__) < version.parse("1.8.0"): result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) else: result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) diff --git a/requirements.txt b/requirements.txt index 6b98c33f3..e083ffba8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ ftfy +packaging regex tqdm torch