Skip to content

Commit 01bdccc

Browse files
committed
add analysis of exponential distribution and some analysis of error against centroid sizes
1 parent 31c79be commit 01bdccc

File tree

4 files changed

+183
-33
lines changed

4 files changed

+183
-33
lines changed

core/src/main/java/com/tdunning/math/stats/ScaleFunction.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public double normalizer(double compression, double n) {
123123
},
124124

125125
/**
126-
* Generates cluster sizes proportional to sqrt(1-q) for q >= 1/2, and uniform cluster sizes for q < 1/2 by gluing
126+
* Generates cluster sizes proportional to sqrt(1-q) for q geq 1/2, and uniform cluster sizes for q lt 1/2 by gluing
127127
* the graph of the K_1 function to its tangent line at q=1/2. Changing the split point is possible.
128128
*/
129129
K_1_GLUED {
@@ -327,7 +327,7 @@ private double Z(double compression, double n) {
327327
},
328328

329329
/**
330-
* Generates cluster sizes proportional to 1-q for q >= 1/2, and uniform cluster sizes for q < 1/2 by gluing
330+
* Generates cluster sizes proportional to 1-q for q geq 1/2, and uniform cluster sizes for q lt 1/2 by gluing
331331
* the graph of the K_2 function to its tangent line at q=1/2. Changing the split point is possible.
332332
*/
333333
K_2_GLUED {
@@ -494,7 +494,7 @@ private double Z(double compression, double n) {
494494
},
495495

496496
/**
497-
* Generates cluster sizes proportional to 1-q for q >= 1/2, and uniform cluster sizes for q < 1/2 by gluing
497+
* Generates cluster sizes proportional to 1-q for q geq 1/2, and uniform cluster sizes for q lt 1/2 by gluing
498498
* the graph of the K_3 function to its tangent line at q=1/2.
499499
*/
500500
K_3_GLUED {

core/src/test/java/com/tdunning/math/stats/MergingDigestTest.java

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@
1717

1818
package com.tdunning.math.stats;
1919

20-
import com.carrotsearch.randomizedtesting.annotations.Seed;
20+
import java.io.FileWriter;
21+
import java.io.IOException;
22+
import java.nio.ByteBuffer;
23+
import java.util.ArrayList;
24+
import java.util.Arrays;
25+
import java.util.HashMap;
26+
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Random;
29+
2130
import org.apache.commons.math3.util.Pair;
2231
import org.apache.mahout.common.RandomUtils;
2332
import org.apache.mahout.math.jet.random.AbstractContinousDistribution;
33+
import org.apache.mahout.math.jet.random.Exponential;
2434
import org.apache.mahout.math.jet.random.Uniform;
2535
import org.junit.Before;
2636
import org.junit.BeforeClass;
2737
import org.junit.Test;
2838

29-
import java.io.FileWriter;
30-
import java.io.IOException;
31-
import java.nio.ByteBuffer;
32-
import java.util.*;
39+
import com.carrotsearch.randomizedtesting.annotations.Seed;
3340

3441
//to freeze the tests with a particular seed, put the seed on the next line
3542
//@Seed("84527677CF03B566:A6FF596BDDB2D59D")
@@ -59,10 +66,25 @@ protected TDigest fromBytes(ByteBuffer bytes) {
5966
return MergingDigest.fromBytes(bytes);
6067
}
6168

62-
69+
@Test
70+
public void writeUniformAsymmetricScaleFunctionResults() {
71+
try {
72+
writeAsymmetricScaleFunctionResults(Distribution.UNIFORM);
73+
} catch (Exception e) {
74+
e.printStackTrace();
75+
}
76+
}
6377

6478
@Test
65-
public void writeAsymmetricScaleFunctionResults() {
79+
public void writeExponentialAsymmetricScaleFunctionResults() {
80+
try {
81+
writeAsymmetricScaleFunctionResults(Distribution.EXPONENTIAL);
82+
} catch (Exception e) {
83+
e.printStackTrace();
84+
}
85+
}
86+
87+
private void writeAsymmetricScaleFunctionResults(Distribution distribution) throws Exception {
6688

6789
List<ScaleFunction> scaleFcns = Arrays.asList(ScaleFunction.K_0, ScaleFunction.K_1,
6890
ScaleFunction.K_2, ScaleFunction.K_3, ScaleFunction.K_1_GLUED,
@@ -79,12 +101,12 @@ public void writeAsymmetricScaleFunctionResults() {
79101
digestParams.put(fcn.toString() + "_USUAL", new Pair<>(fcn, false));
80102
}
81103
}
82-
writeSeveralDigestUniformResults(digestParams, numTrials, "../docs/asymmetric/data/merging/");
83-
104+
writeSeveralDigestUniformResults(digestParams, numTrials, distribution,
105+
"../docs/asymmetric/data/merging/" + distribution.name() + "/");
84106
}
85107

86-
public void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boolean>> digestParams, int numTrials,
87-
String writeLocation) {
108+
private void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boolean>> digestParams,
109+
int numTrials, Distribution distribution, String writeLocation) throws Exception {
88110

89111
int trialSize = 1_000_000;
90112
double compression = 100;
@@ -93,8 +115,12 @@ public void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boo
93115

94116
Map<String, List<Integer>> centroidCounts= new HashMap<>();
95117

118+
Map<String, List<List<Integer>>> centroidSequences= new HashMap<>();
119+
120+
96121
for (Map.Entry<String, Pair<ScaleFunction, Boolean>> entry : digestParams.entrySet()) {
97122
centroidCounts.put(entry.getKey(), new ArrayList<Integer>());
123+
centroidSequences.put(entry.getKey(), new ArrayList<List<Integer>>());
98124
try {
99125
Map<Double, List<String>> records = new HashMap<>();
100126
for (double q : quants) {
@@ -105,7 +131,12 @@ public void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boo
105131
digest.setScaleFunction(entry.getValue().getFirst());
106132
digest.setUseAlternatingSort(entry.getValue().getSecond());
107133
Random rand = new Random();
108-
AbstractContinousDistribution gen = new Uniform(50, 51, rand);
134+
AbstractContinousDistribution gen;
135+
if (distribution.equals(Distribution.UNIFORM)) {
136+
gen = new Uniform(50, 51, rand);
137+
} else if (distribution.equals(Distribution.EXPONENTIAL)) {
138+
gen = new Exponential(5, rand);
139+
} else throw new Exception("distribution not specified");
109140
double[] data = new double[trialSize];
110141
for (int i = 0; i < trialSize; i++) {
111142
data[i] = gen.nextDouble();
@@ -121,6 +152,12 @@ public void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boo
121152
String.valueOf(Math.abs(q1 - q2) / Math.min(q, 1 - q)) + "\n");
122153
}
123154
centroidCounts.get(entry.getKey()).add(digest.centroids().size());
155+
156+
List<Integer> seq = new ArrayList<>();
157+
for (Centroid c : digest.centroids()) {
158+
seq.add(c.count());
159+
}
160+
centroidSequences.get(entry.getKey()).add(seq);
124161
}
125162
for (double q : quants) {
126163
FileWriter csvWriter = new FileWriter(writeLocation + entry.getKey() + "_" + String.valueOf(q) + ".csv");
@@ -140,6 +177,17 @@ public void writeSeveralDigestUniformResults(Map<String, Pair<ScaleFunction, Boo
140177
csvWriter.flush();
141178
csvWriter.close();
142179

180+
181+
FileWriter csvWriter2 = new FileWriter(writeLocation + entry.getKey() + "_centroid_sizes.csv");
182+
for (List<Integer> ct : centroidSequences.get(entry.getKey())) {
183+
for (Integer c : ct) {
184+
csvWriter2.append(c.toString()).append(",");
185+
}
186+
csvWriter2.append("\n");
187+
}
188+
csvWriter2.flush();
189+
csvWriter2.close();
190+
143191
} catch (IOException e) {
144192
System.out.println(e.toString());
145193
return;

core/src/test/java/com/tdunning/math/stats/TDigestTest.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package com.tdunning.math.stats;
1919

2020
import com.google.common.collect.Lists;
21+
2122
import org.apache.mahout.common.RandomUtils;
2223
import org.apache.mahout.math.jet.random.AbstractContinousDistribution;
24+
import org.apache.mahout.math.jet.random.Exponential;
2325
import org.apache.mahout.math.jet.random.Gamma;
2426
import org.apache.mahout.math.jet.random.Normal;
2527
import org.apache.mahout.math.jet.random.Uniform;
@@ -31,7 +33,6 @@
3133
import java.util.concurrent.*;
3234
import java.util.concurrent.atomic.AtomicInteger;
3335

34-
3536
/**
3637
* Base test case for TDigests, just extend this class and implement the abstract methods.
3738
*/
@@ -44,6 +45,8 @@ public abstract class TDigestTest extends AbstractTest {
4445

4546
private static String digestName;
4647

48+
protected enum Distribution {UNIFORM, EXPONENTIAL};
49+
4750
@BeforeClass
4851
public static void freezeSeed() {
4952
RandomUtils.useTestSeed();
@@ -129,6 +132,23 @@ public void offsetUniform() {
129132

130133
@Test
131134
public void writeUniformResultsWithCompression() {
135+
try {
136+
writeResultsWithCompression(Distribution.UNIFORM);
137+
} catch (Exception e) {
138+
e.printStackTrace();
139+
}
140+
}
141+
142+
@Test
143+
public void writeExponentialResultsWithCompression() {
144+
try {
145+
writeResultsWithCompression(Distribution.EXPONENTIAL);
146+
} catch (Exception e) {
147+
e.printStackTrace();
148+
}
149+
}
150+
151+
private void writeResultsWithCompression(Distribution distribution) throws Exception {
132152

133153
List<ScaleFunction> scaleFcns = Arrays.asList(ScaleFunction.K_0, ScaleFunction.K_1,
134154
ScaleFunction.K_2, ScaleFunction.K_3, ScaleFunction.K_1_GLUED,
@@ -138,8 +158,12 @@ public void writeUniformResultsWithCompression() {
138158

139159
Map<ScaleFunction, List<Integer>> centroidCounts= new HashMap<>();
140160

161+
Map<ScaleFunction, List<List<Integer>>> centroidSequences= new HashMap<>();
162+
141163
for (ScaleFunction scaleFcn : scaleFcns) {
142164
centroidCounts.put(scaleFcn, new ArrayList<Integer>());
165+
centroidSequences.put(scaleFcn, new ArrayList<List<Integer>>());
166+
143167
try {
144168
Map<Double, List<String>> records = new HashMap<>();
145169
double[] quants = new double[]{0.00001, 0.0001, 0.001, 0.01, 0.1,
@@ -152,7 +176,12 @@ public void writeUniformResultsWithCompression() {
152176
TDigest digest = factory(compression).create();
153177
digest.setScaleFunction(scaleFcn);
154178
Random rand = new Random();
155-
AbstractContinousDistribution gen = new Uniform(50, 51, rand);
179+
AbstractContinousDistribution gen;
180+
if (distribution.equals(Distribution.UNIFORM)) {
181+
gen = new Uniform(50, 51, rand);
182+
} else if (distribution.equals(Distribution.EXPONENTIAL)) {
183+
gen = new Exponential(5, rand);
184+
} else throw new Exception("distribution not specified");
156185
double[] data = new double[trialSize];
157186
for (int i = 0; i < trialSize; i++) {
158187
data[i] = gen.nextDouble();
@@ -168,6 +197,12 @@ public void writeUniformResultsWithCompression() {
168197
String.valueOf(Math.abs(q1 - q2) / Math.min(q, 1 - q)) + "\n");
169198
}
170199
centroidCounts.get(scaleFcn).add(digest.centroids().size());
200+
201+
List<Integer> seq = new ArrayList<>();
202+
for (Centroid c : digest.centroids()) {
203+
seq.add(c.count());
204+
}
205+
centroidSequences.get(scaleFcn).add(seq);
171206
}
172207
}
173208

@@ -180,7 +215,7 @@ public void writeUniformResultsWithCompression() {
180215
}
181216

182217
for (double q : quants) {
183-
FileWriter csvWriter = new FileWriter("../docs/asymmetric/data/tree/" + fcnName + "_" + String.valueOf(q) + ".csv");
218+
FileWriter csvWriter = new FileWriter("../docs/asymmetric/data/tree/" + distribution.name() + "/" + fcnName + "_" + String.valueOf(q) + ".csv");
184219
csvWriter.append("error_q,norm_error_q\n");
185220
for (String obs : records.get(q)) {
186221
csvWriter.append(obs);
@@ -189,14 +224,25 @@ public void writeUniformResultsWithCompression() {
189224
csvWriter.close();
190225
}
191226

192-
FileWriter csvWriter = new FileWriter("../docs/asymmetric/data/tree/" + fcnName + "_centroid_counts.csv");
227+
FileWriter csvWriter = new FileWriter("../docs/asymmetric/data/tree/" + distribution.name() + "/" + fcnName + "_centroid_counts.csv");
193228
csvWriter.append("centroid_count\n");
194229
for (Integer ct : centroidCounts.get(scaleFcn)) {
195230
csvWriter.append(ct.toString()).append("\n");
196231
}
197232
csvWriter.flush();
198233
csvWriter.close();
199234

235+
236+
FileWriter csvWriter2 = new FileWriter("../docs/asymmetric/data/tree/" + distribution.name() + "/" + fcnName + "_centroid_sizes.csv");
237+
for (List<Integer> ct : centroidSequences.get(scaleFcn)) {
238+
for (Integer c : ct) {
239+
csvWriter2.append(c.toString()).append(",");
240+
}
241+
csvWriter2.append("\n");
242+
}
243+
csvWriter2.flush();
244+
csvWriter2.close();
245+
200246
} catch (IOException e) {
201247
return;
202248
}

0 commit comments

Comments
 (0)