In dm-haiku, parameters of neural networks are defined in dictionaries where keys are module (and submodule) names. If you would like to traverse through the values, there are multiple ways of doing so as shown in this dm-haiku issue. However, the dictionary doesn't respect the ordering of the modules and makes it hard to parse submodules. For example, if I have 2 linear layers, each followed by a mlp layer, then using hk.data_structures.traverse(params) will (roughly) return:
['linear', 'linear_2', 'mlp/~/1', 'mlp/~/2'].
whereas I would like it to return:
['linear', 'mlp/~/1', 'linear_2', 'mlp/~/2'].
My reason for wanting this form is if creating an invertible neural network and wanting to reverse the order the params are called, isolating substituent parts for other purposes (e.g. transfer learning), or, in general, wanting more control of how and where to (re)use trained parameters.
To deal with this, I've resorted to regex the names and put them in the order that I want, then using hk.data_structures.filter(predicate, params) to filter by the sorted module names. Although, this is quite tedious if I have to remake a regex every time I want to do this.
I'm wondering if there is a way to convert a dm-haiku dictionary of params to something like a pytree with a hierarchy and ordering that makes this easier? I believe equinox handles parameters in this manner (and I'm going to look more into how that is done soon), but wanted to check to see if I'm overlooking a simple method to allow grouping, reversing, and other permutations of the params's dictionary?