diff --git a/main.go b/main.go index 772ab0d..3a9d332 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "os" "path" "strconv" + "strings" "time" _ "time/tzdata" @@ -19,6 +20,24 @@ import ( _ "github.com/mattn/go-sqlite3" ) +type OrderType = string + +const ( + OrderTypeAsc OrderType = "ASC" + OrderTypeDesc OrderType = "DESC" +) + +func validateOrderType(orderType string) OrderType { + switch strings.ToLower(orderType) { + case "asc": + return OrderTypeAsc + case "desc": + return OrderTypeDesc + default: + return OrderTypeAsc + } +} + type Podcast struct { ID int64 `json:"id" db:"id"` Name string `json:"name" db:"name"` @@ -198,8 +217,23 @@ func downloadEpisodeAudioFile(url string) ([]byte, error) { return data, nil } -func getPodcastEpisodes(db *sqlx.DB, podcastId int64) ([]*Episode, error) { - rows, err := db.Queryx("SELECT * FROM episodes WHERE podcast_id = ?", podcastId) +func getPodcastEpisodes(db *sqlx.DB, podcastId int64, orderBy string, orderType OrderType) ([]*Episode, error) { + allowedColumns := map[string]bool{ + "title": true, + "pubdate": true, + "number": true, + "created_at": true, + } + if !allowedColumns[orderBy] { + orderBy = "pubdate" + } + + orderType = validateOrderType(orderType) + + query := fmt.Sprintf("SELECT * FROM episodes WHERE podcast_id = :podcast_id ORDER BY %s %s", orderBy, orderType) + rows, err := db.NamedQuery(query, map[string]any{ + "podcast_id": podcastId, + }) if err != nil { return nil, fmt.Errorf("failed to query db: %v", err) } @@ -425,7 +459,7 @@ func main() { return } - episodes, err := getPodcastEpisodes(db, int64(id)) + episodes, err := getPodcastEpisodes(db, int64(id), "pubdate", OrderTypeDesc) if err != nil { http.Error(w, err.Error(), 500) return