Commenting

This commit is contained in:
Mael G. 2020-12-14 18:20:24 -04:00
parent 3459498605
commit 98537616c8
9 changed files with 77 additions and 44 deletions

View file

@ -2,14 +2,15 @@ package core
import "github.com/miekg/dns" import "github.com/miekg/dns"
//Handle the DNS request //Handle the DNS request using miekg/dns
//Requires dns.ReponseWriter and dns.Msg args
func HandleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { func HandleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
//dns.Msg object //dns.Msg object
//Will be passed to the parseQuery() function //Will be passed to the parseQuery() function
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Compress = true //Less CPU usage (?) m.Compress = false
if r.Opcode == dns.OpcodeQuery { //Only respond to dns queries if r.Opcode == dns.OpcodeQuery { //Only respond to dns queries
parseQuery(m) parseQuery(m)

View file

@ -19,20 +19,21 @@ import (
*/ */
//Function called by handleDnsRequest to parse the query from records //Function called by handleDnsRequest to parse the query from records
//Requires dns.ReponseWriter args
func parseQuery(m *dns.Msg) { func parseQuery(m *dns.Msg) {
for _, q := range m.Question { for _, q := range m.Question {
log.Infof("DNS : Query for %s (type : %v)\n", q.Name, q.Qtype) //Log log.Infof("DNS : Query for %s (type : %v)\n", q.Name, q.Qtype) //Log
record := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) record := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) //Get the record in the SQL or Redis database
if record.Content != "" { //If the record is found, return it if record.Content != "" { //If the record is not empty
log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) //Log the content as INFO
rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response
if err == nil { //If no err if err == nil { //If no err
m.Answer = append(m.Answer, rr) m.Answer = append(m.Answer, rr) //Append the record to the response
} }
} else { } else { //If the record is empty log it as DEBUG
logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype) logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype)
} }

19
main.go
View file

@ -18,30 +18,31 @@ var DB *sql.DB
//Main loop //Main loop
func main() { func main() {
//Get config patch //Get the config patch from --config flag
configPatch := flag.String("config", "extra/config.ini.example", "the patch to the config file") configPatch := flag.String("config", "extra/config.ini.example", "the patch to the config file")
flag.Parse() flag.Parse()
//Load Configuration //Load the INI configuration file
conf = new(utils.Conf) conf = new(utils.Conf)
err := ini.MapTo(conf, *configPatch) err := ini.MapTo(conf, *configPatch)
utils.CheckErr(err) utils.CheckErr(err)
//Set up the Logrus logger
utils.InitLogger(conf) utils.InitLogger(conf)
// attach request handler func //Attach DNS request handler func for all domains
dns.HandleFunc(".", core.HandleDnsRequest) dns.HandleFunc(".", core.HandleDnsRequest)
//Init redis database //Initialize the redis database
utils.RedisDatabase(conf) utils.RedisDatabase(conf)
//Init sql database //Initialize the sql database
utils.SqlDatabase(conf) utils.SqlDatabase(conf)
// start server //Start the DNS server
server := &dns.Server{Addr: conf.App.Ip + strconv.Itoa(conf.App.Port), Net: "udp"} //define the server server := &dns.Server{Addr: conf.App.Ip + strconv.Itoa(conf.App.Port), Net: "udp"} //define the server
logrus.WithFields(logrus.Fields{"ip": conf.App.Ip, "port": conf.App.Port}).Infof("SERVER : Started") logrus.WithFields(logrus.Fields{"ip": conf.App.Ip, "port": conf.App.Port}).Infof("SERVER : Started") //log
err = server.ListenAndServe() //start it err = server.ListenAndServe() //start it
utils.CheckErr(err) utils.CheckErr(err)
defer server.Shutdown() //shut down on application closing defer server.Shutdown() //shut down on application closing

View file

@ -8,28 +8,34 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
//Check the SQL and REDIS database for a Record.
//A Record struct is used as input and output
func GetRecord(entry Record) Record { func GetRecord(entry Record) Record {
//Check for strict record in Redis cache //Check for strict record in Redis cache
redisKey := entry.Fqdn + "--" + fmt.Sprint(entry.Qtype) redisKey := entry.Fqdn + "--" + fmt.Sprint(entry.Qtype)
result, redisErr := redisCheckForRecord(redisKey, entry) result, redisErr := redisCheckForRecord(redisKey, entry)
var sqlErr int var sqlErr int //The err returned for sqlCheckForRecord or sqlCheckForReverse6Wildcard
//If reverse DNS //If reverse DNS
reverseCheck := IsReverse(entry.Fqdn) reverseCheck := IsReverse(entry.Fqdn)
if reverseCheck > 0 { if reverseCheck > 0 {
//If reverse record not found in redis
if redisErr == redis.Nil { if redisErr == redis.Nil {
//Check for it in the SQL database
logrus.Debug("QUERIES : Check for strict reverse in MySQL") logrus.Debug("QUERIES : Check for strict reverse in MySQL")
result, sqlErr = sqlCheckForRecord(redisKey, entry.Fqdn, entry) result, sqlErr = sqlCheckForRecord(redisKey, entry.Fqdn, entry)
if sqlErr == 1 { if sqlErr == 1 {
//Check for wildcard reverse in the SQL
logrus.Debug("QUERIES : Check for wildcard reverse in MySQL") logrus.Debug("QUERIES : Check for wildcard reverse in MySQL")
result, _ = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry) result, _ = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry)
} }
} }
//For dynamic reverse dns //For dynamic reverse dns
//Check for it by looking for a "%s" in the record content
//If true, replace it with the formated IP
if strings.Contains(result.Content, "%s") { if strings.Contains(result.Content, "%s") {
record := ExtractAddressFromReverse(entry.Fqdn) record := ExtractAddressFromReverse(entry.Fqdn)
var recordFormated string var recordFormated string

View file

@ -10,17 +10,22 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
//Redis context
var ctx = context.Background() var ctx = context.Background()
//Redis client as global var
var redisDb *redis.Client var redisDb *redis.Client
//Initialize the Redis Database
//Requires a conf struct
//Return a *redis.Client
func RedisDatabase(conf *Conf) *redis.Client { func RedisDatabase(conf *Conf) *redis.Client {
logrus.WithFields(logrus.Fields{"ip": conf.Redis.Ip, "port": conf.Redis.Port}).Infof("REDIS : Connection to DB") logrus.WithFields(logrus.Fields{"ip": conf.Redis.Ip, "port": conf.Redis.Port}).Infof("REDIS : Connection to DB")
rdb := redis.NewClient(&redis.Options{ rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%v", conf.Redis.Ip, conf.Redis.Port), Addr: fmt.Sprintf("%s:%v", conf.Redis.Ip, conf.Redis.Port),
Password: conf.Redis.Password, Password: conf.Redis.Password,
DB: conf.Redis.Db, DB: conf.Redis.Db,
}) }) //Connect to the DB
//Test Redis connection //Test Redis connection
err := rdb.Set(ctx, "alive", 1, 0).Err() err := rdb.Set(ctx, "alive", 1, 0).Err()
@ -38,6 +43,9 @@ func RedisDatabase(conf *Conf) *redis.Client {
return rdb return rdb
} }
//Check for a record in the Redis database
//Requires the redis key (as string) and the record to check (struct)
//Return a Record (struct) and error (if any)
func redisCheckForRecord(redisKey string, entry Record) (Record, error) { func redisCheckForRecord(redisKey string, entry Record) (Record, error) {
val, err := redisDb.Get(ctx, redisKey).Result() val, err := redisDb.Get(ctx, redisKey).Result()
@ -52,6 +60,8 @@ func redisCheckForRecord(redisKey string, entry Record) (Record, error) {
} }
} }
//Add a record in the Redis database
//Return an error (if any)
func redisSet(c *redis.Client, key string, ttl time.Duration, value interface{}) error { func redisSet(c *redis.Client, key string, ttl time.Duration, value interface{}) error {
p, err := json.Marshal(value) p, err := json.Marshal(value)
if err != nil { if err != nil {

View file

@ -7,7 +7,8 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Based on github.com/coredns/coredns/plugin/pkg/dnsutil/reverse.go // DISCLAIMER : Based on https://github.com/coredns/coredns/blob/master/plugin/pkg/dnsutil/reverse.go
// DISCLAIMER : Will be rewrited from scratch in future release
// IsReverse returns 0 is name is not in a reverse zone. Anything > 0 indicates // IsReverse returns 0 is name is not in a reverse zone. Anything > 0 indicates
// name is in a reverse zone. The returned integer will be 1 for in-addr.arpa. (IPv4) // name is in a reverse zone. The returned integer will be 1 for in-addr.arpa. (IPv4)

View file

@ -10,24 +10,27 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
//SQL database as global var
var DB *sql.DB var DB *sql.DB
//Initialize the (My)SQL Database
//Requires a conf struct
func SqlDatabase(conf *Conf) { func SqlDatabase(conf *Conf) {
logrus.WithFields(logrus.Fields{"database": conf.Database.Db}).Infof("SQL : Connection to DB") logrus.WithFields(logrus.Fields{"database": conf.Database.Db}).Infof("SQL : Connection to DB")
//db conn //Connect to the Database
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", conf.Database.Username, conf.Database.Password, conf.Database.Ip, conf.Database.Port, conf.Database.Db)) db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", conf.Database.Username, conf.Database.Password, conf.Database.Ip, conf.Database.Port, conf.Database.Db))
CheckErr(err) CheckErr(err)
DB = db DB = db
SqlTest() //Test SQL conn SqlTest() //Test SQL connexion
// if there is an error opening the connection, handle it
CheckErr(err)
} }
//Test the SQL connexion by selecting all records from the database
func SqlTest() { func SqlTest() {
_, err := DB.Query("SELECT name, content FROM records") _, err := DB.Query("SELECT name, content FROM records")
CheckErr(err) CheckErr(err) //Panic if any error
} }
//Check for a record in the SQL database
func sqlCheckForRecord(redisKey string, dKey string, entry Record) (Record, int) { func sqlCheckForRecord(redisKey string, dKey string, entry Record) (Record, int) {
dbg := DB.QueryRow( dbg := DB.QueryRow(
"SELECT id, content, ttl FROM records WHERE `name` = ? AND `type` = ?;", dKey, entry.Qtype).Scan( "SELECT id, content, ttl FROM records WHERE `name` = ? AND `type` = ?;", dKey, entry.Qtype).Scan(
@ -36,35 +39,36 @@ func sqlCheckForRecord(redisKey string, dKey string, entry Record) (Record, int)
&entry.TTL, &entry.TTL,
) )
//logrus.WithFields(logrus.Fields{"name": dKey, "type": entry.Qtype}).Debugf("SQL : ") if dbg != nil { //If any err
if dbg != nil {
logrus.Debugf("SQL : %v", dbg) logrus.Debugf("SQL : %v", dbg)
} }
logrus.Debugf("SQL : %s => %s", entry.Fqdn, entry.Content) logrus.Debugf("SQL : %s => %s", entry.Fqdn, entry.Content) //log the result
if entry.Content != "" { if entry.Content != "" { //If Record content not empty
//Cache the request in Redis if any result //Cache the request in Redis if any result
logrus.Debugf("REDIS : Set entry for %s", redisKey) logrus.Debugf("REDIS : Set entry for %s", redisKey)
_ = redisSet(redisDb, redisKey, 30*time.Second, entry) _ = redisSet(redisDb, redisKey, 30*time.Second, entry) //Set it in the Redis database for 30sec
return entry, 0 return entry, 0
} else { } else {
//Else return nil //Else return 1 for err
return entry, 1 return entry, 1
} }
} }
//Check for a wildcard record in the SQL database
func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Record, error) { func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Record, error) {
returnedEntry := entry returnedEntry := entry
results, err := DB.Query("SELECT id, content, name FROM records WHERE name LIKE '*%.ip6.arpa.';") results, err := DB.Query("SELECT id, content, name FROM records WHERE name LIKE '*%.ip6.arpa.';") //Get ALL reverse IPs
DbgErr(err) DbgErr(err) //Check for empty row or non important error
//For each result check if it match the reverse IP
for results.Next() { for results.Next() {
err = results.Scan(&returnedEntry.Id, &returnedEntry.Content, &returnedEntry.Fqdn) err = results.Scan(&returnedEntry.Id, &returnedEntry.Content, &returnedEntry.Fqdn)
CheckErr(err) CheckErr(err)
//Check if the record is matching the reversed IP
if checkReverse6(entry, returnedEntry) { if checkReverse6(entry, returnedEntry) {
logrus.Debug("REVERSE : Correct wildcard reverse.") logrus.Debug("REVERSE : Correct wildcard reverse.")
//Cache the request in Redis if any result //Cache the request in Redis if any result

View file

@ -1,6 +1,6 @@
package utils package utils
//Structs for configuration //Struct for App (dns server) configuration in the config.ini file
type App struct { type App struct {
Port int Port int
Ip string Ip string
@ -8,6 +8,7 @@ type App struct {
Logfile bool Logfile bool
} }
//Struct for SQL Database configuration in the config.ini file
type Database struct { type Database struct {
Ip string Ip string
Port string Port string
@ -16,6 +17,7 @@ type Database struct {
Db string Db string
} }
//Struct for Redis Database configuration in the config.ini file
type Redis struct { type Redis struct {
Ip string Ip string
Port int Port int
@ -24,6 +26,7 @@ type Redis struct {
Ttl int Ttl int
} }
//Struct for the whole config.ini file when it will be parsed by go-ini
type Conf struct { type Conf struct {
App_mode string App_mode string
App App
@ -31,6 +34,7 @@ type Conf struct {
Redis Redis
} }
//Struct for a Domain (not used currently).
type Domain struct { type Domain struct {
ID int `json:"id"` ID int `json:"id"`
FriendlyName string FriendlyName string
@ -39,6 +43,8 @@ type Domain struct {
LastEdit string LastEdit string
} }
//Struct for a domain record
//Defined by it's ID, DomainID (parent domain), Fqdn (or name), Content (value of the record), Type (as Qtype/int), TTL (used only for the DNS response and not the Redis TTL)
type Record struct { type Record struct {
Id int Id int
DomainId int DomainId int

View file

@ -10,29 +10,30 @@ import (
"github.com/snowzach/rotatefilehook" "github.com/snowzach/rotatefilehook"
) )
//If fatal error, log it and panic
func CheckErr(err error) { func CheckErr(err error) {
if err != nil { if err != nil {
log.Fatalf("%s\n ", err.Error()) log.Fatalf("%s\n ", err.Error())
panic(err)
} }
} }
//Only used for non fatal errors. //If basic error, log it as classic error but don't panic and keep kalm
func DbgErr(err error) { func DbgErr(err error) {
if err != nil { if err != nil {
log.Errorf("%s\n ", err.Error()) log.Errorf("%s\n ", err.Error())
panic(err)
} }
} }
//Init the logrus logger with rotateFileHook.
//Conf struct passed to get informations about the logger (debug or not)
func InitLogger(conf *Conf) { func InitLogger(conf *Conf) {
var logLevel = logrus.InfoLevel var logLevel = logrus.InfoLevel //By default the level is Info.
if conf.App_mode != "production" { if conf.App_mode != "production" { //If the configuration contains anything different than "production"; the level is set to Debug
logLevel = logrus.DebugLevel logLevel = logrus.DebugLevel
} }
rotateFileHook, err := rotatefilehook.NewRotateFileHook(rotatefilehook.RotateFileConfig{ rotateFileHook, err := rotatefilehook.NewRotateFileHook(rotatefilehook.RotateFileConfig{ //Rotate file hook, By default 50Mb max and 28 days retention
Filename: conf.App.Logdir + "/console.log", Filename: conf.App.Logdir + "/console.log",
MaxSize: 50, // megabytes MaxSize: 50, // megabytes
MaxBackups: 3, MaxBackups: 3,
@ -49,15 +50,15 @@ func InitLogger(conf *Conf) {
logrus.Fatalf("Failed to initialize file rotate hook: %v", err) logrus.Fatalf("Failed to initialize file rotate hook: %v", err)
} }
logrus.SetLevel(logLevel) logrus.SetLevel(logLevel) //Set the log level
logrus.SetOutput(colorable.NewColorableStdout()) logrus.SetOutput(colorable.NewColorableStdout()) //Force colors in the Stdout
logrus.SetFormatter(&logrus.TextFormatter{ logrus.SetFormatter(&logrus.TextFormatter{
ForceColors: false, ForceColors: false,
FullTimestamp: true, FullTimestamp: true,
TimestampFormat: time.RFC822, TimestampFormat: time.RFC822,
}) })
if conf.App.Logfile { if conf.App.Logfile { //If file logging is enabled
logrus.AddHook(rotateFileHook) logrus.AddHook(rotateFileHook)
} }
@ -65,6 +66,8 @@ func InitLogger(conf *Conf) {
log.WithFields(log.Fields{"logLevel": logLevel}).Debug("Log level") log.WithFields(log.Fields{"logLevel": logLevel}).Debug("Log level")
} }
//Check if a reverse wildcard correspond to a record using strings.Contains
//Return bool
func checkReverse6(entry Record, result Record) bool { func checkReverse6(entry Record, result Record) bool {
check := strings.Replace(entry.Fqdn, result.Fqdn[1:], "", 1) check := strings.Replace(entry.Fqdn, result.Fqdn[1:], "", 1)
logrus.WithFields(logrus.Fields{"entry": entry.Fqdn, "result": result.Fqdn[1:]}).Debug("REVERSE checkReverse6 :") logrus.WithFields(logrus.Fields{"entry": entry.Fqdn, "result": result.Fqdn[1:]}).Debug("REVERSE checkReverse6 :")