Add redis connection test before start #29

Closed
ada wants to merge 32 commits from ada/redis-ping into main
3 changed files with 33 additions and 17 deletions
Showing only changes of commit 6567a7c0cd - Show all commits

26
db.go
View file

@ -3,19 +3,20 @@ package main
import (
"context"
"github.com/redis/go-redis/v9"
"log"
"time"
)
var ctx = context.Background()
func connectDB() *redis.Client {
db := redis.NewClient(&redis.Options{
localDb := redis.NewClient(&redis.Options{
Addr: currentConfig.redisAddr,
Username: currentConfig.redisUser,
Password: currentConfig.redisPassword,
DB: currentConfig.redisDB,
})
return db
return localDb
}
func insertPaste(key string, content string, secret string, ttl time.Duration) {
@ -28,15 +29,24 @@ func insertPaste(key string, content string, secret string, ttl time.Duration) {
content: content,
secret: secret,
}
db := connectDB()
db.HSet(ctx, key, "content", hash.content)
db.HSet(ctx, key, "secret", hash.secret)
err := db.HSet(ctx, key, "content", hash.content)
if err != nil {
log.Println(err)
}
err = db.HSet(ctx, key, "secret", hash.secret)
if ttl > -1 {
connectDB().Do(ctx, key, ttl)
db.Do(ctx, key, ttl)
}
}
func getContent(key string) string {
db := connectDB()
return db.HGet(ctx, key, "content").Val()
content := db.HGet(ctx, key, "content").Val()
return content
}
func deleteContent(key string) {
err := db.Del(ctx, key)
if err != nil {
log.Println(err)
}
}

15
main.go
View file

@ -2,6 +2,7 @@ package main
import (
"fmt"
"github.com/redis/go-redis/v9"
"html/template"
"io"
"log"
@ -10,6 +11,7 @@ import (
)
var currentConfig config
var db *redis.Client
type pasteView struct {
Content string
@ -19,7 +21,6 @@ type pasteView struct {
func handleRequest(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
clearPath := strings.ReplaceAll(r.URL.Path, "/raw", "")
db := connectDB()
switch r.Method {
case "GET":
if path == "/" {
@ -31,14 +32,14 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
} else {
if urlExist(clearPath) {
if strings.HasSuffix(path, "/raw") {
pasteContent := db.HGet(ctx, clearPath, "content").Val()
pasteContent := getContent(clearPath)
w.Header().Set("Content-Type", "text/plain")
_, err := io.WriteString(w, pasteContent)
if err != nil {
log.Println(err)
}
} else {
pasteContent := db.HGet(ctx, path, "content").Val()
pasteContent := getContent(path)
s := pasteView{Content: pasteContent, Key: strings.TrimPrefix(path, "/")}
t, err := template.ParseFiles("templates/paste.html")
if err != nil {
@ -68,11 +69,8 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
urlItem := strings.Split(path, "/")
if urlExist("/" + urlItem[2]) {
secret := r.URL.Query().Get("secret")
if secret == db.HGet(ctx, "/"+urlItem[2], "secret").Val() {
err := db.Del(ctx, "/"+urlItem[2])
if err != nil {
log.Println(err)
}
if verifySecret("/"+urlItem[2], secret) {
deleteContent("/" + urlItem[2])
w.WriteHeader(http.StatusNoContent)
} else {
w.WriteHeader(http.StatusForbidden)
@ -87,6 +85,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
}
func main() {
db = connectDB()
currentConfig = getConfig()
listen := currentConfig.host + ":" + currentConfig.port
http.HandleFunc("/", handleRequest)

View file

@ -28,6 +28,13 @@ func generateSecret() string {
}
func urlExist(url string) bool {
exist := connectDB().Exists(ctx, url).Val()
exist := db.Exists(ctx, url).Val()
return exist == 1
}
func verifySecret(url string, secret string) bool {
if secret == db.HGet(ctx, url, "secret").Val() {
return true
}
return false
}