Misalignment in n_classes Parameter Handling Between BatchGenerator and transform_batch
There appears to be a discrepancy in how the n_classes
parameter is managed between the BatchGenerator
and the transform_batch
function. Specifically, the issue arises because the transform_batch
function defines an optional parameter n_classes
with a default value of 2. However, when BatchGenerator
invokes transform_batch
, it does not pass the n_classes
parameter explicitly. Consequently, transform_batch
always operates with the assumption that n_classes=2
, regardless of the actual number of classes intended by the user or required by the dataset being processed. This leads to an error when training with more then 2 classes even if the user has specified classes > 2 in the network recipe.
I believe there are a couple of possible solutiosn to this, but I would like both of your opinion @oliskir and @frazao
Solution 1: Add anotehr parameter to the Batch generator called n_classes and we pass the parameter along. This would be the easiest solution but likely not the best, as the user has already defined in the recipe the number of classes. Effectively, we would have to ask them to do it again.
Solution 2: the n_classes argument is currently used in each of the our defined NN archs to init to build the network, particularly at the end for the Dense Layer. We could consider having the default transform functions search this variable direcltly as the "default" isntead of n_classes=2
. the problem with this approach is that transform_batch is currently a classmethod, and this doesnt really work unless we change it to a instance method. Which I dont think works for us.
Currently n_classes, is defined in the constructor of each specific nn arch, for isntance:
class ResNetArch(NNArch):
def __init__(self, n_classes, pre_trained_base=None, block_sets=None, initial_filters=16,
initial_strides=1, initial_kernel=(3,3), strides=2, kernel=(3,3),
batch_norm_momentum=0.99, dropout_rate=0, **kwargs):
super(ResNetArch, self).__init__(**kwargs)
self.n_classes = n_classes
What do you think?