tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit edd853eceb61e994d9726533072ed45fd8aad88c
parent d1c7a12e46b97d25cd4a67f17c21ef597ec82235
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Tue, 16 Jul 2024 14:47:35 -0700

Restructure hostnames and make listeners print their address

The latter is promised in the man page.

Diffstat:
Mtlsrp.go | 152+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
1 file changed, 91 insertions(+), 61 deletions(-)

diff --git a/tlsrp.go b/tlsrp.go @@ -19,15 +19,10 @@ import ( "time" ) -// TODO: FS-based config -// foo.rr3.xyz -// | _cert -// | _key -// | _unix OR _tcp -// Leading wildcards: -// _.rr3.xyz -// Explicit non-wildcard preferred. -// Just "_" means default for clients with no SNI support. +// TODO: Scrutinize. In particular, compare with sltls. +// TODO: Add support for more than just alternation in patterns, like +// arbitrary regex. +// XXX: We currently don't length check hostnames or the labels within. // We only enforce a timeout on the handshake. After the handshake is complete, // the sink is responsible for timing-out clients. @@ -54,7 +49,7 @@ func lookupSink(hostname hostname) (net.Addr, error) { } sink, ok := <-reply if !ok { - return nil, fmt.Errorf("no sink for hostname %q", hostname) + return nil, fmt.Errorf("no sink for hostname %s", hostname) } return sink, nil } @@ -72,71 +67,98 @@ func lookupCert(hostname hostname) (*tls.Certificate, error) { } cert, ok := <-reply if !ok { - return nil, fmt.Errorf("no certificate for hostname %q", hostname) + return nil, fmt.Errorf("no certificate for hostname %s", hostname) } return cert, nil } -var errHostnameEmpty = errors.New("empty hostname") - -type hostname string +type label string -// XXX: We currently don't length check hostnames or the labels within. -func parseHostname(s string) (hostname, error) { +func parseLabel(s string) (label, error) { if len(s) == 0 { - return "", errHostnameEmpty + return "", fmt.Errorf("empty label") } - + buf := make([]byte, 0, len(s)) - labels := strings.Split(s, ".") - for _, label := range labels { - if len(label) == 0 { - return "", fmt.Errorf("illegal hostname: empty label") + for i, r := range s { + first := i == 0 + last := i == len(s) - 1 + + switch { + case 'A' <= r && r <= 'Z': + r += 'a' - 'A' + case 'a' <= r && r <= 'z': + // Ok + case '0' <= r && r <= '9': + // Ok + case r == '-' && (!first && !last): + // Ok + case r == '-' && first: + return "", fmt.Errorf("hyphen at start of label") + case r == '-' && last: + return "", fmt.Errorf("hyphen at end of label") + default: + return "", fmt.Errorf("illegal rune in label: %q", r) } - if len(buf) > 0 { - buf = append(buf, '.') + buf = append(buf, byte(r)) + } + + return label(buf), nil +} + +type hostname []label + +func (hostname hostname) String() string { + // Ughh, Go can't convert between hostname and []string, + // so we can't use strings.Join. + + var sb strings.Builder + for i, label := range hostname { + if i > 0 { + sb.WriteByte('.') } + sb.WriteString(string(label)) + } - for i, r := range label { - first := i == 0 - last := i == len(label) - 1 - switch { - case 'A' <= r && r <= 'Z': - r += 'a' - 'A' - case 'a' <= r && r <= 'z': - // Ok - case '0' <= r && r <= '9': - // Ok - case r == '-' && (!first && !last): - // Ok - case r == '-' && first: - return "", fmt.Errorf("illegal hostname: hyphen at start of label") - case r == '-' && last: - return "", fmt.Errorf("illegal hostname: hyphen at end of label") - default: - return "", fmt.Errorf("illegal hostname: illegal rune: %q", r) - } - buf = append(buf, byte(r)) + return sb.String() +} + +func (hostname0 hostname) equal(hostname1 hostname) bool { + return slices.Equal(hostname0, hostname1) +} + +func parseHostname(s string) (hostname, error) { + if len(s) == 0 { + return nil, nil + } + + labelStrs := strings.Split(s, ".") + labels := make([]label, 0, len(labelStrs)) + + for _, labelStr := range labelStrs { + label, err := parseLabel(labelStr) + if err != nil { + return nil, fmt.Errorf("illegal hostname: %w", err) } + labels = append(labels, label) } - return hostname(buf), nil + return hostname(labels), nil } -// TODO: Add support for more than just alternation in patterns. type pattern []hostname -func (pattern pattern) matches(hostname hostname) bool { - return slices.Contains(pattern, hostname) +func (pat pattern) matches(hostname hostname) bool { + return slices.ContainsFunc(pat, hostname.equal) } -func parsePattern(ss []string) (pattern, error) { - pat := make(pattern, 0, len(ss)) +func parsePattern(hostnameStrs []string) (pattern, error) { + pat := make(pattern, 0, len(hostnameStrs)) - for _, s := range ss { - hostname, err := parseHostname(s) + for _, hostnameStr := range hostnameStrs { + hostname, err := parseHostname(hostnameStr) if err != nil { return nil, err } @@ -274,7 +296,7 @@ func manageConfig(cfgPath string) { } case msg := <-lookupSinkChan: - if msg.hostname != "" { + if msg.hostname != nil { for _, sink := range cfg.sinks { if sink.pattern.matches(msg.hostname) { msg.reply <- sink @@ -287,7 +309,7 @@ func manageConfig(cfgPath string) { close(msg.reply) case msg := <-lookupCertChan: - if msg.hostname != "" { + if msg.hostname != nil { for _, cert := range cfg.certs { if cert.pattern.matches(msg.hostname) { msg.reply <- cert.cert @@ -359,7 +381,7 @@ func splice(a, b conn) error { } func proxy(client *tls.Conn) { - logf := func(format string, a ...interface{}) { + logf := func(format string, a ...any) { log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...)) } @@ -369,12 +391,12 @@ func proxy(client *tls.Conn) { err := handshake(client) if err != nil { - logf("handshake error: %s", err) + logf("%s", err) return } hostname, err := parseHostname(client.ConnectionState().ServerName) - if err != nil && !errors.Is(err, errHostnameEmpty) { + if err != nil { logf("rejected: %s", err) return } @@ -387,14 +409,14 @@ func proxy(client *tls.Conn) { sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String()) if err != nil { - logf("dial error: %s", err) + logf("%s", err) return } defer sink.Close() err = splice(client, sink.(conn)) if err != nil { - logf("splice error: %s", err) + logf("%s", err) return } } @@ -403,11 +425,19 @@ func accept(listener net.Listener) { var wg sync.WaitGroup defer wg.Wait() + logf := func(format string, a ...any) { + log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...)) + } + + logf("accepting") + defer logf("closing") + for { conn, err := listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { - log.Fatalf("source %s error: %s\n", listener.Addr(), err) + logf("%s", err) + os.Exit(1) } return } @@ -449,7 +479,7 @@ func listen(sources []string) ([]net.Listener, error) { tlsConfig := &tls.Config{ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { hostname, err := parseHostname(chi.ServerName) - if err != nil && !errors.Is(err, errHostnameEmpty) { + if err != nil { return nil, err } return lookupCert(hostname)