Skip to content

Commit 99d68cc

Browse files
authored
Merge pull request #26 from sauln/fixup-sparsedm
Fixup sparsedm
2 parents 9893ca6 + 2c6f814 commit 99d68cc

File tree

5 files changed

+50
-22
lines changed

5 files changed

+50
-22
lines changed

ripser/pyRips.pxd

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from libcpp.vector cimport vector
22

33
cdef extern from "ripser.cpp":
4-
vector[float] pythondm(float* D, int N, int modulus,
5-
int dim_max, float threshold, int do_cocycles)
4+
vector[float] rips_dm(float* D, int N, int modulus,
5+
int dim_max, float threshold, int do_cocycles)
66

77
cdef extern from "ripser.cpp":
8-
vector[float] pythondmsparse(int* I, int* J, float* V, int NEdges,
9-
int N, int modulus, int dim_max,
10-
float threshold, int do_cocycles)
8+
vector[float] rips_dm_sparse(int* I, int* J, float* V, int NEdges,
9+
int N, int modulus, int dim_max,
10+
float threshold, int do_cocycles)

ripser/pyRipser.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def doRipsFiltrationDM(np.ndarray[float,ndim=1,mode="c"] DParam not None, int ma
1010

1111
cdef int N = DParam.shape[0]
1212

13-
res = pyRips.pythondm(&DParam[0], N, coeff, maxHomDim, thresh, do_cocycles)
13+
res = pyRips.rips_dm(&DParam[0], N, coeff, maxHomDim, thresh, do_cocycles)
1414

1515
return res
1616

@@ -20,6 +20,6 @@ def doRipsFiltrationDMSparse(np.ndarray[int,ndim=1,mode="c"] I not None, np.ndar
2020

2121
cdef int NEdges = I.size
2222

23-
res = pyRips.pythondmsparse(&I[0], &J[0], &V[0], NEdges, N, coeff, maxHomDim, thresh, do_cocycles)
23+
res = pyRips.rips_dm_sparse(&I[0], &J[0], &V[0], NEdges, N, coeff, maxHomDim, thresh, do_cocycles)
2424

2525
return res

ripser/ripser.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ void ripser<sparse_distance_matrix>::assemble_columns_to_reduce(
932932
simplices.swap(next_simplices);
933933
}
934934

935-
std::vector<value_t> pythondm(float* D, int N, int modulus, int dim_max, float threshold, int do_cocycles) {
935+
std::vector<value_t> rips_dm(float* D, int N, int modulus, int dim_max, float threshold, int do_cocycles) {
936936
//Setup distance matrix and figure out threshold
937937
std::vector<value_t> retvec;
938938
std::vector<value_t> distances(D, D+N);
@@ -961,7 +961,6 @@ std::vector<value_t> pythondm(float* D, int N, int modulus, int dim_max, float t
961961
if (d <= threshold) ++num_edges;
962962
}
963963

964-
std::cout << "dim_max = " << dim_max << "\n";
965964
if (threshold >= max) {
966965
ripser<compressed_lower_distance_matrix> r(std::move(dist), dim_max, threshold, ratio,
967966
modulus, do_cocycles);
@@ -979,7 +978,7 @@ std::vector<value_t> pythondm(float* D, int N, int modulus, int dim_max, float t
979978
}
980979

981980

982-
std::vector<value_t> pythondmsparse(int* I, int* J, float* V, int NEdges,
981+
std::vector<value_t> rips_dm_sparse(int* I, int* J, float* V, int NEdges,
983982
int N, int modulus, int dim_max, float threshold, int do_cocycles) {
984983
//Setup distance matrix and figure out threshold
985984
std::vector<value_t> retvec;
@@ -1010,13 +1009,13 @@ int unwrapvector(std::vector<value_t> vec, float** out) {
10101009
extern "C" {
10111010
int cripser(float** out, float* D, int N,
10121011
int modulus, int dim_max, float threshold, int do_cocycles) {
1013-
std::vector<value_t> resvec = pythondm(D, N, modulus, dim_max, threshold, do_cocycles);
1012+
std::vector<value_t> resvec = rips_dm(D, N, modulus, dim_max, threshold, do_cocycles);
10141013
return unwrapvector(resvec, out);
10151014
}
10161015

10171016
int cripsersparse(float** out, int* I, int* J, float* V, int NEdges, int N,
10181017
int modulus, int dim_max, float threshold, int do_cocycles) {
1019-
std::vector<value_t> resvec = pythondmsparse(I, J, V, NEdges, N, modulus, dim_max, threshold, do_cocycles);
1018+
std::vector<value_t> resvec = rips_dm_sparse(I, J, V, NEdges, N, modulus, dim_max, threshold, do_cocycles);
10201019
return unwrapvector(resvec, out);
10211020
}
10221021
}

ripser/ripser.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from itertools import cycle
7+
import warnings
78

89
import matplotlib.pyplot as plt
910
import matplotlib as mpl
@@ -63,18 +64,20 @@ def __init__(self, maxdim=1, thresh=np.inf, coeff=2, do_cocycles=False, verbose=
6364
self.do_cocycles = do_cocycles
6465
self.verbose = verbose
6566

67+
# Internal variables
6668
self.dgm_ = None
6769
self.cocycles_ = {}
6870
self.dm_ = None # Distance matrix
69-
self.metric_ = None
70-
self.num_edges_ = None #Number of edges added
71+
self.metric_ = None
72+
self.num_edges_ = None # Number of edges added
7173

7274
if self.verbose:
73-
print("Rips(maxdim={}, thres={}, coef={}, verbose={})".format(
74-
maxdim, thresh, coeff, verbose))
75+
print("Rips(maxdim={}, thresh={}, coeff={}, do_cocycles={}, verbose={})".format(
76+
maxdim, thresh, coeff, do_cocycles, verbose))
7577

7678
def transform(self, X, distance_matrix=False, metric='euclidean'):
77-
"""Compute persistence diagrams for X data array.
79+
""" Compute persistence diagrams for X data array. If X is not a distance matrix,
80+
it will be converted to a distance matrix using the chosen metric.
7881
7982
Parameters
8083
----------
@@ -93,17 +96,19 @@ def transform(self, X, distance_matrix=False, metric='euclidean'):
9396

9497
if not distance_matrix:
9598
if X.shape[0] == X.shape[1]:
96-
from warnings import warn
97-
warn("The input matrix is square, but the distance_matrix flag is off. Did you mean to indicate that this was a distance matrix?")
99+
warnings.warn("The input matrix is square, but the distance_matrix flag is off. Did you mean to indicate that this was a distance matrix?")
98100
elif X.shape[0] < X.shape[1]:
99-
from warnings import warn
100-
warn("The input point cloud has more columns than rows; did you mean to transpose?")
101+
warnings.warn("The input point cloud has more columns than rows; did you mean to transpose?")
102+
103+
self.metric_ = metric
101104
X = pairwise_distances(X, metric=metric)
102105
elif sparse.issparse(X):
103106
#Sparse distance matrix
104107
X = sparse.csr_matrix.astype(X.tocsr(), dtype=np.float32)
108+
105109
if not (X.shape[0] == X.shape[1]):
106110
raise Exception('Distance matrix is not square')
111+
107112
self.dm_ = X
108113

109114
dgm = self._compute_rips(X)
@@ -144,19 +149,22 @@ def _compute_rips(self, dm):
144149
res = DRFDMSparse(coo.row, coo.col, coo.data, n_points, \
145150
self.maxdim, self.thresh, self.coeff, self.do_cocycles)
146151
else:
147-
[I, J] = np.meshgrid(np.arange(n_points), np.arange(n_points))
152+
I, J = np.meshgrid(np.arange(n_points), np.arange(n_points))
148153
DParam = np.array(dm[I > J], dtype=np.float32)
149154
res = DRFDM(DParam, self.maxdim, self.thresh,
150155
self.coeff, int(self.do_cocycles))
156+
151157
pds = []
152158
for dim in range(self.maxdim + 1):
153159
# Number of homology classes in this dimension
154160
n_classes = int(res[0])
161+
155162
# First extract the persistence diagram
156163
res = res[1::]
157164
pd = np.array(res[0:n_classes*2])
158165
pds.append(np.reshape(pd, (n_classes, 2)))
159166
res = res[n_classes*2::]
167+
160168
# Now extract the representative cocycles if they were computed
161169
if self.do_cocycles and dim > 0:
162170
self.cocycles_[dim] = []

test/test_ripser.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ def test_instantiate(self):
2929
rip = Rips()
3030
assert rip is not None
3131

32+
class TestTransform():
33+
def test_input_warnings(self):
34+
35+
rips = Rips()
36+
data = np.random.random((3,10))
37+
38+
with pytest.warns(UserWarning, match='has more columns than rows') as w:
39+
rips.transform(data)
40+
41+
data = np.random.random((3,3))
42+
with pytest.warns(UserWarning, match='input matrix is square, but the distance_matrix') as w:
43+
rips.transform(data)
44+
45+
def test_non_square_dist_matrix(self):
46+
rips = Rips()
47+
data = np.random.random((3,10))
48+
49+
with pytest.raises(Exception):
50+
rips.transform(data, distance_matrix=True)
51+
52+
3253
class TestParams():
3354
def test_defaults(self):
3455
data = np.random.random((100,3))

0 commit comments

Comments
 (0)