-
Notifications
You must be signed in to change notification settings - Fork 0
/
dynamic_range_sampler.h
107 lines (85 loc) · 4 KB
/
dynamic_range_sampler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#ifndef DYNAMIC_SAMPLED_SOFTMAX_LOSS_DYNAMICRANGESAMPLER_H
#define DYNAMIC_SAMPLED_SOFTMAX_LOSS_DYNAMICRANGESAMPLER_H
#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/random/distribution_sampler.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/weighted_picker.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
using namespace tensorflow;
class Env;
class DynamicRangeSampler {
public:
explicit DynamicRangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); }
virtual ~DynamicRangeSampler();
// Sample a single value
virtual int64 Sample(random::SimplePhilox *rnd) const = 0;
// The probability that a single call to Sample() returns the given value.
// Assumes that value is in [0, range). No range checking is done.
virtual float Probability(int64 value) const = 0;
// Fill "batch" with samples from the distribution.
// If unique=true, then we re-pick each element until we get a
// value distinct from all previously picked values in the batch.
void SampleBatch(random::SimplePhilox *rnd, bool unique,
gtl::MutableArraySlice<int64> batch) const;
// Fill "batch" with samples from the distribution, and report
// "expected counts".
//
// The "expected count" of a value is an estimate of the expected
// number of occurrences of the value in the batch returned by a
// call to this function with the given parameters. If unique=true,
// the expected count is an inclusion probability. For details on
// this estimation, see the comment to "ExpectedCountHelper" in the
// .cc file.
//
// Expected counts for the elements of the returned "batch" are reported
// in the aligned array "batch_expected_count".
//
// The user can optionally provide "extras", containing values in the range.
// The expected counts for the extras are reported in the aligned array
// "extras_expected_count".
//
// "batch_expected_count" must have size equal to 0 or to the size of "batch".
// "extras" and "extras_expected_count" must have equal size.
void SampleBatchGetExpectedCount(
random::SimplePhilox *rnd, bool unique,
gtl::MutableArraySlice<int64> batch,
gtl::MutableArraySlice<float> batch_expected_count,
gtl::ArraySlice<int64> extras,
gtl::MutableArraySlice<float> extras_expected_count) const;
// Same as SampleBatchGetExpectedCount (see above), but with avoided values.
// We repick to avoid all of the values in "avoided_values".
// "avoided_values" is only supported with unique=true. If
// unique=false, then avoided_values must be empty.
virtual void SampleBatchGetExpectedCountAvoid(
random::SimplePhilox *rnd, bool unique,
gtl::MutableArraySlice<int64> batch,
gtl::MutableArraySlice<float> batch_expected_count,
gtl::ArraySlice<int64> extras,
gtl::MutableArraySlice<float> extras_expected_count,
gtl::ArraySlice<int64> avoided_values) const;
// Does this sampler need to be updated with values, e.g. UnigramSampler
virtual bool NeedsUpdates() const { return false; }
// Updates the underlying distribution
virtual void Update(gtl::ArraySlice<int64> values) {
LOG(FATAL) << "Update not supported for this sampler type.";
}
int64 range() { return range_; }
protected:
const int64 range_;
};
class DynamicUniformSampler : public DynamicRangeSampler {
public:
explicit DynamicUniformSampler(int64 range);
~DynamicUniformSampler() override {}
int64 Sample(random::SimplePhilox* rnd) const override;
float Probability(int64 value) const override;
private:
const float inv_range_;
};
//TODO: Implement other samplers.
#endif //DYNAMIC_SAMPLED_SOFTMAX_LOSS_DYNAMICRANGESAMPLER_H