run_on_batch and run_on_instance
The run_on_batch
and run_on_instance
functions of the NN_interface
class are very awkward to use.
This is because they internally use the transform functions defined on the nn_architecture class. However, this is inconsistent with the way transform functions were intended to be used, i.e., function defined by the user to process the input and output of the NN.
To train a NN we use the BatchGenerator
to process the input. In the BatchGenerator we have to explicitly declare which transform function we are using. Often, this can be the default transform function delcared in the NN class (for instance, ResNetInterface.transform_batch)
, but we could have used a different one that we defined ourselves. In the case of the former, when we call run_on_batch
or run_on_instance
, it will try to use the internal transform function defined on the class instead of the custom one we trained the model with. One way to bypass this issue is to overwrite the transform_batch
function of the interface before running the model. While this is ok, the concern here is that we are not consistent on how we want the user to use this transform function.
To train the model, it must be used explicitly, but to run the model it is used implicitly. We should stick with just one. At the end, it is just confusing and more often than not, I prefer to run the model manually by applying the transform function myself and calling the relevant tensorflow functions.
My sugestion here to make things more clear is to change the arguments and description of the run_on_batch
and run_on_instance
functions:
Instead of having two booleans for the input and output...
def run_on_batch(self, input_batch, return_raw_output=False, transform_input=True):
# Code here
Replace with the function themselfs, that is, the user has to pass which functions they want to use. In case none is passed, it would be the same as if the user had set return_raw_output
and transform_input
to False
def run_on_batch(self, input_batch, input_transform_function=None, output_transform_function=None):
# Code here
For instance, a call to run_on_batch could be:
run_on_batch(input, input_transform_function=ResNetInterface.transform_batch, output_transform_function=ResNetInterface._transform_output)
or if the user wants to use their own transform functions...
run_on_batch(input, input_transform_function=myOwnInputTransformFunc, output_transform_function=myOwnOutputTransformFunc)