forked from keon/algorithms
-
Notifications
You must be signed in to change notification settings - Fork 8
/
sparse_mul.py
99 lines (88 loc) · 2.66 KB
/
sparse_mul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Given two sparse matrices A and B, return the result of AB.
You may assume that A's column number is equal to B's row number.
Example:
A = [
[ 1, 0, 0],
[-1, 0, 3]
]
B = [
[ 7, 0, 0 ],
[ 0, 0, 0 ],
[ 0, 0, 1 ]
]
| 1 0 0 | | 7 0 0 | | 7 0 0 |
AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 |
| 0 0 1 |
"""
# Python solution without table (~156ms):
def multiply(self, a, b):
"""
:type A: List[List[int]]
:type B: List[List[int]]
:rtype: List[List[int]]
"""
if a is None or b is None: return None
m, n, l = len(a), len(b[0]), len(b[0])
if len(b) != n:
raise Exception("A's column number must be equal to B's row number.")
c = [[0 for _ in range(l)] for _ in range(m)]
for i, row in enumerate(a):
for k, eleA in enumerate(row):
if eleA:
for j, eleB in enumerate(b[k]):
if eleB: c[i][j] += eleA * eleB
return c
# Python solution with only one table for B (~196ms):
def multiply(self, a, b):
"""
:type A: List[List[int]]
:type B: List[List[int]]
:rtype: List[List[int]]
"""
if a is None or b is None: return None
m, n, l = len(a), len(a[0]), len(b[0])
if len(b) != n:
raise Exception("A's column number must be equal to B's row number.")
c = [[0 for _ in range(l)] for _ in range(m)]
table_b = {}
for k, row in enumerate(b):
table_b[k] = {}
for j, eleB in enumerate(row):
if eleB: table_b[k][j] = eleB
for i, row in enumerate(a):
for k, eleA in enumerate(row):
if eleA:
for j, eleB in table_b[k].iteritems():
c[i][j] += eleA * eleB
return c
# Python solution with two tables (~196ms):
def multiply(self, a, b):
"""
:type A: List[List[int]]
:type B: List[List[int]]
:rtype: List[List[int]]
"""
if a is None or b is None: return None
m, n = len(a), len(b[0])
if len(b) != n:
raise Exception("A's column number must be equal to B's row number.")
l = len(b[0])
table_a, table_b = {}, {}
for i, row in enumerate(a):
for j, ele in enumerate(row):
if ele:
if i not in table_a: table_a[i] = {}
table_a[i][j] = ele
for i, row in enumerate(b):
for j, ele in enumerate(row):
if ele:
if i not in table_b: table_b[i] = {}
table_b[i][j] = ele
c = [[0 for j in range(l)] for i in range(m)]
for i in table_a:
for k in table_a[i]:
if k not in table_b: continue
for j in table_b[k]:
c[i][j] += table_a[i][k] * table_b[k][j]
return c