Multiple records support

This commit is contained in:
Mael G. 2020-12-27 19:05:16 -04:00
parent a1cd3ea5a1
commit e05cab9d5e
6 changed files with 65 additions and 56 deletions

View file

@ -25,17 +25,18 @@ func parseQuery(m *dns.Msg) {
log.Infof("DNS : Query for %s (type : %v)\n", q.Name, q.Qtype) //Log log.Infof("DNS : Query for %s (type : %v)\n", q.Name, q.Qtype) //Log
record := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) //Get the record in the SQL or Redis database records := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) //Get the record in the SQL or Redis database
if record.Content != "" { //If the record is not empty for _, record := range records {
log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) //Log the content as INFO if record.Content != "" { //If the record is not empty
rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) //Log the content as INFO
if err == nil { //If no err rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response
m.Answer = append(m.Answer, rr) //Append the record to the response if err == nil { //If no err
m.Answer = append(m.Answer, rr) //Append the record to the response
}
} else { //If the record is empty log it as DEBUG
logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype)
} }
} else { //If the record is empty log it as DEBUG
logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype)
} }
} }
} }

1
go.mod
View file

@ -3,7 +3,6 @@ module github.com/outout14/sacrebleu-dns
go 1.15 go 1.15
require ( require (
github.com/go-redis/redis v6.15.9+incompatible
github.com/go-redis/redis/v8 v8.4.2 github.com/go-redis/redis/v8 v8.4.2
github.com/mattn/go-colorable v0.1.8 github.com/mattn/go-colorable v0.1.8
github.com/miekg/dns v1.1.35 github.com/miekg/dns v1.1.35

2
go.sum
View file

@ -15,8 +15,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg=
github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-redis/redis/v8 v8.4.2 h1:gKRo1KZ+O3kXRfxeRblV5Tr470d2YJZJVIAv2/S8960= github.com/go-redis/redis/v8 v8.4.2 h1:gKRo1KZ+O3kXRfxeRblV5Tr470d2YJZJVIAv2/S8960=
github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M= github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=

View file

@ -10,7 +10,7 @@ import (
//GetRecord : Check the SQL and REDIS database for a Record. //GetRecord : Check the SQL and REDIS database for a Record.
//A Record struct is used as input and output //A Record struct is used as input and output
func GetRecord(entry Record) Record { func GetRecord(entry Record) []Record {
//Check for strict record in Redis cache //Check for strict record in Redis cache
redisKey := entry.Fqdn + "--" + fmt.Sprint(entry.Qtype) redisKey := entry.Fqdn + "--" + fmt.Sprint(entry.Qtype)
result, redisErr := redisCheckForRecord(redisKey, entry) result, redisErr := redisCheckForRecord(redisKey, entry)
@ -29,23 +29,26 @@ func GetRecord(entry Record) Record {
if sqlErr { if sqlErr {
//Check for wildcard reverse in the SQL //Check for wildcard reverse in the SQL
logrus.Debug("QUERIES : Check for wildcard reverse in MySQL") logrus.Debug("QUERIES : Check for wildcard reverse in MySQL")
result, _ = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry) result = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry)
} }
} }
//For dynamic reverse dns //For dynamic reverse dns
//Check for it by looking for a "%s" in the record content //Check for it by looking for a "%s" in the record content
//If true, replace it with the formated IP //If true, replace it with the formated IP
if strings.Contains(result.Content, "%s") { for _, r := range result {
record := ExtractAddressFromReverse(entry.Fqdn) if strings.Contains(r.Content, "%s") {
var recordFormated string record := ExtractAddressFromReverse(entry.Fqdn)
if reverseCheck == 1 { var recordFormated string
recordFormated = strings.ReplaceAll(record, ".", "-") if reverseCheck == 1 {
} else { recordFormated = strings.ReplaceAll(record, ".", "-")
recordFormated = strings.ReplaceAll(record, ":", "-") } else {
recordFormated = strings.ReplaceAll(record, ":", "-")
}
r.Content = fmt.Sprintf(r.Content, recordFormated)
} }
result.Content = fmt.Sprintf(result.Content, recordFormated)
} }
} else if redisErr == redis.Nil { //If strict record NOT in Redis cache & not Reverse } else if redisErr == redis.Nil { //If strict record NOT in Redis cache & not Reverse
//Check for wildcard in Redis cache //Check for wildcard in Redis cache
logrus.Debug("QUERIES : Check for wildcard in redis cache") logrus.Debug("QUERIES : Check for wildcard in redis cache")

View file

@ -46,16 +46,17 @@ func RedisDatabase(conf *Conf) *redis.Client {
//Check for a record in the Redis database //Check for a record in the Redis database
//Requires the redis key (as string) and the record to check (struct) //Requires the redis key (as string) and the record to check (struct)
//Return a Record (struct) and error (if any) //Return a Record (struct) and error (if any)
func redisCheckForRecord(redisKey string, entry Record) (Record, error) { func redisCheckForRecord(redisKey string, entry Record) ([]Record, error) {
val, err := redisDb.Get(ctx, redisKey).Result() val, err := redisDb.Get(ctx, redisKey).Result()
var result []Record
//If Record in Redis cache //If Record in Redis cache
if err == nil { if err == nil {
err := json.Unmarshal([]byte(val), &entry) err := json.Unmarshal([]byte(val), &result)
logrus.Debugf("REDIS : %s => %s", redisKey, entry.Content) return result, err
return entry, err
} }
return entry, redis.Nil return result, redis.Nil
} }
//Add a record in the Redis database //Add a record in the Redis database

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/go-redis/redis"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
@ -15,13 +14,10 @@ import (
//DB SQL database as global var //DB SQL database as global var
var db *gorm.DB var db *gorm.DB
//SQLDatabase Initialize the (My)SQL Database //SQLDatabase Initialize the SQL Database
//Requires a conf struct //Requires a conf struct
func SQLDatabase(conf *Conf) { func SQLDatabase(conf *Conf) *gorm.DB {
logrus.WithFields(logrus.Fields{"database": conf.Database.Db, "driver": conf.Database.Type}).Infof("SQL : Connection to DB") logrus.WithFields(logrus.Fields{"database": conf.Database.Db, "driver": conf.Database.Type}).Infof("SQL : Connection to DB")
//Connect to the Database
var err error
var gormLogLevel logger.LogLevel var gormLogLevel logger.LogLevel
//Set GORM log level based on conf AppMode //Set GORM log level based on conf AppMode
@ -31,21 +27,25 @@ func SQLDatabase(conf *Conf) {
gormLogLevel = logger.Silent gormLogLevel = logger.Silent
} }
//Connect to the Database
if conf.Database.Type == "postgresql" { if conf.Database.Type == "postgresql" {
dsn := fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=disable", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db) dsn := fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=disable", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ DB, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(gormLogLevel),
})
} else {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(gormLogLevel), Logger: logger.Default.LogMode(gormLogLevel),
}) })
CheckErr(err)
db = DB
return DB
} }
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db)
DB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(gormLogLevel),
})
CheckErr(err) CheckErr(err)
db = DB
return DB
} }
//SQLMigrate : Launch the database migration (creation of tables) //SQLMigrate : Launch the database migration (creation of tables)
@ -55,33 +55,40 @@ func SQLMigrate() {
} }
//Check for a record in the SQL database //Check for a record in the SQL database
func sqlCheckForRecord(redisKey string, dKey string, entry Record) (Record, bool) { func sqlCheckForRecord(redisKey string, dKey string, entry Record) ([]Record, bool) {
db.Where("fqdn = ? AND type = ?", dKey, entry.Qtype).First(&entry) var records []Record
logrus.Debugf("SQL : %s => %s", entry.Fqdn, entry.Content) //log the result rows, err := db.Where("fqdn = ? AND type = ?", dKey, entry.Qtype).Model(&Record{}).Rows()
if err != nil {
if entry.Content != "" { //If Record content not empty return records, true
//Cache the request in Redis if any result
logrus.Debugf("REDIS : Set entry for %s", redisKey)
_ = redisSet(redisDb, redisKey, 30*time.Second, entry) //Set it in the Redis database for 30sec
return entry, false
} }
//Else return 1 for err defer rows.Close()
return entry, true for rows.Next() {
var entry Record
db.ScanRows(rows, &entry)
if entry.Content != "" { //If Record content not empty
records = append(records, entry)
}
}
//Cache the request in Redis if any result
_ = redisSet(redisDb, redisKey, 30*time.Second, records) //Set it in the Redis database for 30sec
return records, false
} }
//Check for a wildcard record in the SQL database //Check for a wildcard record in the SQL database
func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Record, error) { func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) []Record {
returnedEntry := entry returnedEntry := entry
rows, err := db.Table("records").Select("id", "content", "fqdn").Where("fqdn LIKE ?", "*%.ip6.arpa.").Rows() rows, err := db.Table("records").Select("id", "content", "fqdn").Where("fqdn LIKE ?", "*%.ip6.arpa.").Rows()
DbgErr(err) //Check for empty row or non important error DbgErr(err) //Check for empty row or non important error
var records []Record
//For each result check if it match the reverse IP //For each result check if it match the reverse IP
for rows.Next() { for rows.Next() {
err = rows.Scan(&returnedEntry.ID, &returnedEntry.Content, &returnedEntry.Fqdn) rows.Scan(&returnedEntry.ID, &returnedEntry.Content, &returnedEntry.Fqdn)
CheckErr(err) CheckErr(err)
//Check if the record is matching the reversed IP //Check if the record is matching the reversed IP
@ -89,11 +96,11 @@ func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Re
logrus.Debug("REVERSE : Correct wildcard reverse.") logrus.Debug("REVERSE : Correct wildcard reverse.")
//Cache the request in Redis if any result //Cache the request in Redis if any result
_ = redisSet(redisDb, redisKey, 10*time.Second, returnedEntry) _ = redisSet(redisDb, redisKey, 10*time.Second, returnedEntry)
return returnedEntry, err records = append(records, entry)
} }
logrus.Debug("REVERSE : WRONG wildcard reverse .") logrus.Debug("REVERSE : WRONG wildcard reverse .")
} }
return entry, redis.Nil return records
} }