session.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package base
  2. import (
  3. "fmt"
  4. "github.com/druidcaesa/gotool"
  5. "sync"
  6. "ulink-admin/frame"
  7. "ulink-admin/frame/third_plugins/http"
  8. "xorm.io/xorm"
  9. )
  10. var sessionMap sync.Map
  11. var userMap sync.Map
  12. const (
  13. ADMIN = "admin"
  14. MEMBER = "member"
  15. )
  16. type UserInfo struct {
  17. Id int64
  18. Name string
  19. IsAdmin bool
  20. ComponyId int64
  21. UserType string
  22. }
  23. func GetCurUser() *UserInfo {
  24. routineId := http.GetRoutineId()
  25. txv, ok := userMap.Load(routineId)
  26. if ok {
  27. return txv.(interface{}).(*UserInfo)
  28. } else {
  29. return nil
  30. }
  31. }
  32. func SetCurUser(user *UserInfo) {
  33. routineId := http.GetRoutineId()
  34. userMap.Store(routineId, user)
  35. }
  36. type SessionPlus struct {
  37. Session *xorm.Session //session
  38. IsTx bool //是否开启事务
  39. Level int //事务层级
  40. Id uint64
  41. }
  42. func GetSession() *SessionPlus {
  43. //var transMap map[string]*xorm.Session
  44. var sessionPlus *SessionPlus
  45. //获取线程id
  46. routineId := http.GetRoutineId()
  47. txv, ok := sessionMap.Load(routineId)
  48. fmt.Printf("routineId:%d\n", routineId)
  49. //如果从根据当前线程ID从缓存中获取到了。
  50. //否则则获取一个新的Session
  51. if ok {
  52. sessionPlus = txv.(interface{}).(*SessionPlus)
  53. //session = transMap["tx"]
  54. } else {
  55. sessionPlus = &SessionPlus{Session: SqlDB.NewSession(), IsTx: false, Id: routineId}
  56. sessionMap.Store(routineId, sessionPlus)
  57. }
  58. return sessionPlus
  59. }
  60. func Transaction(f func(session *xorm.Session)) {
  61. var err error
  62. sessionPlus := GetSession()
  63. if sessionPlus.IsTx == false {
  64. if err = sessionPlus.Session.Begin(); err != nil {
  65. gotool.Logs.ErrorLog().Println("session begin failed, err msg: %s", err.Error())
  66. frame.Throw(frame.BUSINESS_CODE, "session begin failed")
  67. } else {
  68. sessionPlus.IsTx = true
  69. sessionPlus.Level = 1
  70. }
  71. } else {
  72. sessionPlus.Level = sessionPlus.Level + 1
  73. }
  74. defer func() {
  75. //抛出异常,开启事务,并且事务层级等于1
  76. //支持镶套事务处理
  77. if sessionPlus.Level <= 1 && sessionPlus.IsTx {
  78. if r := recover(); r != nil {
  79. gotool.Logs.ErrorLog().Println("异常%v", r)
  80. if err := sessionPlus.Session.Rollback(); err != nil {
  81. gotool.Logs.ErrorLog().Println("session rollback failed, err msg: %s", err.Error())
  82. }
  83. sessionPlus.IsTx = false
  84. sessionPlus.Session.Close()
  85. //Clean()
  86. frame.Throw(frame.SQL_CODE, fmt.Sprintf("数据库错误%v", r))
  87. } else if err = sessionPlus.Session.Commit(); err != nil {
  88. gotool.Logs.ErrorLog().Println("session commit failed, err msg: %s", err.Error())
  89. }
  90. sessionPlus.IsTx = false
  91. sessionPlus.Session.Close()
  92. //Clean()
  93. }
  94. if sessionPlus.Level > 1 && sessionPlus.IsTx {
  95. sessionPlus.Level = sessionPlus.Level - 1
  96. if r := recover(); r != nil {
  97. frame.Throw(100, "镶嵌事务错误")
  98. }
  99. }
  100. }()
  101. f(sessionPlus.Session)
  102. }
  103. // Clean 关闭Session和根据进程ID清理缓存,以便GC回收
  104. func Clean() {
  105. sessionMap.Delete(http.GetRoutineId())
  106. userMap.Delete(http.GetRoutineId())
  107. }