diff --git a/models_cpc.py b/models_cpc.py index 6bed4e1..92d2bd0 100644 --- a/models_cpc.py +++ b/models_cpc.py @@ -15,11 +15,13 @@ from timm.models.vision_transformer import PatchEmbed, Block +from huggingface_hub import PyTorchModelHubMixin + import numpy as np -class MaskedAutoencoderViT(nn.Module): +class MaskedAutoencoderViT(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/CoMAE.git", pipeline_tag=["image-feature-extraction"]): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, diff --git a/models_mm_mae.py b/models_mm_mae.py index 7cab268..d42ba28 100644 --- a/models_mm_mae.py +++ b/models_mm_mae.py @@ -18,9 +18,11 @@ from util.pos_embed import get_2d_sincos_pos_embed +from huggingface_hub import PyTorchModelHubMixin -class MaskedAutoencoderViT(nn.Module): + +class MaskedAutoencoderViT(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/MCG-NJU/CoMAE.git", pipeline_tag=["image-feature-extraction"]): """ Masked Autoencoder with VisionTransformer backbone """