I would like to create a custom keras layer (a codebook for a VQVAE model.) While training I would like to have a tf.Variable
which tracks the usage of each code so I can restart unused codes. So I created my Codebook layer as follows...
class Codebook(layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False)
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
The issue I have is that the Layer
class finds the member variable self.code_counter
and adds it to the list of weights which are saved with the layer. It also expects the self.code_counter
to be present when weights are loaded which is not the case when I run in inference mode. How can I make it so keras does not track a variable in my layer. I do not want it persisted or to be part of the layers.weights
.