-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NAPSU-MQ inference configuration #55
Conversation
@lumip The way I used to configure NAPSU-MQ inference is very verbose when changing just one parameter, but it makes the signature of Example from the notebook: model = NapsuMQModel(required_marginals=required_marginals)
inference_config = NapsuMQInferenceConfig(
mcmc_config=NapsuMQMCMCConfig(
num_samples=1000
)
)
result = model.fit(
data=orig_df,
rng=inference_rng,
epsilon=1,
delta=(n ** (-2)),
inference_config=inference_config,
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall these are good changes. I've added some thoughts about possible improvements in comments directly at the relevant code.
query_sets: Optional[Iterable] = None, | ||
inference_config: NapsuMQInferenceConfig = NapsuMQInferenceConfig(), | ||
show_progress: bool = True, | ||
return_diagnostics: bool = False) -> 'NapsuMQResult': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep in mind to keep the signatures of fit
compatible between this and DPVIModel
so that they could be used somewhat interchangeable in code. So I think for now it would be good to keep **kwargs
to absorb and ignore any unknown arguments here. (And long term we should think if there's a way to make downstream code truly agnostic about the method it is passed for inference.)
From that perspective, maybe NapsuMQModel
should keep the inference config as a lifetime variable, i.e., it's passed during initialization instead of here..
Do we expect use cases where we'd like to change the inference configuration often for otherwise the same model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think using **kwargs
to just absorb unknown arguments is a good idea, as that can silently ignore errors, for example misspelled argument names.
The config has a default value, so users can change from DPVI to NAPSU-MQ with the same fit
call if they want to use the default config. If they are using a custom NAPSU-MQ config, and want to change to DPVI, they should be required to change the fit
call, as the NAPSU-MQ config won't do anything with DPVI.
I changed the show_progress
argument to silent
for compatibility with DPVI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can get behind not using **kwargs
but then we need to sort out the differences in API between the two methods. With the changes here, NAPSU-MQ would have the following extraneous arguments that are not present for DPVI (or the base class InferenceModel
):
- query_sets
- inference_config
- return_diagnostics
We could make return_diagnostics
a common argument, that could be handy for DPVI as well - alternatively, we could maybe include the diagnostics as part of the result object.
We could do similarly for inference_config
, but I prefer making it an argument for NapsuMQModel.__init__
, i.e., making a NapsuMQModel
instance completely encapsulate all details of its inference process (DPVIModel
already mostly functions that way).
What is actually the intended difference between query_sets
here and required_marginals
passed in the __init__
? Could the be unified or both lifted into __init__
.
Finally, I have reverted your change for show_progress
and adapted DPVIModel
accordingly (also removed the verbose
argument there) - I think show_progress
is the better/more descriptive argument name here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference between query_sets
and required_marginals
is that one chooses the queries explicitly and the other specifies queries that are always included when other queries are selected automatically. I renamed them to make this clear, and put both to __init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also moved inference_config
to __init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now looks good to me, thanks. I have made some small additional changes to InferenceModel
and DPVIModel
so that all fit
function now share the same arguments.
I'll do a rebase and some cleanup and then merge it.
…e to NapsuMQModel
8ba9ff7
to
90b51a0
Compare
Allow changing the inference parameters for NAPSU-MQ. Also add an option to return the
InferenceData
object from NAPSU-MQ that allows inspecting MCMC diagnostics with arviz.