Skip to content

Commit

Permalink
added first version of PAiNN
Browse files Browse the repository at this point in the history
  • Loading branch information
aimat committed Jul 9, 2021
1 parent 6609df3 commit d5722d2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def call(self, inputs, **kwargs):
Returns:
tf.RaggedTensor: Scalar product of shape (batch, [N], ...)
"""
# axis = [i for i in range(len(inputs.shape))][self.axis]
if isinstance(inputs, tf.RaggedTensor) and inputs.ragged_rank == 1:
v = inputs.values
out = tf.reduce_sum(tf.square(v), axis=-1)
Expand Down
14 changes: 7 additions & 7 deletions kgcnn/layers/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def call(self, inputs, **kwargs):
# Simply wrapper for self._kgcnn_wrapper_layer. Only works for simply element-wise operations.
if all([isinstance(x, tf.RaggedTensor) for x in inputs]):
# However, partition could be different, so this is only okay if ragged_validate=False
if all([x.ragged_rank == 1 for x in inputs]) and not self.ragged_validate:
if all([x.ragged_rank == 1 for x in inputs]):
out = self._kgcnn_wrapper_layer([x.values for x in inputs], **kwargs) # will be all Tensor
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=False)
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=self.ragged_validate)
return out
else:
print("WARNING: Layer", self.name, "fail call on values for ragged_rank=1, attempting keras call... ")
Expand All @@ -107,9 +107,9 @@ def call(self, inputs, **kwargs):
# Simply wrapper for self._kgcnn_wrapper_layer. Only works for simply element-wise operations.
if all([isinstance(x, tf.RaggedTensor) for x in inputs]):
# However, partition could be different, so this is only okay if ragged_validate=False
if all([x.ragged_rank == 1 for x in inputs]) and not self.ragged_validate:
if all([x.ragged_rank == 1 for x in inputs]):
out = self._kgcnn_wrapper_layer([x.values for x in inputs], **kwargs) # will be all Tensor
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=False)
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=self.ragged_validate)
return out
else:
print("WARNING: Layer", self.name, "fail call on values for ragged_rank=1, attempting keras call... ")
Expand All @@ -130,9 +130,9 @@ def call(self, inputs, **kwargs):
# Simply wrapper for self._kgcnn_wrapper_layer. Only works for simply element-wise operations.
if all([isinstance(x, tf.RaggedTensor) for x in inputs]):
# However, partition could be different, so this is only okay if ragged_validate=False
if all([x.ragged_rank == 1 for x in inputs]) and not self.ragged_validate:
if all([x.ragged_rank == 1 for x in inputs]):
out = self._kgcnn_wrapper_layer([x.values for x in inputs], **kwargs) # will be all Tensor
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=False)
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=self.ragged_validate)
return out
else:
print("WARNING: Layer", self.name, "fail call on values for ragged_rank=1, attempting keras call... ")
Expand All @@ -156,7 +156,7 @@ def call(self, inputs, **kwargs):
# However, partition could be different, so this is only okay if ragged_validate=False
# For defined inner-dimension and raggd_rank=1 can do sloppy concatenate on values.
if all([x.ragged_rank == 1 for x in inputs]) and self._kgcnn_wrapper_layer.axis == -1 and all(
[x.shape[-1] is not None for x in inputs]) and not self.ragged_validate:
[x.shape[-1] is not None for x in inputs]):
out = self._kgcnn_wrapper_layer([x.values for x in inputs], **kwargs) # will be all Tensor
out = tf.RaggedTensor.from_row_splits(out, inputs[0].row_splits, validate=self.ragged_validate)
return out
Expand Down
14 changes: 7 additions & 7 deletions kgcnn/layers/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def call(self, inputs, **kwargs):
return out


@tf.keras.utils.register_keras_serializable(package='kgcnn', name='DenseEquivariant')
class LinearEquivariant(GraphBaseLayer):
@tf.keras.utils.register_keras_serializable(package='kgcnn', name='TrafoEquivariant')
class TrafoEquivariant(GraphBaseLayer):
"""Linear Combination of equivariant features.
Used by PAiNN. Require ragged_rank=1 and rank=4.
TODO: Will remove this later and replace by a more general version.
Expand All @@ -228,7 +228,7 @@ def __init__(self,
bias_constraint=None,
**kwargs):
"""Initialize layer same as tf.keras.Multiply."""
super(LinearEquivariant, self).__init__(**kwargs)
super(TrafoEquivariant, self).__init__(**kwargs)
self._kgcnn_wrapper_args = ["units", "activation", "use_bias", "kernel_initializer", "bias_initializer",
"kernel_regularizer", "bias_regularizer", "activity_regularizer",
"kernel_constraint", "bias_constraint"]
Expand Down Expand Up @@ -302,10 +302,10 @@ def __init__(self, units,
# Layer
self.lay_dense1 = Dense(units=self.units, activation=activation, use_bias=self.use_bias, **kernel_args,
**self._kgcnn_info)
self.lay_lin_u = LinearEquivariant(self.units, activation='linear', use_bias=False, **kernel_args,
**self._kgcnn_info)
self.lay_lin_v = LinearEquivariant(self.units, activation='linear', use_bias=False, **kernel_args,
**self._kgcnn_info)
self.lay_lin_u = TrafoEquivariant(self.units, activation='linear', use_bias=False, **kernel_args,
**self._kgcnn_info)
self.lay_lin_v = TrafoEquivariant(self.units, activation='linear', use_bias=False, **kernel_args,
**self._kgcnn_info)
self.lay_a = Dense(units=self.units*3, activation='linear', use_bias=self.use_bias, **kernel_args,
**self._kgcnn_info)

Expand Down

0 comments on commit d5722d2

Please sign in to comment.