-
Notifications
You must be signed in to change notification settings - Fork 29
/
data.py
executable file
·79 lines (71 loc) · 2.47 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from numpy import *
from scipy.sparse import lil_matrix, csr_matrix, csc_matrix
"""
pass in data as a numpy matrix
with data in vectors, class as last row
or as dict mapping from point -> class
"""
class DefDict(dict):
def __init__(self, default, *args, **kwargs):
self.default = default
self.update(*args, **kwargs)
def __getitem__(self, key):
if key not in self:
return self.default
return dict.__getitem__(self, key)
def __setitem__(self, key, val):
dict.__setitem__(self, key, val)
def update(self, *args, **kwargs):
for other in args:
for k in other:
self[k] = other[k]
def addto(self, other):
for key in other:
self[key] += other[key]
class Data:
def __init__(self, inp):
self.matrix = None
self.dict = None
self.sparse = None
if type(inp) == ndarray:
self.matrix = inp
elif type(inp) == DefDict:
self.dict = inp
elif type(inp) == csc_matrix or type(inp) == csr_matrix or type(inp) == lil_matrix:
self.sparse = inp
else:
print type(inp)
raise RuntimeError
def asDict(self):
if self.dict == None:
if self.matrix == None:
raise RuntimeError
self.dict = DefDict([])
for col in self.matrix.T:
print "col:", col, col[:-1], col[-1], self.dict[tuple(col[:-1])]
self.dict[tuple(col[:-1])] = \
tuple(self.dict[tuple(col[:-1])]) + (int(col[-1]),)
return self.dict
def asMatrix(self):
if self.matrix == None:
if not self.sparse == None:
self.matrix = array(self.sparse)
else:
cols = []
for k in self.dict:
for v in self.dict[k]:
cols.append(hstack((k, v)).T)
self.matrix = column_stack(cols)
return self.matrix
def asSparseMatrix(self):
if self.sparse == None:
raise RuntimeError
return self.sparse
if __name__ == "__main__":
testmat = DefDict((), {(1,2,3,4,5):(1,),
(4,2,5,1,0):(2,),
(5,3,2,1,1):(3,4)})
data1 = Data(testmat)
data2 = Data(data1.asMatrix())
print (data1.asMatrix() == data2.asMatrix()).all()
print data2.asDict() == data1.asDict()