diff --git a/Dockerfile b/Dockerfile index 78ea19c..09ee1ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,8 @@ WORKDIR /go/src/github.com/lyft/cni-ipvlan-vpc-k8s/ RUN go get github.com/golang/dep && \ go install github.com/golang/dep/cmd/dep && \ - go get -u gopkg.in/alecthomas/gometalinter.v1 && \ - gometalinter.v1 --install + go get -u gopkg.in/alecthomas/gometalinter.v2 && \ + gometalinter.v2 --install COPY . /go/src/github.com/lyft/cni-ipvlan-vpc-k8s/ diff --git a/Makefile b/Makefile index a06cfc5..3019105 100644 --- a/Makefile +++ b/Makefile @@ -20,25 +20,27 @@ cache: .PHONY: lint lint: - gometalinter.v1 --disable-all \ - --enable=golint --enable=staticcheck --enable=gosimple --enable=unused --enable=gas \ + gometalinter.v2 --disable-all \ + --enable=golint --enable=megacheck --enable=gas \ --enable=gofmt \ - --deadline=10m --vendor ./... + --deadline=10m --vendor ./... \ + --exclude="Errors unhandled.*" \ + --enable-gc .PHONY: test test: dep cache lint ifndef GOOS - go test -v ./aws/... ./nl ./cmd/cni-ipvlan-vpc-k8s-tool + go test -v ./aws/... ./nl ./cmd/cni-ipvlan-vpc-k8s-tool ./lib/... else @echo Tests not available when cross-compiling endif .PHONY: build build: dep cache - go build -o $(NAME)-ipam ./plugin/ipam/main.go - go build -o $(NAME)-ipvlan ./plugin/ipvlan/ipvlan.go - go build -o $(NAME)-unnumbered-ptp ./plugin/unnumbered-ptp/unnumbered-ptp.go - go build -ldflags "-X main.version=$(VERSION)" -o $(NAME)-tool ./cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go + go build -i -o $(NAME)-ipam ./plugin/ipam/main.go + go build -i -o $(NAME)-ipvlan ./plugin/ipvlan/ipvlan.go + go build -i -o $(NAME)-unnumbered-ptp ./plugin/unnumbered-ptp/unnumbered-ptp.go + go build -i -ldflags "-X main.version=$(VERSION)" -o $(NAME)-tool ./cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go tar cvzf cni-ipvlan-vpc-k8s-$(VERSION).tar.gz $(NAME)-ipam $(NAME)-ipvlan $(NAME)-unnumbered-ptp $(NAME)-tool @@ -58,7 +60,7 @@ interactive-docker: test-docker ci: go get -u github.com/golang/dep/cmd/dep go install github.com/golang/dep/cmd/dep - go get -u gopkg.in/alecthomas/gometalinter.v1 - gometalinter.v1 --install + go get -u gopkg.in/alecthomas/gometalinter.v2 + gometalinter.v2 --install $(MAKE) all diff --git a/aws/cache/cacheable.go b/aws/cache/cacheable.go index a33c888..c8d445c 100644 --- a/aws/cache/cacheable.go +++ b/aws/cache/cacheable.go @@ -6,6 +6,8 @@ import ( "os" "path" "time" + + "github.com/lyft/cni-ipvlan-vpc-k8s/lib" ) const ( @@ -36,35 +38,10 @@ func cachePath() string { return path.Join(cacheRoot, cacheProgram) } -// JSONTime is a RFC3339 encoded time with JSON marshallers -type JSONTime struct { - time.Time -} - -// MarshalJSON marshals a JSONTime to an RFC3339 string -func (j *JSONTime) MarshalJSON() ([]byte, error) { - return json.Marshal(j.Time.Format(time.RFC3339)) -} - -// UnmarshalJSON unmarshals a JSONTime to a time.Time -func (j *JSONTime) UnmarshalJSON(js []byte) error { - var rawString string - err := json.Unmarshal(js, &rawString) - if err != nil { - return err - } - t, err := time.Parse(time.RFC3339, rawString) - if err != nil { - return err - } - j.Time = t - return nil -} - // Cacheable defines metadata for objects which can be cached to files as JSON type Cacheable struct { - Expires JSONTime `json:"_expires"` - Contents interface{} `json":contents"` + Expires lib.JSONTime `json:"_expires"` + Contents interface{} `json":contents"` } func ensureDirectory() error { diff --git a/aws/cache/cacheable_test.go b/aws/cache/cacheable_test.go index 87fea72..8579943 100644 --- a/aws/cache/cacheable_test.go +++ b/aws/cache/cacheable_test.go @@ -1,44 +1,15 @@ package cache import ( - "encoding/json" "testing" "time" ) -type Foo struct { - TheTime JSONTime `json:"time"` -} - type Thing struct { Value string `json:"value"` TValue int } -func TestJSONTime_MarshalJSON(t *testing.T) { - input := Foo{JSONTime{time.Date(2017, 1, 1, 1, 1, 0, 0, time.UTC)}} - output, err := json.Marshal(&input) - if err != nil { - t.Error(err) - } - if string(output) != `{"time":"2017-01-01T01:01:00Z"}` { - t.Error(string(output)) - } -} - -func TestJSONTime_UnmarshalJSON(t *testing.T) { - input := []byte(`{"time":"2017-01-01T01:01:00Z"}`) - var foo Foo - err := json.Unmarshal(input, &foo) - if err != nil { - t.Error(err) - } - expected := Foo{JSONTime{time.Date(2017, 1, 1, 1, 1, 0, 0, time.UTC)}} - if !foo.TheTime.Time.Equal(expected.TheTime.Time) { - t.Errorf("Times were not equal: %v %v %v", foo.TheTime.Time, expected.TheTime.Time, foo) - } -} - func TestGet(t *testing.T) { var d Thing state := Get("key_not_exist", &d) diff --git a/aws/client_test.go b/aws/client_test.go index e16ec9c..34cb27d 100644 --- a/aws/client_test.go +++ b/aws/client_test.go @@ -25,7 +25,7 @@ func TestClientCreate(t *testing.T) { } client2, err := defaultClient.newEC2() - if client != client2 { + if client != client2 || err != nil { t.Errorf("Clients returned were not identical (no caching)") } diff --git a/aws/vpc.go b/aws/vpc.go index bb2aaaf..8db805f 100644 --- a/aws/vpc.go +++ b/aws/vpc.go @@ -101,7 +101,7 @@ func (v *vpcclient) DescribeVPCPeerCIDRs(vpcID string) ([]*net.IPNet, error) { // and visible to the API, even if the CIDR is not active in // one of the peered VPCs. We store all of the CIDRs in a map // to de-duplicate them. - cidrs := make(map[string]bool, 0) + cidrs := make(map[string]bool) for _, peering := range res.VpcPeeringConnections { var peer *ec2.VpcPeeringConnectionVpcInfo diff --git a/cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go b/cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go index 7483916..7438b62 100644 --- a/cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go +++ b/cmd/cni-ipvlan-vpc-k8s-tool/cni-ipvlan-vpc-k8s-tool.go @@ -6,12 +6,15 @@ import ( "os" "strings" "text/tabwriter" + "time" "github.com/urfave/cli" - "github.com/lyft/cni-ipvlan-vpc-k8s" "github.com/lyft/cni-ipvlan-vpc-k8s/aws" + "github.com/lyft/cni-ipvlan-vpc-k8s/lib" + "github.com/lyft/cni-ipvlan-vpc-k8s/lib/freeip" "github.com/lyft/cni-ipvlan-vpc-k8s/nl" + "github.com/lyft/cni-ipvlan-vpc-k8s/registry" ) var version string @@ -22,7 +25,7 @@ func filterBuild(input string) (map[string]string, error) { return nil, nil } - ret := make(map[string]string, 0) + ret := make(map[string]string) tuples := strings.Split(input, ",") for _, t := range tuples { kv := strings.Split(t, "=") @@ -40,7 +43,7 @@ func filterBuild(input string) (map[string]string, error) { } func actionNewInterface(c *cli.Context) error { - return cniipvlanvpck8s.LockfileRun(func() error { + return lib.LockfileRun(func() error { filtersRaw := c.String("subnet_filter") filters, err := filterBuild(filtersRaw) if err != nil { @@ -66,7 +69,7 @@ func actionNewInterface(c *cli.Context) error { } func actionBugs(c *cli.Context) error { - return cniipvlanvpck8s.LockfileRun(func() error { + return lib.LockfileRun(func() error { w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) fmt.Fprintln(w, "bug\tafflicted\t") for _, bug := range aws.ListBugs(aws.DefaultClient) { @@ -78,7 +81,7 @@ func actionBugs(c *cli.Context) error { } func actionRemoveInterface(c *cli.Context) error { - return cniipvlanvpck8s.LockfileRun(func() error { + return lib.LockfileRun(func() error { interfaces := c.Args() if len(interfaces) <= 0 { @@ -96,7 +99,7 @@ func actionRemoveInterface(c *cli.Context) error { } func actionDeallocate(c *cli.Context) error { - return cniipvlanvpck8s.LockfileRun(func() error { + return lib.LockfileRun(func() error { releaseIps := c.Args() for _, toRelease := range releaseIps { @@ -122,7 +125,7 @@ func actionDeallocate(c *cli.Context) error { } func actionAllocate(c *cli.Context) error { - return cniipvlanvpck8s.LockfileRun(func() error { + return lib.LockfileRun(func() error { index := c.Int("index") res, err := aws.DefaultClient.AllocateIPFirstAvailableAtIndex(index) if err != nil { @@ -137,7 +140,7 @@ func actionAllocate(c *cli.Context) error { } func actionFreeIps(c *cli.Context) error { - ips, err := cniipvlanvpck8s.FindFreeIPsAtIndex(0) + ips, err := freeip.FindFreeIPsAtIndex(0) if err != nil { fmt.Println(err) return err @@ -270,6 +273,60 @@ func actionSubnets(c *cli.Context) error { return nil } +func actionRegistryList(c *cli.Context) error { + return lib.LockfileRun(func() error { + + reg := ®istry.Registry{} + ips, err := reg.List() + if err != nil { + return err + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "ip\t") + for _, ip := range ips { + fmt.Fprintf(w, "%v\t\n", + ip) + } + w.Flush() + return nil + }) +} + +func actionRegistryGc(c *cli.Context) error { + return lib.LockfileRun(func() error { + + reg := ®istry.Registry{} + freeAfter := c.Duration("free-after") + if freeAfter <= 0*time.Second { + fmt.Fprintf(os.Stderr, + "Invalid duration specified. free-after must be > 0 seconds. Got %v. Please specify with --free-minutes=[time]\n", freeAfter) + return fmt.Errorf("invalid duration") + } + + // Insert free-after jitter of 15% of the period + freeAfter = registry.Jitter(freeAfter, 0.15) + + // Invert free-after + freeAfter *= -1 + + ips, err := reg.TrackedBefore(time.Now().Add(freeAfter)) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return err + } + for _, ip := range ips { + err := aws.DefaultClient.DeallocateIP(&ip) + if err == nil { + reg.ForgetIP(ip) + } else { + fmt.Fprintf(os.Stderr, "Can't deallocate %v due to %v", ip, err) + } + } + + return nil + }) +} + func main() { if !aws.DefaultClient.Available() { fmt.Fprintln(os.Stderr, "This command must be run from a running ec2 instance") @@ -355,6 +412,20 @@ func main() { Usage: "Show the peered VPC CIDRs associated with current interfaces", Action: actionVpcPeerCidr, }, + { + Name: "registry-list", + Usage: "List all known free IPs in the internal registry", + Action: actionRegistryList, + }, + { + Name: "registry-gc", + Usage: "Free all IPs that have remained unused for a given time interval", + Action: actionRegistryGc, + Flags: []cli.Flag{ + cli.DurationFlag{Name: "free-after", + Value: 0 * time.Second}, + }, + }, } app.Version = version app.Copyright = "(c) 2017-2018 Lyft Inc." diff --git a/freeip.go b/lib/freeip/freeip.go similarity index 97% rename from freeip.go rename to lib/freeip/freeip.go index acd93d8..b47f918 100644 --- a/freeip.go +++ b/lib/freeip/freeip.go @@ -1,4 +1,4 @@ -package cniipvlanvpck8s +package freeip import ( "github.com/lyft/cni-ipvlan-vpc-k8s/aws" diff --git a/lib/jsontime.go b/lib/jsontime.go new file mode 100644 index 0000000..90a02c5 --- /dev/null +++ b/lib/jsontime.go @@ -0,0 +1,31 @@ +package lib + +import ( + "encoding/json" + "time" +) + +// JSONTime is a RFC3339 encoded time with JSON marshallers +type JSONTime struct { + time.Time +} + +// MarshalJSON marshals a JSONTime to an RFC3339 string +func (j *JSONTime) MarshalJSON() ([]byte, error) { + return json.Marshal(j.Time.Format(time.RFC3339)) +} + +// UnmarshalJSON unmarshals a JSONTime to a time.Time +func (j *JSONTime) UnmarshalJSON(js []byte) error { + var rawString string + err := json.Unmarshal(js, &rawString) + if err != nil { + return err + } + t, err := time.Parse(time.RFC3339, rawString) + if err != nil { + return err + } + j.Time = t + return nil +} diff --git a/lib/jsontime_test.go b/lib/jsontime_test.go new file mode 100644 index 0000000..be699fb --- /dev/null +++ b/lib/jsontime_test.go @@ -0,0 +1,35 @@ +package lib + +import ( + "encoding/json" + "testing" + "time" +) + +type Foo struct { + TheTime JSONTime `json:"time"` +} + +func TestJSONTime_MarshalJSON(t *testing.T) { + input := Foo{JSONTime{time.Date(2017, 1, 1, 1, 1, 0, 0, time.UTC)}} + output, err := json.Marshal(&input) + if err != nil { + t.Error(err) + } + if string(output) != `{"time":"2017-01-01T01:01:00Z"}` { + t.Error(string(output)) + } +} + +func TestJSONTime_UnmarshalJSON(t *testing.T) { + input := []byte(`{"time":"2017-01-01T01:01:00Z"}`) + var foo Foo + err := json.Unmarshal(input, &foo) + if err != nil { + t.Error(err) + } + expected := Foo{JSONTime{time.Date(2017, 1, 1, 1, 1, 0, 0, time.UTC)}} + if !foo.TheTime.Time.Equal(expected.TheTime.Time) { + t.Errorf("Times were not equal: %v %v %v", foo.TheTime.Time, expected.TheTime.Time, foo) + } +} diff --git a/lock.go b/lib/lock.go similarity index 96% rename from lock.go rename to lib/lock.go index 17d551d..234598f 100644 --- a/lock.go +++ b/lib/lock.go @@ -1,4 +1,4 @@ -package cniipvlanvpck8s +package lib import ( "fmt" diff --git a/nl/desc.go b/nl/desc.go index 187294d..48229f9 100644 --- a/nl/desc.go +++ b/nl/desc.go @@ -50,12 +50,7 @@ func GetIPs() ([]BoundIP, error) { var namespaces []string files, err := ioutil.ReadDir("/var/run/netns/") - if err != nil { - if err != nil { - // Ignore this error - } - files = []os.FileInfo{} - } else { + if err == nil { for _, file := range files { namespaces = append(namespaces, filepath.Join("/var/run/netns", file.Name())) diff --git a/nl/down_test.go b/nl/down_test.go index a3126bb..0e8bafb 100644 --- a/nl/down_test.go +++ b/nl/down_test.go @@ -13,7 +13,7 @@ func TestDownInterface(t *testing.T) { return } - CreateTestInterface("lyft3") + CreateTestInterface(t, "lyft3") defer RemoveInterface("lyft3") if err := UpInterface("lyft3"); err != nil { @@ -31,7 +31,7 @@ func TestRemoveInterface(t *testing.T) { return } - CreateTestInterface("lyft4") + CreateTestInterface(t, "lyft4") if err := RemoveInterface("lyft4"); err != nil { t.Fatalf("Failed to RemoveInterface lyft4: %v", err) diff --git a/nl/up_test.go b/nl/up_test.go index 3ebc06a..13ea128 100644 --- a/nl/up_test.go +++ b/nl/up_test.go @@ -1,14 +1,13 @@ package nl import ( - "fmt" "os" "testing" "github.com/vishvananda/netlink" ) -func CreateTestInterface(name string) { +func CreateTestInterface(t *testing.T, name string) { lyftBridge := &netlink.Bridge{LinkAttrs: netlink.LinkAttrs{ TxQLen: -1, Name: name, @@ -17,7 +16,7 @@ func CreateTestInterface(name string) { err := netlink.LinkAdd(lyftBridge) if err != nil { RemoveInterface(name) - fmt.Errorf("Could not add %s: %v", lyftBridge.Name, err) + t.Errorf("Could not add %s: %v", lyftBridge.Name, err) } lyft1, _ := netlink.LinkByName(name) @@ -30,7 +29,7 @@ func TestUpInterface(t *testing.T) { return } - CreateTestInterface("lyft1") + CreateTestInterface(t, "lyft1") defer RemoveInterface("lyft1") if err := UpInterface("lyft1"); err != nil { @@ -44,7 +43,7 @@ func TestUpInterfacePoll(t *testing.T) { return } - CreateTestInterface("lyft2") + CreateTestInterface(t, "lyft2") defer RemoveInterface("lyft2") if err := UpInterfacePoll("lyft2"); err != nil { diff --git a/plugin/ipam/main.go b/plugin/ipam/main.go index 6fcb900..456ba44 100644 --- a/plugin/ipam/main.go +++ b/plugin/ipam/main.go @@ -30,9 +30,11 @@ import ( "github.com/containernetworking/plugins/pkg/ns" "github.com/vishvananda/netlink" - "github.com/lyft/cni-ipvlan-vpc-k8s" "github.com/lyft/cni-ipvlan-vpc-k8s/aws" + "github.com/lyft/cni-ipvlan-vpc-k8s/lib" + "github.com/lyft/cni-ipvlan-vpc-k8s/lib/freeip" "github.com/lyft/cni-ipvlan-vpc-k8s/nl" + "github.com/lyft/cni-ipvlan-vpc-k8s/registry" ) // PluginConf contains configuration parameters @@ -87,8 +89,13 @@ func cmdAdd(args *skel.CmdArgs) error { var alloc *aws.AllocationResult // Try to find a free IP first - possibly from a broken container, // or torn down namespace. - free, err := cniipvlanvpck8s.FindFreeIPsAtIndex(conf.IPAM.IfaceIndex) + free, err := freeip.FindFreeIPsAtIndex(conf.IPAM.IfaceIndex) if err == nil && len(free) > 0 { + // Since we found this IP in the free list, remove it from + // the registry + + registry := ®istry.Registry{} + registry.ForgetIP(*free[0].IP) alloc = free[0] } else { // allocate an IP on an available interface @@ -185,7 +192,7 @@ func cmdDel(args *skel.CmdArgs) error { var addrs []netlink.Addr // enter the namespace to grab the list of IPs - err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { + _ = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { iface, err := netlink.LinkByName(args.IfName) if err != nil { return err @@ -200,6 +207,13 @@ func cmdDel(args *skel.CmdArgs) error { aws.DefaultClient.DeallocateIP(&addr.IP) } } + + // Mark this IP as free in the registry + registry := ®istry.Registry{} + for _, addr := range addrs { + registry.TrackIP(addr.IP) + } + return nil } @@ -208,5 +222,5 @@ func main() { skel.PluginMain(cmdAdd, cmdDel, version.PluginSupports(version.Current())) return nil } - _ = cniipvlanvpck8s.LockfileRun(run) + _ = lib.LockfileRun(run) } diff --git a/plugin/unnumbered-ptp/unnumbered-ptp.go b/plugin/unnumbered-ptp/unnumbered-ptp.go index 8a779c1..0aac912 100644 --- a/plugin/unnumbered-ptp/unnumbered-ptp.go +++ b/plugin/unnumbered-ptp/unnumbered-ptp.go @@ -157,7 +157,7 @@ func findFreeTable(start int) (int, error) { } // find first slot that's available for both V4 and V6 usage for i := start; i < math.MaxUint32; i++ { - if allocatedTableIDs[i] == false { + if !allocatedTableIDs[i] { return i, nil } } @@ -471,7 +471,7 @@ func cmdDel(args *skel.CmdArgs) error { // If the device isn't there then don't try to clean up IP masq either. var ipnets []netlink.Addr vethPeerIndex := -1 - err = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { + _ = ns.WithNetNSPath(args.Netns, func(_ ns.NetNS) error { var err error // lookup pod IPs from the args.IfName device (usually eth0) @@ -494,7 +494,7 @@ func cmdDel(args *skel.CmdArgs) error { if err != nil && err != ip.ErrLinkNotFound { return err } - vethPeerIndex, err = netlink.VethPeerIndex(&netlink.Veth{LinkAttrs: *vethIface.Attrs()}) + vethPeerIndex, _ = netlink.VethPeerIndex(&netlink.Veth{LinkAttrs: *vethIface.Attrs()}) return nil }) @@ -507,7 +507,7 @@ func cmdDel(args *skel.CmdArgs) error { addrBits = 32 } - err = ip.TeardownIPMasq(&net.IPNet{IP: ipn.IP, Mask: net.CIDRMask(addrBits, addrBits)}, chain, comment) + _ = ip.TeardownIPMasq(&net.IPNet{IP: ipn.IP, Mask: net.CIDRMask(addrBits, addrBits)}, chain, comment) } if vethPeerIndex != -1 { diff --git a/registry/registry.go b/registry/registry.go new file mode 100644 index 0000000..be6ade3 --- /dev/null +++ b/registry/registry.go @@ -0,0 +1,244 @@ +package registry + +import ( + "encoding/json" + "fmt" + "log" + "math/rand" + "net" + "os" + "path" + "sync" + "time" + + "github.com/lyft/cni-ipvlan-vpc-k8s/lib" +) + +const ( + registryDir = "cni-ipvlan-vpc-k8s" + registryFile = "registry.json" + registrySchemaVersion = 1 +) + +func defaultRegistry() registryContents { + return registryContents{ + SchemaVersion: registrySchemaVersion, + IPs: map[string]*registryIP{}, + } +} + +type registryIP struct { + ReleasedOn lib.JSONTime `json:"released_on"` +} + +type registryContents struct { + SchemaVersion int `json:"schema_version"` + IPs map[string]*registryIP `json:"ips"` +} + +// Registry defines a re-usable IP registry which tracks IPs that are +// free in the system and when they were last released back to the pool. +type Registry struct { + path string + lock sync.Mutex +} + +// registryPath gives a default location for the registry +// which varies based on invoking user ID +func registryPath() string { + uid := os.Getuid() + if uid != 0 { + // Non-root users of the registry + return path.Join("/run/user", fmt.Sprintf("%d", uid), registryDir) + } + + return path.Join("/var/lib", registryDir) +} + +func (r *Registry) ensurePath() (string, error) { + if len(r.path) == 0 { + r.path = registryPath() + } + rpath := r.path + info, err := os.Stat(rpath) + if err != nil || !info.IsDir() { + err = os.MkdirAll(rpath, os.ModeDir|0700) + if err != nil { + return "", err + } + } + rpath = path.Join(rpath, registryFile) + return rpath, nil +} + +func (r *Registry) load() (*registryContents, error) { + // Load the pre-versioned schema + contents := defaultRegistry() + rpath, err := r.ensurePath() + if err != nil { + return nil, err + } + + file, err := os.Open(rpath) + if os.IsNotExist(err) { + // Return an empty registry + return &contents, nil + } else if err != nil { + return nil, err + } + + defer file.Close() + + decoder := json.NewDecoder(file) + if decoder == nil { + return nil, fmt.Errorf("invalid decoder") + } + + err = decoder.Decode(&contents) + if err != nil { + log.Printf("invalid registry format, returning empty registry %v", err) + contents = defaultRegistry() + } + + // Reset the registry if the version is not what we can deal with. + // Add more states here, or just let the registry be blown away + // on invalid loads + if contents.SchemaVersion != registrySchemaVersion { + contents = defaultRegistry() + } + if contents.IPs == nil { + contents = defaultRegistry() + } + return &contents, nil +} + +func (r *Registry) save(rc *registryContents) error { + rpath, err := r.ensurePath() + if err != nil { + return err + } + file, err := os.OpenFile(rpath, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return err + } + defer file.Close() + + encoder := json.NewEncoder(file) + if encoder == nil { + return fmt.Errorf("could not make a new encoder") + } + rc.SchemaVersion = registrySchemaVersion + err = encoder.Encode(rc) + return err +} + +// TrackIP records an IP in the free registry with the current system +// time as the current freed-time. If an IP is freed again, the time +// will be updated to the new current time. +func (r *Registry) TrackIP(ip net.IP) error { + r.lock.Lock() + defer r.lock.Unlock() + + contents, err := r.load() + if err != nil { + return err + } + + contents.IPs[ip.String()] = ®istryIP{lib.JSONTime{time.Now()}} + return r.save(contents) +} + +// ForgetIP removes an IP from the registry +func (r *Registry) ForgetIP(ip net.IP) error { + r.lock.Lock() + defer r.lock.Unlock() + + contents, err := r.load() + if err != nil { + return err + } + + delete(contents.IPs, ip.String()) + + return r.save(contents) +} + +// HasIP checks if an IP is in an registry +func (r *Registry) HasIP(ip net.IP) (bool, error) { + r.lock.Lock() + defer r.lock.Unlock() + + contents, err := r.load() + if err != nil { + return false, err + } + + _, ok := contents.IPs[ip.String()] + return ok, nil +} + +// TrackedBefore returns a list of all IPs last recorded time _before_ +// the time passed to this function. You probably want to call this with +// time.Now().Add(-duration) +func (r *Registry) TrackedBefore(t time.Time) ([]net.IP, error) { + r.lock.Lock() + defer r.lock.Unlock() + + contents, err := r.load() + if err != nil { + return nil, err + } + + returned := []net.IP{} + for ipString, entry := range contents.IPs { + if entry.ReleasedOn.Before(t) { + ip := net.ParseIP(ipString) + if ip == nil { + continue + } + returned = append(returned, ip) + } + } + return returned, nil +} + +// Clear clears the registry unconditionally +func (r *Registry) Clear() error { + r.lock.Lock() + defer r.lock.Unlock() + + rpath, err := r.ensurePath() + if err != nil { + return err + } + + err = os.Remove(rpath) + if os.IsNotExist(err) { + return nil + } + + return err +} + +// List returns a list of all tracked IPs +func (r *Registry) List() (ret []net.IP, err error) { + r.lock.Lock() + defer r.lock.Unlock() + + contents, err := r.load() + if err != nil { + return nil, err + } + for stringIP := range contents.IPs { + ret = append(ret, net.ParseIP(stringIP)) + } + return +} + +// Jitter takes a duration and adjusts it forward by a up to `pct` percent +// uniformly. +func Jitter(d time.Duration, pct float64) time.Duration { + jitter := rand.Int63n(int64(float64(d) * pct)) + d += time.Duration(jitter) + return d +} diff --git a/registry/registry_test.go b/registry/registry_test.go new file mode 100644 index 0000000..2015132 --- /dev/null +++ b/registry/registry_test.go @@ -0,0 +1,121 @@ +package registry + +import ( + "net" + "testing" + "time" +) + +const ( + IP1 = "127.0.0.1" + IP2 = "127.0.0.2" + IP3 = "127.0.0.3" +) + +func TestRegistry_TrackIp(t *testing.T) { + r := &Registry{} + + err := r.TrackIP(net.ParseIP(IP1)) + if err != nil { + t.Fatalf("Failed to track IP %v", err) + } + + if ok, err := r.HasIP(net.ParseIP(IP1)); !ok || err != nil { + t.Fatalf("Did not remember IP %v %v %v", IP1, ok, err) + } +} + +func TestRegistry_Clear(t *testing.T) { + r := &Registry{} + + err := r.Clear() + if err != nil { + t.Fatalf("clear failed %v", err) + } + + r.TrackIP(net.ParseIP(IP3)) + if ok, err := r.HasIP(net.ParseIP(IP3)); !ok || err != nil { + t.Fatalf("Did not remember IP %v %v %v", IP3, ok, err) + } + + err = r.Clear() + if err != nil { + t.Fatalf("clear failed %v", err) + } + + if ok, err := r.HasIP(net.ParseIP(IP3)); ok || err != nil { + t.Fatalf("Did not forget IP %v %v %v", IP3, ok, err) + } +} + +func TestRegistry_ForgetIP(t *testing.T) { + r := &Registry{} + + err := r.Clear() + if err != nil { + t.Fatalf("clear failed %v", err) + } + + r.TrackIP(net.ParseIP(IP2)) + if ok, err := r.HasIP(net.ParseIP(IP2)); !ok || err != nil { + t.Fatalf("Did not remember IP %v %v %v", IP3, ok, err) + } + + err = r.ForgetIP(net.ParseIP(IP2)) + if err != nil { + t.Fatalf("forget failed %v", err) + } + + if ok, err := r.HasIP(net.ParseIP(IP2)); ok || err != nil { + t.Fatalf("Did not forget IP %v %v %v", IP2, ok, err) + } + + // Forget an IP never registered + err = r.ForgetIP(net.ParseIP(IP1)) + if err != nil { + t.Fatalf("forgetting an IP not tracked should not be an error") + } + +} + +func TestRegistry_TrackedBefore(t *testing.T) { + r := &Registry{} + + err := r.Clear() + if err != nil { + t.Fatalf("clear failed %v", err) + } + + r.TrackIP(net.ParseIP(IP1)) + now := time.Now() + + before, err := r.TrackedBefore(now.Add(100 * time.Hour)) + if err != nil { + t.Fatalf("error tracked before %v", err) + } + + if len(before) != 1 { + t.Fatalf("Invalid number of entries, got %v", before) + } + + after, err := r.TrackedBefore(now.Add(-100 * time.Hour)) + if err != nil { + t.Fatalf("error tracked before %v", err) + } + + if len(after) != 0 { + t.Fatalf("Should return no IPs before the future, got %v", after) + } +} + +func TestJitter(t *testing.T) { + d1 := 1 * time.Second + d1p := Jitter(d1, 0.10) + if d1 >= d1p { + t.Fatalf("Jitter did not move forward: %v %v", d1, d1p) + } + + if d1p > 1101*time.Millisecond { + t.Fatalf("Jitter moved more than 10pct forward %v", d1p) + } +}