-
Notifications
You must be signed in to change notification settings - Fork 0
/
CategoricalDistribution.java
118 lines (101 loc) · 3.01 KB
/
CategoricalDistribution.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import java.util.Random;
public class CategoricalDistribution implements java.io.Serializable
{
/* the number of possible outcomes */
int outcomes;
/* storing also the probability distribution and the cumilative one
as array of probabilities */
double[] distribution;
double[] cumulativeDistribution;
Random randomGenerator = new Random(); //maybe we can add a seed
/*** ctors ***/
/* values > 0 */
public CategoricalDistribution(int values)
{
outcomes = values;
/* initing as a uniform distribution */
distribution = new double [values];
initUniformDistribution(distribution);
cumulativeDistribution = new double [values];
initCumulativeDistribution(distribution,
cumulativeDistribution);
}
public CategoricalDistribution(int values, double[] probabilities)
{
outcomes = values;
distribution = probabilities; /* here a copy ctor should be wiser */
normalizeDistribution(distribution);
cumulativeDistribution = new double [values];
initCumulativeDistribution(distribution,
cumulativeDistribution);
}
/*** initing ***/
private void initUniformDistribution(double [] probabilities)
{
int values = probabilities.length;
double equalProbability = 1.0 / values;
/* making the distribution uniform */
for(int i = 0; i < values; ++i)
{
probabilities[i] = equalProbability;
}
}
/* normalizing the values to be in [0, 1] and to sum to 1 */
private void normalizeDistribution(double[] probabilities)
{
int values = probabilities.length;
double sum = 0;
for(int i = 0; i < values; ++i)
{
sum += probabilities[i];
}
for(int i = 0; i < values; ++i)
{
probabilities[i] /= sum;
}
}
private void initCumulativeDistribution(double[] probabilities,
double[] cumulativeProbabilities)
{
int values = probabilities.length;
double sum = 0;
for(int i = 0; i < values; ++i)
{
sum += probabilities[i];
cumulativeProbabilities[i] = sum;
}
}
/*** getters and setters ***/
public void setOutcomeProbability(int index, double prob)
{
distribution[index] = prob;
/* getting a valid probability distribution */
normalizeDistribution(distribution);
initCumulativeDistribution(distribution, cumulativeDistribution);
}
public void setOutcomeProbabilities(double[] dist)
{
distribution = dist;
normalizeDistribution(distribution);
initCumulativeDistribution(distribution, cumulativeDistribution);
}
public double getOutcomeProbability(int index)
{
//System.out.println("outcome prob: " + distribution[index]);
return distribution[index];
}
/* random value generation */
public int nextValue()
{
double randomUnif = randomGenerator.nextDouble();
int randomOutcome = -1;
for(int i = 0; i < outcomes && randomOutcome == - 1; ++i)
{
if(randomUnif < cumulativeDistribution[i])
{
randomOutcome = i;
}
}
return randomOutcome;
}
}