tlsrp.go (6622B)
1 package main 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "flag" 8 "fmt" 9 "golang.org/x/sys/unix" 10 "io" 11 "log" 12 "net" 13 "os" 14 "os/signal" 15 "strings" 16 "sync" 17 "time" 18 ) 19 20 // We only enforce a timeout on the handshake. After the handshake is complete, 21 // the sink is responsible for timing-out clients. 22 const handshakeTimeout = 10 * time.Second 23 24 // softExit or hardExit is closed when exiting. During a soft exit, we 25 // accept no new clients, but existing clients should finish gracefully; 26 // during a hard exit, we accept no new clients, and exiting clients 27 // should be forcefully disconnected. Since we have two different exit 28 // modes, we don't use context.Context's to handle cancellation. 29 var softExit chan struct{} 30 var hardExit chan struct{} 31 32 func init() { 33 softExit = make(chan struct{}) 34 hardExit = make(chan struct{}) 35 } 36 37 type socket interface { 38 io.ReadWriteCloser 39 CloseWrite() error 40 } 41 42 func parseSNI(sni string) (hostname, error) { 43 var hostname hostname 44 45 if sni != "" { 46 var err error 47 hostname, err = parseHostname(sni) 48 if err != nil { 49 return nil, err 50 } 51 } 52 53 return hostname, nil 54 } 55 56 func handshake(conn *tls.Conn) error { 57 ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout) 58 defer cancel() 59 60 handshakeErr := make(chan error, 1) 61 go func() { handshakeErr <- conn.HandshakeContext(ctx) }() 62 63 var err error 64 select { 65 case err = <-handshakeErr: 66 case <-hardExit: 67 } 68 69 return err 70 } 71 72 func splice(a, b socket) error { 73 a2bErr := make(chan error, 1) 74 go func() { 75 _, err := io.Copy(b, a) 76 a2bErr <- err 77 }() 78 79 b2aErr := make(chan error, 1) 80 go func() { 81 _, err := io.Copy(a, b) 82 b2aErr <- err 83 }() 84 85 // In the first two cases, we call CloseWrite (not Close) on the 86 // corresponding destination connection, so that the other copy goroutine 87 // can continue reading from it (until it hits EOF). In the hard exit 88 // case, we want to exit ASAP, so we close both ends of both connections. 89 var err error 90 select { 91 case err = <-a2bErr: 92 b.CloseWrite() 93 <-b2aErr 94 case err = <-b2aErr: 95 a.CloseWrite() 96 <-a2bErr 97 case <-hardExit: 98 a.Close() 99 b.Close() 100 <-a2bErr 101 <-b2aErr 102 } 103 104 return err 105 } 106 107 func proxy(client *tls.Conn) { 108 logf := func(format string, a ...any) { 109 log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...)) 110 } 111 112 logf("connected") 113 defer logf("disconnected") 114 defer client.Close() 115 116 err := handshake(client) 117 if err != nil { 118 logf("%s", err) 119 return 120 } 121 122 hostname, err := parseSNI(client.ConnectionState().ServerName) 123 if err != nil { 124 logf("rejected: %s", err) 125 return 126 } 127 128 sink, err := lookupSink(hostname) 129 if err != nil { 130 logf("rejected: %s", err) 131 return 132 } 133 134 server, err := net.Dial(sink.network, sink.address) 135 if err != nil { 136 logf("%s", err) 137 return 138 } 139 defer server.Close() 140 141 err = splice(client, server.(socket)) 142 if err != nil { 143 logf("%s", err) 144 return 145 } 146 } 147 148 func accept(listener net.Listener) { 149 logf := func(format string, a ...any) { 150 log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...)) 151 } 152 153 logf("accepting") 154 defer logf("closing") 155 156 var wg sync.WaitGroup 157 defer wg.Wait() 158 159 for { 160 conn, err := listener.Accept() 161 if err != nil { 162 if !errors.Is(err, net.ErrClosed) { 163 logf("%s", err) 164 os.Exit(1) 165 // XXX: Exiting here might be an over-reaction to the error. 166 // Although, keep in mind that tlsrp should be running under 167 // some service manager that should restart tlsrp if it exits. 168 } 169 return 170 } 171 172 wg.Add(1) 173 go func() { 174 proxy(conn.(*tls.Conn)) 175 wg.Done() 176 }() 177 } 178 } 179 180 func listen(sources []string) ([]net.Listener, error) { 181 tlsConfig := &tls.Config{ 182 GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { 183 hostname, err := parseHostname(chi.ServerName) 184 if err != nil { 185 return nil, err 186 } 187 188 cert, err := lookupCert(hostname) 189 if err != nil { 190 return nil, err 191 } 192 193 return cert.cert, nil 194 }, 195 } 196 197 listeners := make([]net.Listener, 0, len(sources)) 198 for _, source := range sources { 199 fields := strings.SplitN(source, ":", 2) 200 if len(fields) != 2 { 201 return nil, fmt.Errorf("invalid source: expected colon separating network type from address") 202 } 203 204 network := fields[0] 205 address := fields[1] 206 207 switch network { 208 case "tcp", "unix": 209 listener, err := tls.Listen(network, address, tlsConfig) 210 if err != nil { 211 return nil, err 212 } 213 listeners = append(listeners, listener) 214 default: 215 return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"") 216 } 217 } 218 219 return listeners, nil 220 } 221 222 func manageConfig(cfgPath string) { 223 cfg, err := loadConfig(cfgPath) 224 if err != nil { 225 log.Printf("failed to load initial configuration: %s\n", err) 226 // Proceed with an empty configuration (i.e., with no sinks and no 227 // certs), causing every client to be rejected. 228 cfg = &config{} 229 } 230 231 sighup := make(chan os.Signal, 1) 232 signal.Notify(sighup, unix.SIGHUP) 233 234 for { 235 select { 236 case <-sighup: 237 log.Println("received SIGHUP; reloading configuration") 238 newCfg, err := loadConfig(cfgPath) 239 if err == nil { 240 cfg = newCfg 241 } else { 242 log.Printf("failed to reload configuration: %s\n", err) 243 } 244 245 case msg := <-lookupSinkChan: 246 if msg.hostname != nil { 247 for _, sink := range cfg.sinks { 248 if sink.pattern.matches(msg.hostname) { 249 msg.reply <- sink 250 break 251 } 252 } 253 } else if len(cfg.sinks) > 0 { 254 msg.reply <- cfg.sinks[0] 255 } 256 close(msg.reply) 257 258 case msg := <-lookupCertChan: 259 if msg.hostname != nil { 260 for _, cert := range cfg.certs { 261 if cert.pattern.matches(msg.hostname) { 262 msg.reply <- cert 263 break 264 } 265 } 266 } else if len(cfg.certs) > 0 { 267 msg.reply <- cfg.certs[0] 268 } 269 close(msg.reply) 270 } 271 } 272 } 273 274 func manageExit() { 275 sigs := make(chan os.Signal, 3) 276 signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) 277 278 <-sigs 279 log.Println("received SIGINT/SIGQUIT/SIGTERM; exiting softly") 280 close(softExit) 281 282 <-sigs 283 log.Println("received another SIGINT/SIGQUIT/SIGTERM; exiting harshly") 284 close(hardExit) 285 } 286 287 func main() { 288 flag.Usage = func() { 289 fmt.Fprintf(os.Stderr, "Usage: %s CONFIG_PATH SOURCE...", os.Args[0]) 290 } 291 flag.Parse() 292 if flag.NArg() < 2 { 293 log.Fatalln("expected at least 2 arguments") 294 } 295 296 cfgPath := flag.Args()[0] 297 298 listeners, err := listen(flag.Args()[1:]) 299 if err != nil { 300 log.Fatalln(err.Error()) 301 } 302 303 go manageExit() 304 go manageConfig(cfgPath) 305 306 var wg sync.WaitGroup 307 defer wg.Wait() 308 309 for _, l := range listeners { 310 wg.Add(1) 311 go func (l net.Listener) { 312 accept(l) 313 wg.Done() 314 }(l) 315 } 316 317 select { 318 case <-softExit: 319 case <-hardExit: 320 } 321 for _, l := range listeners { 322 l.Close() 323 } 324 }