Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong dimension order in unpatchify? #6

Open
Xact-sniper opened this issue Jun 27, 2024 · 1 comment
Open

Wrong dimension order in unpatchify? #6

Xact-sniper opened this issue Jun 27, 2024 · 1 comment

Comments

@Xact-sniper
Copy link

minRF/dit.py

Lines 287 to 288 in 72feb0c

x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)

Should this not be:

    x = x.reshape(shape=(x.shape[0], h, w, c, p, p))
    x = torch.einsum("nhwcpq->nchpwq", x)

I would expect unpatchify( patchify( image ) ) == image but as is that is not the case.

@cloneofsimo
Copy link
Owner

You are 100% correct that this is not the case. This is a bug in my side.

However its actually fine because all the info in the patch gets mapped to unpatched. Order gets only mixed within the patch, so its equivalent upto permutation, which nn.Linear will learn to recover.

What I mean is that, set(patch_of (image))== set(patch_of (unpatchify(patchify(image))). i.e., pixels dont get mixed across patches.

You can see that by running the following code, that always returns true.

import torch

class PatchProcessor:
    def __init__(self, patch_size, out_channels):
        self.patch_size = patch_size
        self.out_channels = out_channels

    def unpatchify(self, x):
        c = self.out_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def patchify(self, x):
        B, C, H, W = x.size()
        x = x.view(
            B,
            C,
            H // self.patch_size,
            self.patch_size,
            W // self.patch_size,
            self.patch_size,
        )
        x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
        return x

patch_size = 4
out_channels = 3  # Assuming an RGB image
processor = PatchProcessor(patch_size, out_channels)
SIZE = 32

image = torch.arange(out_channels * SIZE * SIZE).reshape(1, out_channels, SIZE, SIZE).float()

patched_image = processor.patchify(image)

reconstructed_image = processor.unpatchify(patched_image)


for idx in range(0, SIZE // patch_size, patch_size):
    for jdx in range(0, SIZE // patch_size, patch_size):
        print(f"Patch ({idx}, {jdx}):")
        
        sets_bef = set(image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
        sets_aft = set(reconstructed_image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
        print(sets_bef == sets_aft)

However this was not intended and what you pointed out is correct. This is unnessesary channel-wise shuffle that doesnt need to be here so ill remove this in the future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants