Skip to content

Commit

Permalink
add plugin rate_limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Sep 22, 2023
1 parent 8fe4f8d commit fd6cfaa
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ require (
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
Expand Down
145 changes: 145 additions & 0 deletions pkg/rate_limiter/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package rate_limiter

import (
"io"
"net/netip"
"sync"
"time"

"golang.org/x/time/rate"
)

type RateLimiter interface {
Allow(addr netip.Addr) bool
io.Closer
}

type limiter struct {
limit rate.Limit
burst int
mask4 int
mask6 int

closeOnce sync.Once
closeNotify chan struct{}
m sync.Mutex
tables map[netip.Addr]*limiterEntry
}

type limiterEntry struct {
l *rate.Limiter
lastSeen time.Time
sync.Once
}

// limit and burst should be greater than zero.
// If gcInterval is <= 0, it will be automatically chosen between 2~10s.
// In this case, if the token refill time (burst/limit) is greater than 10s,
// the actual average qps limit may be higher than expected.
// If mask is zero or greater than 32/128. The default is 32/48.
// If mask is negative, the masks will be 0.
func NewRateLimiter(limit rate.Limit, burst int, gcInterval time.Duration, mask4, mask6 int) RateLimiter {
if mask4 > 32 || mask4 == 0 {
mask4 = 32
}
if mask4 < 0 {
mask4 = 0
}

if mask6 > 128 || mask6 == 0 {
mask6 = 48
}
if mask6 < 0 {
mask6 = 0
}

if gcInterval <= 0 {
if limit <= 0 || burst <= 0 {
gcInterval = time.Second * 2
} else {
refillSec := float64(burst) / float64(limit)
if refillSec < 2 {
refillSec = 2
}
if refillSec > 10 {
refillSec = 10
}
gcInterval = time.Duration(refillSec) * time.Second
}
}

l := &limiter{
limit: limit,
burst: burst,
mask4: mask4,
mask6: mask6,
closeNotify: make(chan struct{}),
tables: make(map[netip.Addr]*limiterEntry),
}
go l.gcLoop(gcInterval)
return l
}

func (l *limiter) Allow(a netip.Addr) bool {
a = l.applyMask(a)
now := time.Now()
l.m.Lock()
e, ok := l.tables[a]
if !ok {
e = &limiterEntry{
l: rate.NewLimiter(l.limit, l.burst),
lastSeen: now,
}
l.tables[a] = e
}
e.lastSeen = now
clientLimiter := e.l
l.m.Unlock()
return clientLimiter.AllowN(now, 1)
}

func (l *limiter) Close() error {
l.closeOnce.Do(func() {
close(l.closeNotify)
})
return nil
}

func (l *limiter) gcLoop(gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()

for {
select {
case <-l.closeNotify:
return
case now := <-ticker.C:
l.doGc(now, gcInterval)
}
}
}

func (l *limiter) doGc(now time.Time, gcInterval time.Duration) {
l.m.Lock()
defer l.m.Unlock()

for a, e := range l.tables {
if now.Sub(e.lastSeen) > gcInterval {
delete(l.tables, a)
}
}
}

func (l *limiter) applyMask(a netip.Addr) netip.Addr {
switch {
case a.Is4():
m, _ := a.Prefix(l.mask4)
return m.Addr()
case a.Is4In6():
m, _ := netip.AddrFrom4(a.As4()).Prefix(l.mask4)
return m.Addr()
default:
m, _ := a.Prefix(l.mask6)
return m.Addr()
}
}
1 change: 1 addition & 0 deletions plugin/enabled_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/metrics_collector"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/nftset"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/query_summary"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/rate_limiter"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/redirect"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/reverse_lookup"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
Expand Down
85 changes: 85 additions & 0 deletions plugin/executable/rate_limiter/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package rate_limiter

import (
"context"

"github.com/IrineSistiana/mosdns/v5/coremain"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/pkg/rate_limiter"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
"github.com/miekg/dns"
"golang.org/x/time/rate"
)

const PluginType = "rate_limiter"

func init() {
coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
}

type Args struct {
Qps float64 `yaml:"qps"`
Burst int `yaml:"burst"`
Mask4 int `yaml:"mask4"`
Mask6 int `yaml:"mask6"`
}

func (args *Args) init() {
utils.SetDefaultUnsignNum(&args.Qps, 20)
utils.SetDefaultUnsignNum(&args.Burst, 40)
utils.SetDefaultUnsignNum(&args.Mask4, 32)
utils.SetDefaultUnsignNum(&args.Mask4, 48)
}

var _ sequence.Executable = (*RateLimiter)(nil)

type RateLimiter struct {
l rate_limiter.RateLimiter
}

func Init(_ *coremain.BP, args any) (any, error) {
return New(*(args.(*Args))), nil
}

func New(args Args) *RateLimiter {
args.init()
l := rate_limiter.NewRateLimiter(rate.Limit(args.Qps), args.Burst, 0, args.Mask4, args.Mask6)
return &RateLimiter{l: l}
}

func (s *RateLimiter) Exec(ctx context.Context, qCtx *query_context.Context) error {
clientAddr := qCtx.QueryMeta().ClientAddr
if clientAddr.IsValid() {
if !s.l.Allow(clientAddr) {
qCtx.SetResponse(refuse(qCtx.Q()))
}
}
return nil
}

func refuse(q *dns.Msg) *dns.Msg {
r := new(dns.Msg)
r.SetReply(q)
r.Rcode = dns.RcodeRefused
return r
}

0 comments on commit fd6cfaa

Please sign in to comment.