diff --git a/db.go b/db.go index e2a7cbd..5dcdd03 100644 --- a/db.go +++ b/db.go @@ -3,19 +3,20 @@ package main import ( "context" "github.com/redis/go-redis/v9" + "log" "time" ) var ctx = context.Background() func connectDB() *redis.Client { - db := redis.NewClient(&redis.Options{ + localDb := redis.NewClient(&redis.Options{ Addr: currentConfig.redisAddr, Username: currentConfig.redisUser, Password: currentConfig.redisPassword, DB: currentConfig.redisDB, }) - return db + return localDb } func insertPaste(key string, content string, secret string, ttl time.Duration) { @@ -28,15 +29,24 @@ func insertPaste(key string, content string, secret string, ttl time.Duration) { content: content, secret: secret, } - db := connectDB() - db.HSet(ctx, key, "content", hash.content) - db.HSet(ctx, key, "secret", hash.secret) + err := db.HSet(ctx, key, "content", hash.content) + if err != nil { + log.Println(err) + } + err = db.HSet(ctx, key, "secret", hash.secret) if ttl > -1 { - connectDB().Do(ctx, key, ttl) + db.Do(ctx, key, ttl) } } func getContent(key string) string { - db := connectDB() - return db.HGet(ctx, key, "content").Val() + content := db.HGet(ctx, key, "content").Val() + return content +} + +func deleteContent(key string) { + err := db.Del(ctx, key) + if err != nil { + log.Println(err) + } } diff --git a/main.go b/main.go index 34e05d9..aa7c81c 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/redis/go-redis/v9" "html/template" "io" "log" @@ -10,6 +11,7 @@ import ( ) var currentConfig config +var db *redis.Client type pasteView struct { Content string @@ -19,7 +21,6 @@ type pasteView struct { func handleRequest(w http.ResponseWriter, r *http.Request) { path := r.URL.Path clearPath := strings.ReplaceAll(r.URL.Path, "/raw", "") - db := connectDB() switch r.Method { case "GET": if path == "/" { @@ -31,14 +32,14 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { } else { if urlExist(clearPath) { if strings.HasSuffix(path, "/raw") { - pasteContent := db.HGet(ctx, clearPath, "content").Val() + pasteContent := getContent(clearPath) w.Header().Set("Content-Type", "text/plain") _, err := io.WriteString(w, pasteContent) if err != nil { log.Println(err) } } else { - pasteContent := db.HGet(ctx, path, "content").Val() + pasteContent := getContent(path) s := pasteView{Content: pasteContent, Key: strings.TrimPrefix(path, "/")} t, err := template.ParseFiles("templates/paste.html") if err != nil { @@ -68,11 +69,8 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { urlItem := strings.Split(path, "/") if urlExist("/" + urlItem[2]) { secret := r.URL.Query().Get("secret") - if secret == db.HGet(ctx, "/"+urlItem[2], "secret").Val() { - err := db.Del(ctx, "/"+urlItem[2]) - if err != nil { - log.Println(err) - } + if verifySecret("/"+urlItem[2], secret) { + deleteContent("/" + urlItem[2]) w.WriteHeader(http.StatusNoContent) } else { w.WriteHeader(http.StatusForbidden) @@ -87,6 +85,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { } func main() { + db = connectDB() currentConfig = getConfig() listen := currentConfig.host + ":" + currentConfig.port http.HandleFunc("/", handleRequest) diff --git a/utils.go b/utils.go index 7f9cef9..36305da 100644 --- a/utils.go +++ b/utils.go @@ -28,6 +28,13 @@ func generateSecret() string { } func urlExist(url string) bool { - exist := connectDB().Exists(ctx, url).Val() + exist := db.Exists(ctx, url).Val() return exist == 1 } + +func verifySecret(url string, secret string) bool { + if secret == db.HGet(ctx, url, "secret").Val() { + return true + } + return false +}