package main import ( "errors" "log" "net/http" "strconv" "itsc/db" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/gorilla/websocket" ) var TOKEN = "woshimima" func ok(c *gin.Context) { c.JSON(200, gin.H{ "status": "OK", }) } func auth(c *gin.Context) { type Request struct { Token string `json:"token"` } req := &Request{} err := c.ShouldBindBodyWith(req, binding.JSON) if err != nil { c.AbortWithError(400, errors.New("解析Token错误")) return } if req.Token != TOKEN { c.AbortWithError(403, errors.New("Token错误")) return } } var wsUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } func main() { r := gin.Default() group := r.Group("/") group.StaticFile("/", "./web/build/index.html") group.Static("/web", "./web/build") group.Use(func(c *gin.Context) { c.Next() if len(c.Errors) > 0 { c.JSON(-1, gin.H{ "errors": c.Errors.Errors(), "note": "General error handler abort", }) } }) api := group.Group("/api") api.GET("/timetables", func(c *gin.Context) { timetables := make([]*db.Timetable, 0) rows, err := db.GetAllTimetables.Query() if err != nil { c.AbortWithError(400, err) return } for rows.Next() { s := &db.Timetable{} rows.Scan(&s.ID, &s.Name, &s.Status) timetables = append(timetables, s) } c.JSON(http.StatusOK, gin.H{ "timetables": timetables, }) }) api.POST("/timetables", auth, func(c *gin.Context) { type Request struct { Name string `json:"newTimeTableName"` } req := &Request{} err := c.ShouldBindBodyWith(req, binding.JSON) if err != nil { c.AbortWithError(400, err) return } row := db.CreateNewTimetable.QueryRow(req.Name) var id int64 err = row.Scan(&id) if err != nil { c.AbortWithError(400, err) return } c.JSON(200, gin.H{ "id": id, }) }) api.DELETE("/timetables/:id", auth, func(c *gin.Context) { timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.AbortWithError(400, err) return } _, err = db.DeleteTimetable.Exec(timetableID) if err != nil { c.AbortWithError(400, err) return } ok(c) }) api.GET("/timetables/:id/username/:username", func(c *gin.Context) { timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.AbortWithError(400, err) return } rows, err := db.GetTimeSlotsByTimetable.Query(timetableID, c.Param("username")) if err != nil { c.AbortWithError(400, err) return } timeslots := make([]*db.TimeSlot, 0) var timetableName string var timetableStatus bool for rows.Next() { s := &db.TimeSlot{} err = rows.Scan(&s.ID, &s.Name, &s.Time, &s.Take, &s.Capacity, &timetableName, &timetableStatus, &s.Success) if err != nil { c.AbortWithError(400, err) return } timeslots = append(timeslots, s) } c.JSON(200, gin.H{ "timeslots": timeslots, "timetableName": timetableName, "timetableStatus": timetableStatus, }) }) api.POST("/timetables/:id", auth, func(c *gin.Context) { timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.AbortWithError(400, err) return } req := &db.TimeSlot{} err = c.ShouldBindBodyWith(req, binding.JSON) if err != nil { c.AbortWithError(400, err) return } row := db.CreateNewTimeslot.QueryRow(timetableID, req.Name, req.Time, req.Capacity) var newID int64 err = row.Scan(&newID) if err != nil { c.AbortWithError(400, err) return } ok(c) }) api.DELETE("/timetables/:id/:tsid", auth, func(c *gin.Context) { timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64) if err != nil { c.AbortWithError(400, err) return } _, err = db.DeleteTimeslot.Exec(timeslotID) if err != nil { c.AbortWithError(400, err) return } ok(c) }) api.GET("/timetables/:id/:tsid", func(c *gin.Context) { timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64) if err != nil { c.AbortWithError(400, err) return } takes := make([]*db.Take, 0) rows, err := db.GetTakesByTimeslot.Query(timeslotID) if err != nil { c.AbortWithError(400, err) return } for rows.Next() { s := &db.Take{} err = rows.Scan(&s.Username, &s.Created) if err != nil { c.AbortWithError(400, err) return } takes = append(takes, s) } c.JSON(200, gin.H{ "takes": takes, }) }) api.DELETE("/timetables/:id/:tsid/:tkname", auth, func(c *gin.Context) { timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64) if err != nil { c.AbortWithError(400, err) return } tkname := c.Param("tkname") tx, err := db.DB.Begin() if err != nil { c.AbortWithError(400, err) tx.Rollback() return } deleteTakeStmt := tx.Stmt(db.DeleteTake) _, err = deleteTakeStmt.Exec(timeslotID, tkname) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } UpdateTakeCountStmt := tx.Stmt(db.UpdateTakeCount) _, err = UpdateTakeCountStmt.Exec(-1, timeslotID) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } err = tx.Commit() if err != nil { c.AbortWithError(400, err) tx.Rollback() return } ok(c) }) api.PUT("/timetables/:id", auth, func(c *gin.Context) { timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.AbortWithError(400, err) return } req := &db.Timetable{} err = c.ShouldBindBodyWith(req, binding.JSON) if err != nil { c.AbortWithError(400, err) return } _, err = db.UpdateTimetableStatus.Exec(req.Status, timetableID) if err != nil { c.AbortWithError(400, err) return } ok(c) }) api.PUT("/timetables/:id/:tsid", func(c *gin.Context) { timetableID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.AbortWithError(400, err) return } timeslotID, err := strconv.ParseInt(c.Param("tsid"), 10, 64) if err != nil { c.AbortWithError(400, err) return } req := &db.Take{} err = c.ShouldBindBodyWith(req, binding.JSON) if err != nil { c.AbortWithError(400, err) return } if err != nil { c.AbortWithError(400, err) return } tx, err := db.DB.Begin() if err != nil { c.AbortWithError(400, err) tx.Rollback() return } updateCountStmt := tx.Stmt(db.UpdateTakeCount) if req.Username[0] == '!' { untakeStmt := tx.Stmt(db.UserUntakeTimeslot) username := req.Username[1:len(req.Username)] _, err = untakeStmt.Exec(username, timeslotID) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } _, err = updateCountStmt.Exec(-1, timeslotID) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } } else { checkStatusStmt := tx.Stmt(db.CheckTableStatus) var status bool row := checkStatusStmt.QueryRow(timetableID) err = row.Scan(&status) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } if !status { c.AbortWithError(403, errors.New("此表暂未开始招募")) tx.Rollback() return } takeStmt := tx.Stmt(db.UserTakeTimeslot) _, err = takeStmt.Exec(req.Username, timeslotID) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } _, err = updateCountStmt.Exec(1, timeslotID) if err != nil { c.AbortWithError(400, err) tx.Rollback() return } } err = tx.Commit() if err != nil { c.AbortWithError(400, err) return } ok(c) }) api.GET("/ws", func(c *gin.Context) { ws, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { c.AbortWithError(401, err) return } defer ws.Close() for { mt, message, err := ws.ReadMessage() if err != nil { break } if string(message) == "ping" { message = []byte("pong") } err = ws.WriteMessage(mt, message) if err != nil { break } } }) log.Println("Started") r.Run(":8081") }