tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

tlsrp.go (6622B)


      1 package main
      2 
      3 import (
      4 	"context"
      5 	"crypto/tls"
      6 	"errors"
      7 	"flag"
      8 	"fmt"
      9 	"golang.org/x/sys/unix"
     10 	"io"
     11 	"log"
     12 	"net"
     13 	"os"
     14 	"os/signal"
     15 	"strings"
     16 	"sync"
     17 	"time"
     18 )
     19 
     20 // We only enforce a timeout on the handshake. After the handshake is complete,
     21 // the sink is responsible for timing-out clients.
     22 const handshakeTimeout = 10 * time.Second
     23 
     24 // softExit or hardExit is closed when exiting. During a soft exit, we
     25 // accept no new clients, but existing clients should finish gracefully;
     26 // during a hard exit, we accept no new clients, and exiting clients
     27 // should be forcefully disconnected. Since we have two different exit
     28 // modes, we don't use context.Context's to handle cancellation.
     29 var softExit chan struct{}
     30 var hardExit chan struct{}
     31 
     32 func init() {
     33 	softExit = make(chan struct{})
     34 	hardExit = make(chan struct{})
     35 }
     36 
     37 type socket interface {
     38 	io.ReadWriteCloser
     39 	CloseWrite() error
     40 }
     41 
     42 func parseSNI(sni string) (hostname, error) {
     43 	var hostname hostname
     44 
     45 	if sni != "" {
     46 		var err error
     47 		hostname, err = parseHostname(sni)
     48 		if err != nil {
     49 			return nil, err
     50 		}
     51 	}
     52 
     53 	return hostname, nil
     54 }
     55 
     56 func handshake(conn *tls.Conn) error {
     57 	ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout)
     58 	defer cancel()
     59 
     60 	handshakeErr := make(chan error, 1)
     61 	go func() { handshakeErr <- conn.HandshakeContext(ctx) }()
     62 
     63 	var err error
     64 	select {
     65 	case err = <-handshakeErr:
     66 	case <-hardExit:
     67 	}
     68 
     69 	return err
     70 }
     71 
     72 func splice(a, b socket) error {
     73 	a2bErr := make(chan error, 1)
     74 	go func() {
     75 		_, err := io.Copy(b, a)
     76 		a2bErr <- err
     77 	}()
     78 
     79 	b2aErr := make(chan error, 1)
     80 	go func() {
     81 		_, err := io.Copy(a, b)
     82 		b2aErr <- err
     83 	}()
     84 
     85 	// In the first two cases, we call CloseWrite (not Close) on the
     86 	// corresponding destination connection, so that the other copy goroutine
     87 	// can continue reading from it (until it hits EOF). In the hard exit
     88 	// case, we want to exit ASAP, so we close both ends of both connections.
     89 	var err error
     90 	select {
     91 	case err = <-a2bErr:
     92 		b.CloseWrite()
     93 		<-b2aErr
     94 	case err = <-b2aErr:
     95 		a.CloseWrite()
     96 		<-a2bErr
     97 	case <-hardExit:
     98 		a.Close()
     99 		b.Close()
    100 		<-a2bErr
    101 		<-b2aErr
    102 	}
    103 
    104 	return err
    105 }
    106 
    107 func proxy(client *tls.Conn) {
    108 	logf := func(format string, a ...any) {
    109 		log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...))
    110 	}
    111 
    112 	logf("connected")
    113 	defer logf("disconnected")
    114 	defer client.Close()
    115 
    116 	err := handshake(client)
    117 	if err != nil {
    118 		logf("%s", err)
    119 		return
    120 	}
    121 
    122 	hostname, err := parseSNI(client.ConnectionState().ServerName)
    123 	if err != nil {
    124 		logf("rejected: %s", err)
    125 		return
    126 	}
    127 
    128 	sink, err := lookupSink(hostname)
    129 	if err != nil {
    130 		logf("rejected: %s", err)
    131 		return
    132 	}
    133 
    134 	server, err := net.Dial(sink.network, sink.address)
    135 	if err != nil {
    136 		logf("%s", err)
    137 		return
    138 	}
    139 	defer server.Close()
    140 
    141 	err = splice(client, server.(socket))
    142 	if err != nil {
    143 		logf("%s", err)
    144 		return
    145 	}
    146 }
    147 
    148 func accept(listener net.Listener) {
    149 	logf := func(format string, a ...any) {
    150 		log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...))
    151 	}
    152 
    153 	logf("accepting")
    154 	defer logf("closing")
    155 
    156 	var wg sync.WaitGroup
    157 	defer wg.Wait()
    158 
    159 	for {
    160 		conn, err := listener.Accept()
    161 		if err != nil {
    162 			if !errors.Is(err, net.ErrClosed) {
    163 				logf("%s", err)
    164 				os.Exit(1)
    165 				// XXX: Exiting here might be an over-reaction to the error.
    166 				// Although, keep in mind that tlsrp should be running under
    167 				// some service manager that should restart tlsrp if it exits.
    168 			}
    169 			return
    170 		}
    171 
    172 		wg.Add(1)
    173 		go func() {
    174 			proxy(conn.(*tls.Conn))
    175 			wg.Done()
    176 		}()
    177 	}
    178 }
    179 
    180 func listen(sources []string) ([]net.Listener, error) {
    181 	tlsConfig := &tls.Config{
    182 		GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
    183 			hostname, err := parseHostname(chi.ServerName)
    184 			if err != nil {
    185 				return nil, err
    186 			}
    187 
    188 			cert, err := lookupCert(hostname)
    189 			if err != nil {
    190 				return nil, err
    191 			}
    192 
    193 			return cert.cert, nil
    194 		},
    195 	}
    196 
    197 	listeners := make([]net.Listener, 0, len(sources))
    198 	for _, source := range sources {
    199 		fields := strings.SplitN(source, ":", 2)
    200 		if len(fields) != 2 {
    201 			return nil, fmt.Errorf("invalid source: expected colon separating network type from address")
    202 		}
    203 
    204 		network := fields[0]
    205 		address := fields[1]
    206 
    207 		switch network {
    208 		case "tcp", "unix":
    209 			listener, err := tls.Listen(network, address, tlsConfig)
    210 			if err != nil {
    211 				return nil, err
    212 			}
    213 			listeners = append(listeners, listener)
    214 		default:
    215 			return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"")
    216 		}
    217 	}
    218 
    219 	return listeners, nil
    220 }
    221 
    222 func manageConfig(cfgPath string) {
    223 	cfg, err := loadConfig(cfgPath)
    224 	if err != nil {
    225 		log.Printf("failed to load initial configuration: %s\n", err)
    226 		// Proceed with an empty configuration (i.e., with no sinks and no
    227 		// certs), causing every client to be rejected.
    228 		cfg = &config{}
    229 	}
    230 
    231 	sighup := make(chan os.Signal, 1)
    232 	signal.Notify(sighup, unix.SIGHUP)
    233 
    234 	for {
    235 		select {
    236 		case <-sighup:
    237 			log.Println("received SIGHUP; reloading configuration")
    238 			newCfg, err := loadConfig(cfgPath)
    239 			if err == nil {
    240 				cfg = newCfg
    241 			} else {
    242 				log.Printf("failed to reload configuration: %s\n", err)
    243 			}
    244 
    245 		case msg := <-lookupSinkChan:
    246 			if msg.hostname != nil {
    247 				for _, sink := range cfg.sinks {
    248 					if sink.pattern.matches(msg.hostname) {
    249 						msg.reply <- sink
    250 						break
    251 					}
    252 				}
    253 			} else if len(cfg.sinks) > 0 {
    254 				msg.reply <- cfg.sinks[0]
    255 			}
    256 			close(msg.reply)
    257 
    258 		case msg := <-lookupCertChan:
    259 			if msg.hostname != nil {
    260 				for _, cert := range cfg.certs {
    261 					if cert.pattern.matches(msg.hostname) {
    262 						msg.reply <- cert
    263 						break
    264 					}
    265 				}
    266 			} else if len(cfg.certs) > 0 {
    267 				msg.reply <- cfg.certs[0]
    268 			}
    269 			close(msg.reply)
    270 		}
    271 	}
    272 }
    273 
    274 func manageExit() {
    275 	sigs := make(chan os.Signal, 3)
    276 	signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
    277 
    278 	<-sigs
    279 	log.Println("received SIGINT/SIGQUIT/SIGTERM; exiting softly")
    280 	close(softExit)
    281 
    282 	<-sigs
    283 	log.Println("received another SIGINT/SIGQUIT/SIGTERM; exiting harshly")
    284 	close(hardExit)
    285 }
    286 
    287 func main() {
    288 	flag.Usage = func() {
    289 		fmt.Fprintf(os.Stderr, "Usage: %s CONFIG_PATH SOURCE...", os.Args[0])
    290 	}
    291 	flag.Parse()
    292 	if flag.NArg() < 2 {
    293 		log.Fatalln("expected at least 2 arguments")
    294 	}
    295 
    296 	cfgPath := flag.Args()[0]
    297 
    298 	listeners, err := listen(flag.Args()[1:])
    299 	if err != nil {
    300 		log.Fatalln(err.Error())
    301 	}
    302 
    303 	go manageExit()
    304 	go manageConfig(cfgPath)
    305 
    306 	var wg sync.WaitGroup
    307 	defer wg.Wait()
    308 
    309 	for _, l := range listeners {
    310 		wg.Add(1)
    311 		go func (l net.Listener) {
    312 			accept(l)
    313 			wg.Done()
    314 		}(l)
    315 	}
    316 
    317 	select {
    318 	case <-softExit:
    319 	case <-hardExit:
    320 	}
    321 	for _, l := range listeners {
    322 		l.Close()
    323 	}
    324 }