forked from bwalex/tftp-http-proxy
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
131 lines (114 loc) · 3.65 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package main
import (
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/pin/tftp"
systemd "github.com/coreos/go-systemd/daemon"
)
const httpBaseUrlDefault = "http://127.0.0.1/tftp"
const tftpTimeoutDefault = 5 * time.Second
const tftpBindAddrDefault = ":69"
const appendPathDefault = false
var globalState = struct {
httpBaseUrl string
httpClient *http.Client
appendPath bool
}{
httpBaseUrl: httpBaseUrlDefault,
httpClient: nil,
appendPath: appendPathDefault,
}
func tftpReadHandler(filename string, rf io.ReaderFrom) error {
raddr := rf.(tftp.OutgoingTransfer).RemoteAddr() // net.UDPAddr
log.Printf("INFO: New TFTP request (%s) from %s", filename, raddr.IP.String())
uri := globalState.httpBaseUrl
if globalState.appendPath {
// No need to validate url any further, http.NewRequest does
// this for us using url.Parse(). We already checked that base
// contains scheme and host and ends with a slash. We assume
// that appending filename does not change scheme, host and initial
// part of path of URL.
uri = uri + strings.TrimLeft(filename, "/")
}
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
log.Printf("ERR: http request setup failed: %v", err)
return err
}
req.Header.Add("X-TFTP-IP", raddr.IP.String())
req.Header.Add("X-TFTP-Port", fmt.Sprintf("%d", raddr.Port))
req.Header.Add("X-TFTP-File", filename)
resp, err := globalState.httpClient.Do(req)
if err != nil {
log.Printf("ERR: http request failed: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
log.Printf("INFO: http FileNotFound response: %s", resp.Status)
return fmt.Errorf("File not found")
} else if resp.StatusCode != http.StatusOK {
log.Printf("ERR: http request returned status %s", resp.Status)
return fmt.Errorf("HTTP request error: %s", resp.Status)
}
// Use ContentLength, if provided, to set TSize option
if resp.ContentLength >= 0 {
rf.(tftp.OutgoingTransfer).SetSize(resp.ContentLength)
}
_, err = rf.ReadFrom(resp.Body)
if err != nil {
log.Printf("ERR: ReadFrom failed: %v", err)
return err
}
return nil
}
func parseBaseURL(baseUrl string, appendPath bool) string {
u, err := url.ParseRequestURI(baseUrl)
if err != nil {
log.Panicf("FATAL: invalid base URL: %v\n", err)
}
if (u.Scheme == "") {
log.Panicf("FATAL: invalid base URL: No scheme found.\n")
}
if (u.Host == "") {
log.Panicf("FATAL: invalid base URL: No host found.\n")
}
base := u.String()
if appendPath && !strings.HasSuffix(base, "/") {
return base + "/"
} else {
return base
}
}
func main() {
httpBaseUrlPtr := flag.String("http-base-url", httpBaseUrlDefault, "HTTP base URL")
appendPathPtr := flag.Bool("http-append-path", appendPathDefault, "append TFTP filename to URL")
tftpTimeoutPtr := flag.Duration("tftp-timeout", tftpTimeoutDefault, "TFTP timeout")
bindAddrPtr := flag.String("tftp-bind-address", tftpBindAddrDefault, "TFTP addr to bind to")
flag.Parse()
globalState.httpBaseUrl = parseBaseURL(*httpBaseUrlPtr, *appendPathPtr)
globalState.httpClient = &http.Client{}
globalState.appendPath = *appendPathPtr
s := tftp.NewServer(tftpReadHandler, nil)
s.SetTimeout(*tftpTimeoutPtr)
err := s.ListenAndServe2(*bindAddrPtr, func() {
log.Printf("INFO: Listening TFTP requests on: %s", *bindAddrPtr)
sent, err := systemd.SdNotify(true, "READY=1\n");
if err != nil {
log.Printf("WARN: Unable to send systemd daemon successful start message: %v\n", err)
} else if (sent) {
log.Printf("DEBUG: Systemd was notified.\n")
} else {
log.Printf("DEBUG: Systemd notifications are not supported.\n")
}
})
if err != nil {
log.Panicf("FATAL: tftp server: %v\n", err)
}
}