Skip to content

Commit

Permalink
fix: gemm weight swapped dimension
Browse files Browse the repository at this point in the history
the dimensions of the weights are swapped in some cases
  • Loading branch information
jan-haug authored Mar 14, 2023
1 parent 73046a6 commit 1626c74
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onnx2torch/node_converters/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
if not trans_b:
weights = weights.T

in_features, out_features = weights.shape[0], weights.shape[1]
in_features, out_features = weights.shape[1], weights.shape[0]
torch_module = nn.Linear(
in_features=in_features,
out_features=out_features,
Expand Down

0 comments on commit 1626c74

Please sign in to comment.