diff --git a/core/parseQuery.go b/core/parseQuery.go index 533d69d..68db067 100644 --- a/core/parseQuery.go +++ b/core/parseQuery.go @@ -25,17 +25,18 @@ func parseQuery(m *dns.Msg) { log.Infof("DNS : Query for %s (type : %v)\n", q.Name, q.Qtype) //Log - record := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) //Get the record in the SQL or Redis database + records := utils.GetRecord(utils.Record{Fqdn: q.Name, Qtype: q.Qtype}) //Get the record in the SQL or Redis database - if record.Content != "" { //If the record is not empty - log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) //Log the content as INFO - rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response - if err == nil { //If no err - m.Answer = append(m.Answer, rr) //Append the record to the response + for _, record := range records { + if record.Content != "" { //If the record is not empty + log.Infof("DNS : Record found for '%s' => '%s'\n", q.Name, record.Content) //Log the content as INFO + rr, err := dns.NewRR(fmt.Sprintf("%s %v %s %s", q.Name, record.TTL, dns.TypeToString[q.Qtype], record.Content)) //Create the response + if err == nil { //If no err + m.Answer = append(m.Answer, rr) //Append the record to the response + } + } else { //If the record is empty log it as DEBUG + logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype) } - } else { //If the record is empty log it as DEBUG - logrus.Debugf("DNS : No record for '%s' (type '%v')\n", record.Fqdn, record.Qtype) } - } } diff --git a/go.mod b/go.mod index 626a11f..f880ce4 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/outout14/sacrebleu-dns go 1.15 require ( - github.com/go-redis/redis v6.15.9+incompatible github.com/go-redis/redis/v8 v8.4.2 github.com/mattn/go-colorable v0.1.8 github.com/miekg/dns v1.1.35 diff --git a/go.sum b/go.sum index c83abcc..ada5278 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= -github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-redis/redis/v8 v8.4.2 h1:gKRo1KZ+O3kXRfxeRblV5Tr470d2YJZJVIAv2/S8960= github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= diff --git a/utils/queries.go b/utils/queries.go index 66f046f..7dd3f3f 100644 --- a/utils/queries.go +++ b/utils/queries.go @@ -10,7 +10,7 @@ import ( //GetRecord : Check the SQL and REDIS database for a Record. //A Record struct is used as input and output -func GetRecord(entry Record) Record { +func GetRecord(entry Record) []Record { //Check for strict record in Redis cache redisKey := entry.Fqdn + "--" + fmt.Sprint(entry.Qtype) result, redisErr := redisCheckForRecord(redisKey, entry) @@ -29,23 +29,26 @@ func GetRecord(entry Record) Record { if sqlErr { //Check for wildcard reverse in the SQL logrus.Debug("QUERIES : Check for wildcard reverse in MySQL") - result, _ = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry) + result = sqlCheckForReverse6Wildcard(redisKey, entry.Fqdn, entry) } } //For dynamic reverse dns //Check for it by looking for a "%s" in the record content //If true, replace it with the formated IP - if strings.Contains(result.Content, "%s") { - record := ExtractAddressFromReverse(entry.Fqdn) - var recordFormated string - if reverseCheck == 1 { - recordFormated = strings.ReplaceAll(record, ".", "-") - } else { - recordFormated = strings.ReplaceAll(record, ":", "-") + for _, r := range result { + if strings.Contains(r.Content, "%s") { + record := ExtractAddressFromReverse(entry.Fqdn) + var recordFormated string + if reverseCheck == 1 { + recordFormated = strings.ReplaceAll(record, ".", "-") + } else { + recordFormated = strings.ReplaceAll(record, ":", "-") + } + r.Content = fmt.Sprintf(r.Content, recordFormated) } - result.Content = fmt.Sprintf(result.Content, recordFormated) } + } else if redisErr == redis.Nil { //If strict record NOT in Redis cache & not Reverse //Check for wildcard in Redis cache logrus.Debug("QUERIES : Check for wildcard in redis cache") diff --git a/utils/redis.go b/utils/redis.go index 60e21fe..e17417a 100644 --- a/utils/redis.go +++ b/utils/redis.go @@ -46,16 +46,17 @@ func RedisDatabase(conf *Conf) *redis.Client { //Check for a record in the Redis database //Requires the redis key (as string) and the record to check (struct) //Return a Record (struct) and error (if any) -func redisCheckForRecord(redisKey string, entry Record) (Record, error) { +func redisCheckForRecord(redisKey string, entry Record) ([]Record, error) { val, err := redisDb.Get(ctx, redisKey).Result() + var result []Record + //If Record in Redis cache if err == nil { - err := json.Unmarshal([]byte(val), &entry) - logrus.Debugf("REDIS : %s => %s", redisKey, entry.Content) - return entry, err + err := json.Unmarshal([]byte(val), &result) + return result, err } - return entry, redis.Nil + return result, redis.Nil } //Add a record in the Redis database diff --git a/utils/sql.go b/utils/sql.go index f87f302..421b704 100644 --- a/utils/sql.go +++ b/utils/sql.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/go-redis/redis" "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -15,13 +14,10 @@ import ( //DB SQL database as global var var db *gorm.DB -//SQLDatabase Initialize the (My)SQL Database +//SQLDatabase Initialize the SQL Database //Requires a conf struct -func SQLDatabase(conf *Conf) { +func SQLDatabase(conf *Conf) *gorm.DB { logrus.WithFields(logrus.Fields{"database": conf.Database.Db, "driver": conf.Database.Type}).Infof("SQL : Connection to DB") - //Connect to the Database - - var err error var gormLogLevel logger.LogLevel //Set GORM log level based on conf AppMode @@ -31,21 +27,25 @@ func SQLDatabase(conf *Conf) { gormLogLevel = logger.Silent } + //Connect to the Database if conf.Database.Type == "postgresql" { dsn := fmt.Sprintf("user=%s password=%s host=%s port=%s database=%s sslmode=disable", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db) - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ - Logger: logger.Default.LogMode(gormLogLevel), - }) - - } else { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db) - - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + DB, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(gormLogLevel), }) + CheckErr(err) + db = DB + return DB } + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Database.Username, conf.Database.Password, conf.Database.Host, conf.Database.Port, conf.Database.Db) + DB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(gormLogLevel), + }) CheckErr(err) + db = DB + + return DB } //SQLMigrate : Launch the database migration (creation of tables) @@ -55,33 +55,40 @@ func SQLMigrate() { } //Check for a record in the SQL database -func sqlCheckForRecord(redisKey string, dKey string, entry Record) (Record, bool) { - db.Where("fqdn = ? AND type = ?", dKey, entry.Qtype).First(&entry) +func sqlCheckForRecord(redisKey string, dKey string, entry Record) ([]Record, bool) { + var records []Record - logrus.Debugf("SQL : %s => %s", entry.Fqdn, entry.Content) //log the result - - if entry.Content != "" { //If Record content not empty - //Cache the request in Redis if any result - logrus.Debugf("REDIS : Set entry for %s", redisKey) - _ = redisSet(redisDb, redisKey, 30*time.Second, entry) //Set it in the Redis database for 30sec - return entry, false + rows, err := db.Where("fqdn = ? AND type = ?", dKey, entry.Qtype).Model(&Record{}).Rows() + if err != nil { + return records, true } - //Else return 1 for err - return entry, true + defer rows.Close() + for rows.Next() { + var entry Record + db.ScanRows(rows, &entry) + if entry.Content != "" { //If Record content not empty + records = append(records, entry) + } + } + //Cache the request in Redis if any result + _ = redisSet(redisDb, redisKey, 30*time.Second, records) //Set it in the Redis database for 30sec + return records, false } //Check for a wildcard record in the SQL database -func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Record, error) { +func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) []Record { returnedEntry := entry rows, err := db.Table("records").Select("id", "content", "fqdn").Where("fqdn LIKE ?", "*%.ip6.arpa.").Rows() DbgErr(err) //Check for empty row or non important error + var records []Record + //For each result check if it match the reverse IP for rows.Next() { - err = rows.Scan(&returnedEntry.ID, &returnedEntry.Content, &returnedEntry.Fqdn) + rows.Scan(&returnedEntry.ID, &returnedEntry.Content, &returnedEntry.Fqdn) CheckErr(err) //Check if the record is matching the reversed IP @@ -89,11 +96,11 @@ func sqlCheckForReverse6Wildcard(redisKey string, dKey string, entry Record) (Re logrus.Debug("REVERSE : Correct wildcard reverse.") //Cache the request in Redis if any result _ = redisSet(redisDb, redisKey, 10*time.Second, returnedEntry) - return returnedEntry, err + records = append(records, entry) } logrus.Debug("REVERSE : WRONG wildcard reverse .") } - return entry, redis.Nil + return records }