44"""
55
66from itertools import cycle
7+ import warnings
78
89import matplotlib .pyplot as plt
910import matplotlib as mpl
@@ -63,18 +64,20 @@ def __init__(self, maxdim=1, thresh=np.inf, coeff=2, do_cocycles=False, verbose=
6364 self .do_cocycles = do_cocycles
6465 self .verbose = verbose
6566
67+ # Internal variables
6668 self .dgm_ = None
6769 self .cocycles_ = {}
6870 self .dm_ = None # Distance matrix
69- self .metric_ = None
70- self .num_edges_ = None #Number of edges added
71+ self .metric_ = None
72+ self .num_edges_ = None # Number of edges added
7173
7274 if self .verbose :
73- print ("Rips(maxdim={}, thres ={}, coef ={}, verbose={})" .format (
74- maxdim , thresh , coeff , verbose ))
75+ print ("Rips(maxdim={}, thresh ={}, coeff={}, do_cocycles ={}, verbose={})" .format (
76+ maxdim , thresh , coeff , do_cocycles , verbose ))
7577
7678 def transform (self , X , distance_matrix = False , metric = 'euclidean' ):
77- """Compute persistence diagrams for X data array.
79+ """ Compute persistence diagrams for X data array. If X is not a distance matrix,
80+ it will be converted to a distance matrix using the chosen metric.
7881
7982 Parameters
8083 ----------
@@ -93,17 +96,19 @@ def transform(self, X, distance_matrix=False, metric='euclidean'):
9396
9497 if not distance_matrix :
9598 if X .shape [0 ] == X .shape [1 ]:
96- from warnings import warn
97- warn ("The input matrix is square, but the distance_matrix flag is off. Did you mean to indicate that this was a distance matrix?" )
99+ warnings .warn ("The input matrix is square, but the distance_matrix flag is off. Did you mean to indicate that this was a distance matrix?" )
98100 elif X .shape [0 ] < X .shape [1 ]:
99- from warnings import warn
100- warn ("The input point cloud has more columns than rows; did you mean to transpose?" )
101+ warnings .warn ("The input point cloud has more columns than rows; did you mean to transpose?" )
102+
103+ self .metric_ = metric
101104 X = pairwise_distances (X , metric = metric )
102105 elif sparse .issparse (X ):
103106 #Sparse distance matrix
104107 X = sparse .csr_matrix .astype (X .tocsr (), dtype = np .float32 )
108+
105109 if not (X .shape [0 ] == X .shape [1 ]):
106110 raise Exception ('Distance matrix is not square' )
111+
107112 self .dm_ = X
108113
109114 dgm = self ._compute_rips (X )
@@ -144,19 +149,22 @@ def _compute_rips(self, dm):
144149 res = DRFDMSparse (coo .row , coo .col , coo .data , n_points , \
145150 self .maxdim , self .thresh , self .coeff , self .do_cocycles )
146151 else :
147- [ I , J ] = np .meshgrid (np .arange (n_points ), np .arange (n_points ))
152+ I , J = np .meshgrid (np .arange (n_points ), np .arange (n_points ))
148153 DParam = np .array (dm [I > J ], dtype = np .float32 )
149154 res = DRFDM (DParam , self .maxdim , self .thresh ,
150155 self .coeff , int (self .do_cocycles ))
156+
151157 pds = []
152158 for dim in range (self .maxdim + 1 ):
153159 # Number of homology classes in this dimension
154160 n_classes = int (res [0 ])
161+
155162 # First extract the persistence diagram
156163 res = res [1 ::]
157164 pd = np .array (res [0 :n_classes * 2 ])
158165 pds .append (np .reshape (pd , (n_classes , 2 )))
159166 res = res [n_classes * 2 ::]
167+
160168 # Now extract the representative cocycles if they were computed
161169 if self .do_cocycles and dim > 0 :
162170 self .cocycles_ [dim ] = []
0 commit comments