-
Notifications
You must be signed in to change notification settings - Fork 74
bugfix for fit_laplace absent dims #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ms on data containers and deterministics
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we forcing dims in variables that didn't have then originally?
Hey @ricardoV94, I am not sure I understand what you mean here. I compared the returned |
pm.sample doesn't add dims to model variables. Those show up in the conversion to InferenceData by arviz |
…variables that did not have them originally
Thank you, @ricardoV94. I was not aware of that. I made a change so that we don't force any dims on variables that were not assigned any in the original model context. |
|
Hey @ricardoV94, anything I can add/modify here to move it forward? |
| laplace_model.add_coords( | ||
| {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} | ||
| ) | ||
| if dim_shapes[0] is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow what's hapening here. If dim_shapes is coming from rv.type.shape, you could have dim_shapes[0] not be None, and a follow up be None. Example x = pm.Data("x", np.zeros(5, 3), shape=(5, None) is valid.
The whole approach seems a bit backwards. Instead of trying to plug the sample dims in the model that may not have properly defined dims to begin with, why not sample first and then attach those in the final inference data object, which always has dims, explicit or automatic?
If you just want to patch for now be more exhaustive and check not any(d is None for d in dim_shapes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Thank you for taking the time to review my changes. I went back and looked at the logic and I agree with you that it would be more appropriate to sample and allow the inference data object to create dims automatically. However, due to the addition of (temp_chain, temp_draw) dims inmodel_to_laplace_approx the automatic dims are incremented by 2. For example, the expected mu_dim_0 would be mu_dim_2. I wrote a helper function to rename these automatically generated dims/coords post creation. Please let me know if that looks okay to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense yes. I think the ultimate issue here was trying to put sampling dims in the model instead of working on the random function directly, but that's a larger question that doesn't need to be had in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are correct. It is odd that we are assigning the sampling dims (temp_chain, temp_draws) in the model. I need to dig a little deeper and see how to change this architectural design to allow the random function handle the sampling dims and not the model. There are some comments in the implementation mentioning that (temp_chain, temp_draws) are supposedly batch dimensions. Maybe the correct approach would be to treat those separately from (chain, draw) and let the random function just name them using defaults. For example (using the above example), we would then get (mu_dim_0, mu_dim_1, mu_dim_2) with shapes (2, 500, 100). I am not sure, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point though is that the chain/draw dimensions from regular sampling don't exist in the graph either. Here we were trying to create a vectorized posterior/predictive/compute deterministics sort of function. Which is fine, but maybe shouldn't 1) be restricted to this method specifically, and 2) hacked into the model. We basically need to get the underlying function that's used by this routine after creating the batched model, call it once, and then handle the conversion to InferenceData ourselves, which is when you tell it that there are 2 batch dimensions with name chain/draw for every variable.
This should be thought about separate from this PR
…s the dimension names after
| assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) | ||
| assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) | ||
|
|
||
| assert idata["posterior"].beta.shape[-2:] == (3, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we test we get the expected sample dims as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is definitely a good idea. I added assert statements to test for sampling dims and I also consolidated the test with test_fit_laplace_ragged_coords because there was a lot of code duplicated.
|
Failing tests seem unrelated, we should open an issue for them / handle their failure. I'll merge this because it was marinating for so long already |
|
Thanks @Dekermanjian and sorry for the delay |
|
Thank you @ricardoV94. I will check if those failing tests were just because my branch wasn't up-to-date with main. If I find those failures on the latest main commit I will open a and issue for them. |
updated dim_shape assignment logic in fit_laplace to handle absent dims on data containers and deterministics. I also added a test for that specific scenario.
closes #581