Skip to content

Refactor run_on_batch and run_on_instance for Custom Transform Functions

Bruno Padovese requested to merge nn_run_methods_refactor into development

As described in #12 This MR addresses inconsistencies in the usage of transform functions in the run_on_batch and run_on_instance methods of the NN_interface class when comapred to how they are used for training. The methods now accept custom transform functions as arguments, making the behavior more flexible and consistent with the rest of the system.

Changes

  • Refactored run_on_batch to accept input_transform_function and output_transform_function as optional parameters.
  • Refactored run_on_instance to also accept input_transform_function and output_transform_function as optional parameters.
  • Updated the documentation and tests to reflect these changes.

Issue Addressed

The previous implementation had inconsistencies in how transform functions were used. While training the neural network allowed for explicit declaration of transform functions, running the model used internal transform functions implicitly, leading to potential inconsistencies and confusion.

Thoughts: Thinking now, I can see the motivation for using the transofrm functions implicitly. However, I believe the code does need to be changed for that appraoch and the document greatly improved.

Besides the issue of consistency, I see two other problems with the way these function currently work:

  1. The way the run_on_batch and run_on_instance are currently declared, it is not obvious to the user which functions they need to overwrite to achieve the desired behavior. In fact, there are two_input_transform functions declared in the nn_interface class: _transform_input and transform_batch, and again, it is not obvious without looking at the code which one the run_on_batch and run_on_instance functions will use
def run_on_batch(self, input_batch, return_raw_output=False, transform_input=True):
  1. The _transform_input and _transform_output internal functions have the _ prefixe, which causes them to not show in the documentation built by sphinx. So the only way for the user to actually visualize these functions is to actually dig up the code.

This is not a mjor issue, and can be solved by expanding the documentation and maybe changing the code a little bit, but what I proposed I think also solves the problem. The changes are not major either way.

What do you think the best approach would be? @frazao

Edited by Bruno Padovese

Merge request reports