Conversation
J-Gann
commented
Jul 24, 2023
- Added documentation for pruning of pretrained models
- Fixed bugs regarding pruning of non-resnet models
- Problems in traversal of network layers while collecting layer information
- Problems where last network layer was pruned
- Added new method for loading pretrained models
- Fixed permission issue regarding wandb on cluster
- Added reference models
|
|
||
| if len(key_elements) > 1: | ||
| parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer} | ||
| if parent_info.parent_info: |
There was a problem hiding this comment.
Here, some models threw an exception. During the traversal of the model layers, when listing the childs of a parent, all childs of childs were not excluded. This led to wrong traversal of the layer tree. I excluded all childs of childs by adding the condition and info.parent_info.var_name == parent_info.var_name
| return False | ||
| return True | ||
| if isinstance(layer_info.module, torch.nn.Conv2d): | ||
| if layer_info.output_size == []: |
There was a problem hiding this comment.
Unhandled state of output_size I ran into.
| self._frozen_layers = frozen_layers | ||
| self._layer_dict = {layer_key: module for layer_key, module in self._reference_model.named_modules() if | ||
| not [*module.children()]} | ||
| last_layer = list(self._layer_dict.keys())[-1] |
There was a problem hiding this comment.
I discovered, that galen assumes, that the last layer of the model is named "fc" as stated e.g here. This leads to unexpected and difficult to resolve errors during pruning. I propose to always add the last layer of the network to the list of frozen layers. Alternatively, it should be included in the documentation.
| ) | ||
|
|
||
| wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), | ||
| wandb.init(project=args.wandb_project, config=vars(args), |
There was a problem hiding this comment.
entity=args.wandb_entity leads to permission problems on the cluster if not running as superuser. The reason is, that wandb tries to access the /tmp folder. Deleting this argument resolves the problem.
| # model on torch hub | ||
| name, repo = select_str.split("@") | ||
| model = torch.hub.load(repo, name, pretrained=True, num_classes=num_classes) | ||
| elif "/" in select_str: |
There was a problem hiding this comment.
Proposal for an additional method for loading pretrained models which were saved using torch.save(model, PATH).
| - pandas | ||
| - pip: | ||
| - torch-pruning | ||
| - torch-pruning==0.2.8 |
There was a problem hiding this comment.
The torch-pruning API changed since version 0.2.8 requiring a refactoring of the galen code.