Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to change covariance matrix type for GMM class #50

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

dominik-strutz
Copy link

This PR adds changes to the zuko.flows.mixture.GMM class, which allow the user to change the type of the covariance matrix used for each of the Gaussian components of the mixture.

The options added are

  • covariance_type, which allows to change the type of the covariance matrices
  • tied a switch which allows to control if covariance matrices are tied between components
  • cov_rank the rank of the low-rank covariance matrix when covariance_type is 'lowrank'

Since the construction of the shapes got quite long I moved this part in its own function.

Below is an illustration of the effect these different choices have for a mixture of 3 two-dimensional Gaussians.

image

)


def _determine_shapes(components, features, covariance_type, tied, cov_rank):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following pattern would allow to reduce code duplication

leading = 1 if tied else components

if covariance_type == 'full':
    shapes.extend([
        (leading, features),
        (leading, features * (features - 1) // 2),
    ])
elif ...

@francois-rozet
Copy link
Member

Hello @dominik-strutz, I quickly went over the code, and it looks nice! I think we should add some tests however, maybe in a new tests/test_flows_gmm.py file.

Are you also planning to improve the initialization as well?

@dominik-strutz
Copy link
Author

Hi @francois-rozet, Yes, I am happy to write some tests.

I'm also happy to try to improve the initialization. Following sklearn again, the initialization methods 'random' and 'random_from_data' (the latter one working quite well in my limited experience) should be easy enough to implement. Also, implementing 'k-mean' or 'k-means++' should be manageable. I don't know how to handle the conditional case yet, but having only the unconditional one would be a good first start.

What is your opinion on how to structure the initialization? I think it would be beneficial to keep the __init__ method quite simple and have a separate method (e.g., GMM.initialization) that takes user-provided data samples and sets the phi variable or the last layer of the network to reasonable values given an initialization method.

I will give the initialisation a try and let you know how it goes.

P.S: I have no idea why the pre-commit hook fails. I used ruff --fix locally to reformat, and from what I can see, it fulfils the requirements.

@francois-rozet
Copy link
Member

francois-rozet commented Apr 4, 2024

I don't know how to handle the conditional case yet, but having only the unconditional one would be a good first start.

I think a good way to handle the conditional case would be to make the weight $W$ of the last layer small (e.g. standard initialization * 1e-2) and set the bias $b$ to the unconditional initialization.

What is your opinion on how to structure the initialization?

I agree that a separate method could be appropriate, similar to the reset_parameters of nn.Linear.

P.S: I have no idea why the pre-commit hook fails. I used ruff --fix locally to reformat, and from what I can see, it fulfils the requirements.

I pulled your branch and ruff check . at the root of the repo returns

zuko/flows/mixture.py:7:1: I001 [*] Import block is un-sorted or un-formatted
Found 1 error.
[*] 1 fixable with the `--fix` option.

Maybe you were not at the root? My version of ruff is 0.1.14 by the way.

@francois-rozet
Copy link
Member

@dominik-strutz Do you still plan on contributing this PR?

@dominik-strutz
Copy link
Author

Yes, I still like to contribute but haven't found much free time to do it recently. I have implemented most of the initialization methods for the unconditional case, but it still needs to be polished up and tested. The extension for the conditional case shouldn't take too long afterwards. If you or someone else wants to continue this sooner, I'm happy to push an intermediary commit of everything I have so far.

@francois-rozet
Copy link
Member

No problem, take your time! I am currently updating a few things and wanted to know if I should wait for this PR for the next minor release.

@francois-rozet francois-rozet marked this pull request as draft July 22, 2024 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants