Skip to content

Commit

Permalink
Remove phased rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
tomgrek committed Apr 28, 2019
1 parent 230d5fa commit b1ad484
Showing 1 changed file with 2 additions and 23 deletions.
25 changes: 2 additions & 23 deletions nn/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma,
a=-self.embedding_range.item(),
b=self.embedding_range.item())

if model_name == 'pRotatE':
self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))

if model_name not in ['TransE', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE']:
if model_name not in ['ComplEx', 'RotatE']:
raise ValueError('model {} not supported'.format(model_name))

def forward(self, sample, mode='single'):
Expand Down Expand Up @@ -111,8 +108,7 @@ def forward(self, sample, mode='single'):

model_func = {
'ComplEx': self.ComplEx,
'RotatE': self.RotatE,
'pRotatE': self.pRotatE
'RotatE': self.RotatE
}

if self.model_name in model_func:
Expand Down Expand Up @@ -166,23 +162,6 @@ def RotatE(self, head, relation, tail, mode):
score = self.gamma.item() - score.sum(dim=2)
return score

def pRotatE(self, head, relation, tail, mode):

phase_head = head/(self.embedding_range.item()/math.pi)
phase_relation = relation/(self.embedding_range.item()/math.pi)
phase_tail = tail/(self.embedding_range.item()/math.pi)

if mode == 'head-batch':
score = phase_head + (phase_relation - phase_tail)
else:
score = (phase_head + phase_relation) - phase_tail

score = torch.sin(score)
score = torch.abs(score)

score = self.gamma.item() - score.sum(dim = 2) * self.modulus
return score

@staticmethod
def train_step(model, optimizer, train_iterator, args):
model.train()
Expand Down

0 comments on commit b1ad484

Please sign in to comment.