package dns

import (
	"context"
	"fmt"
	"net"
	"strings"

	"github.com/containers/gvisor-tap-vsock/pkg/types"
	"github.com/miekg/dns"
	log "github.com/sirupsen/logrus"
)

type dnsHandler struct {
	zones []types.Zone
}

func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg) {
	m := new(dns.Msg)
	m.SetReply(r)
	m.RecursionAvailable = true
	h.addAnswers(m)
	if err := w.WriteMsg(m); err != nil {
		log.Error(err)
	}
}

func (h *dnsHandler) addAnswers(m *dns.Msg) {
	for _, q := range m.Question {
		for _, zone := range h.zones {
			zoneSuffix := fmt.Sprintf(".%s", zone.Name)
			if strings.HasSuffix(q.Name, zoneSuffix) {
				if q.Qtype != dns.TypeA {
					return
				}
				for _, record := range zone.Records {
					withoutZone := strings.TrimSuffix(q.Name, zoneSuffix)
					if (record.Name != "" && record.Name == withoutZone) ||
						(record.Regexp != nil && record.Regexp.MatchString(withoutZone)) {
						m.Answer = append(m.Answer, &dns.A{
							Hdr: dns.RR_Header{
								Name:   q.Name,
								Rrtype: dns.TypeA,
								Class:  dns.ClassINET,
								Ttl:    0,
							},
							A: record.IP,
						})
						return
					}
				}
				if !zone.DefaultIP.Equal(net.IP("")) {
					m.Answer = append(m.Answer, &dns.A{
						Hdr: dns.RR_Header{
							Name:   q.Name,
							Rrtype: dns.TypeA,
							Class:  dns.ClassINET,
							Ttl:    0,
						},
						A: zone.DefaultIP,
					})
					return
				}
				m.Rcode = dns.RcodeNameError
				return
			}
		}

		resolver := net.Resolver{
			PreferGo: false,
		}
		switch q.Qtype {
		case dns.TypeNS:
			records, err := resolver.LookupNS(context.TODO(), q.Name)
			if err != nil {
				m.Rcode = dns.RcodeNameError
				return
			}
			for _, ns := range records {
				m.Answer = append(m.Answer, &dns.NS{
					Hdr: dns.RR_Header{
						Name:   q.Name,
						Rrtype: dns.TypeNS,
						Class:  dns.ClassINET,
						Ttl:    0,
					},
					Ns: ns.Host,
				})
			}
		case dns.TypeA:
			ips, err := resolver.LookupIPAddr(context.TODO(), q.Name)
			if err != nil {
				m.Rcode = dns.RcodeNameError
				return
			}
			for _, ip := range ips {
				if len(ip.IP.To4()) != net.IPv4len {
					continue
				}
				m.Answer = append(m.Answer, &dns.A{
					Hdr: dns.RR_Header{
						Name:   q.Name,
						Rrtype: dns.TypeA,
						Class:  dns.ClassINET,
						Ttl:    0,
					},
					A: ip.IP.To4(),
				})
			}
		}
	}
}

func Serve(udpConn net.PacketConn, zones []types.Zone) error {
	mux := dns.NewServeMux()
	handler := &dnsHandler{zones: zones}
	mux.HandleFunc(".", handler.handle)
	srv := &dns.Server{
		PacketConn: udpConn,
		Handler:    mux,
	}
	return srv.ActivateAndServe()
}
