Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Nov 21, 2024
1 parent b7ca9c6 commit 333b4d3
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions deeptables/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,9 +923,17 @@ def get_config(self):


class VarLenColumnEmbedding(Layer):
def __init__(self, emb_vocab_size, emb_output_dim, dropout_rate=0. , **kwargs):
def __init__(self, emb_vocab_size, emb_output_dim,
embeddings_initializer,
embeddings_regularizer,
activity_regularizer,
dropout_rate=0.,
**kwargs):
self.emb_vocab_size = emb_vocab_size
self.emb_output_dim = emb_output_dim
self.embeddings_initializer = embeddings_initializer
self.embeddings_regularizer = embeddings_regularizer
self.activity_regularizer = activity_regularizer
self.dropout_rate = dropout_rate
super(VarLenColumnEmbedding, self).__init__(**kwargs)
self.dropout = None
Expand All @@ -937,28 +945,37 @@ def compute_output_shape(self, input_shape):

def build(self, input_shape=None):
super(VarLenColumnEmbedding, self).build(input_shape)
self.emb_layer = Embedding(input_dim=self.emb_vocab_size, output_dim=self.emb_output_dim)
self.emb_layer = Embedding(input_dim=self.emb_vocab_size,
output_dim=self.emb_output_dim,
embeddings_initializer=self.embeddings_initializer,
embeddings_regularizer=self.embeddings_regularizer,
activity_regularizer=self.activity_regularizer)
if self.dropout_rate > 0:
self.dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
else:
self.dropout = None
self.built = True

def call(self, inputs):
embedding_output = self.emb_layer.call(inputs)
embedding_output = embedding_output.reshape((embedding_output[0], 1, -1))
embedding_output = self.emb_layer(inputs)
embedding_output_reshape = tf.reshape(embedding_output, [embedding_output.shape[0], 1, -1])
if self.dropout is not None:
dropout_output = self.dropout(embedding_output)
dropout_output = self.dropout(embedding_output_reshape)
else:
dropout_output = embedding_output
dropout_output = embedding_output_reshape
return dropout_output

def compute_mask(self, inputs, mask=None):
return None

def get_config(self, ):
config = { 'dropout_rate': self.dropout_rate,
'emb_layer': self.emb_layer.get_config()}
config = { 'dropout_rate': self.dropout_rate,
'emb_layer': self.emb_layer.get_config(),
'embeddings_initializer': self.embeddings_initializer,
'embeddings_regularizer': self.embeddings_regularizer,
'emb_vocab_size': self.emb_vocab_size,
'emb_output_dim': self.emb_output_dim
}
base_config = super(VarLenColumnEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down

0 comments on commit 333b4d3

Please sign in to comment.