Files
itsc/main.go
2022-11-23 18:27:39 +08:00

395 lines
8.3 KiB
Go

package main
import (
"errors"
"log"
"net/http"
"os"
"strconv"
"itsc/db"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/gorilla/websocket"
)
var TOKEN = "woshimima"
func ok(c *gin.Context) {
c.JSON(200, gin.H{
"status": "OK",
})
}
func auth(c *gin.Context) {
type Request struct {
Token string `json:"token"`
}
req := &Request{}
err := c.ShouldBindBodyWith(req, binding.JSON)
if err != nil {
c.AbortWithError(400, errors.New("解析Token错误"))
return
}
if req.Token != TOKEN {
c.AbortWithError(403, errors.New("Token错误"))
return
}
}
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func main() {
r := gin.Default()
r.NoRoute(func(c *gin.Context) {
c.File("./web/build/index.html")
})
group := r.Group("/")
group.StaticFile("/", "./web/build/index.html")
group.Static("/web", "./web/build")
group.Use(func(c *gin.Context) {
c.Next()
if len(c.Errors) > 0 {
c.JSON(-1, gin.H{
"errors": c.Errors.Errors(),
"note": "General error handler abort",
})
}
})
api := group.Group("/api")
api.GET("/timetables", func(c *gin.Context) {
timetables := make([]*db.Timetable, 0)
rows, err := db.GetAllTimetables.Query()
if err != nil {
c.AbortWithError(400, err)
return
}
for rows.Next() {
s := &db.Timetable{}
rows.Scan(&s.ID, &s.Name, &s.Status, &s.Limit)
timetables = append(timetables, s)
}
c.JSON(http.StatusOK, gin.H{
"timetables": timetables,
})
})
api.POST("/timetables", auth, func(c *gin.Context) {
type Request struct {
Name string `json:"newTimeTableName"`
Limit int64 `json:"newLimit"`
}
req := &Request{}
err := c.ShouldBindBodyWith(req, binding.JSON)
if err != nil {
c.AbortWithError(400, err)
return
}
row := db.CreateNewTimetable.QueryRow(req.Name, req.Limit)
var id int64
err = row.Scan(&id)
if err != nil {
c.AbortWithError(400, err)
return
}
c.JSON(200, gin.H{
"id": id,
})
})
api.DELETE("/timetables/:id", auth, func(c *gin.Context) {
timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
_, err = db.DeleteTimetable.Exec(timetableID)
if err != nil {
c.AbortWithError(400, err)
return
}
ok(c)
})
api.GET("/timetables/:id/username/:username", func(c *gin.Context) {
timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
rows, err := db.GetTimeSlotsByTimetable.Query(timetableID, c.Param("username"))
if err != nil {
c.AbortWithError(400, err)
return
}
timeslots := make([]*db.TimeSlot, 0)
var timetableName string
var timetableStatus bool
var timetableLimit int64
for rows.Next() {
s := &db.TimeSlot{}
err = rows.Scan(&s.ID, &s.Name, &s.Time, &s.Take, &s.Capacity, &timetableName, &timetableStatus, &s.Success, &timetableLimit)
if err != nil {
c.AbortWithError(400, err)
return
}
timeslots = append(timeslots, s)
}
c.JSON(200, gin.H{
"timeslots": timeslots,
"timetableName": timetableName,
"timetableStatus": timetableStatus,
"timetableLimit": timetableLimit,
})
})
api.POST("/timetables/:id", auth, func(c *gin.Context) {
timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
req := &db.TimeSlot{}
err = c.ShouldBindBodyWith(req, binding.JSON)
if err != nil {
c.AbortWithError(400, err)
return
}
row := db.CreateNewTimeslot.QueryRow(timetableID, req.Name, req.Time, req.Capacity)
var newID int64
err = row.Scan(&newID)
if err != nil {
c.AbortWithError(400, err)
return
}
ok(c)
})
api.DELETE("/timetables/:id/:tsid", auth, func(c *gin.Context) {
timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
_, err = db.DeleteTimeslot.Exec(timeslotID)
if err != nil {
c.AbortWithError(400, err)
return
}
ok(c)
})
api.GET("/timetables/:id/:tsid", func(c *gin.Context) {
timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
takes := make([]*db.Take, 0)
rows, err := db.GetTakesByTimeslot.Query(timeslotID)
if err != nil {
c.AbortWithError(400, err)
return
}
for rows.Next() {
s := &db.Take{}
err = rows.Scan(&s.Username, &s.Created)
if err != nil {
c.AbortWithError(400, err)
return
}
takes = append(takes, s)
}
c.JSON(200, gin.H{
"takes": takes,
})
})
api.DELETE("/timetables/:id/:tsid/:tkname", auth, func(c *gin.Context) {
timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
tkname := c.Param("tkname")
tx, err := db.DB.Begin()
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
deleteTakeStmt := tx.Stmt(db.DeleteTake)
_, err = deleteTakeStmt.Exec(timeslotID, tkname)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
UpdateTakeCountStmt := tx.Stmt(db.UpdateTakeCount)
_, err = UpdateTakeCountStmt.Exec(-1, timeslotID)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
err = tx.Commit()
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
ok(c)
})
api.PUT("/timetables/:id", auth, func(c *gin.Context) {
timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
req := &db.Timetable{}
err = c.ShouldBindBodyWith(req, binding.JSON)
if err != nil {
c.AbortWithError(400, err)
return
}
_, err = db.UpdateTimetableStatus.Exec(req.Status, req.Limit, timetableID)
if err != nil {
c.AbortWithError(400, err)
return
}
ok(c)
})
api.PUT("/timetables/:id/:tsid", func(c *gin.Context) {
timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64)
if err != nil {
c.AbortWithError(400, err)
return
}
req := &db.Take{}
err = c.ShouldBindBodyWith(req, binding.JSON)
if err != nil {
c.AbortWithError(400, err)
return
}
if err != nil {
c.AbortWithError(400, err)
return
}
tx, err := db.DB.Begin()
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
can := true
row, err := db.CheckLimit.Query(req.Username, timetableID)
for row.Next() {
err = row.Scan(&can)
}
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
if !can {
c.AbortWithError(401, errors.New("超出报名数量限制"))
tx.Rollback()
return
}
updateCountStmt := tx.Stmt(db.UpdateTakeCount)
if req.Username[0] == '!' {
untakeStmt := tx.Stmt(db.UserUntakeTimeslot)
username := req.Username[1:len(req.Username)]
_, err = untakeStmt.Exec(username, timeslotID)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
_, err = updateCountStmt.Exec(-1, timeslotID)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
} else {
checkStatusStmt := tx.Stmt(db.CheckTableStatus)
var status bool
row := checkStatusStmt.QueryRow(timetableID)
err = row.Scan(&status)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
if !status {
c.AbortWithError(403, errors.New("此表暂未开始招募"))
tx.Rollback()
return
}
takeStmt := tx.Stmt(db.UserTakeTimeslot)
_, err = takeStmt.Exec(req.Username, timeslotID)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
_, err = updateCountStmt.Exec(1, timeslotID)
if err != nil {
c.AbortWithError(400, err)
tx.Rollback()
return
}
}
err = tx.Commit()
if err != nil {
c.AbortWithError(400, err)
return
}
ok(c)
})
api.GET("/ws", func(c *gin.Context) {
ws, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.AbortWithError(401, err)
return
}
defer ws.Close()
for {
mt, message, err := ws.ReadMessage()
if err != nil {
break
}
if string(message) == "ping" {
message = []byte("pong")
}
err = ws.WriteMessage(mt, message)
if err != nil {
break
}
}
})
log.Println("Started")
r.Run(os.Getenv("ITSC_LISTEN"))
}