sql.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright 2018 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package builder
  5. import (
  6. sql2 "database/sql"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. )
  12. func condToSQL(cond Cond) (string, []interface{}, error) {
  13. if cond == nil || !cond.IsValid() {
  14. return "", nil, nil
  15. }
  16. w := NewWriter()
  17. if err := cond.WriteTo(w); err != nil {
  18. return "", nil, err
  19. }
  20. return w.String(), w.args, nil
  21. }
  22. func condToBoundSQL(cond Cond) (string, error) {
  23. if cond == nil || !cond.IsValid() {
  24. return "", nil
  25. }
  26. w := NewWriter()
  27. if err := cond.WriteTo(w); err != nil {
  28. return "", err
  29. }
  30. return ConvertToBoundSQL(w.String(), w.args)
  31. }
  32. // ToSQL convert a builder or conditions to SQL and args
  33. func ToSQL(cond interface{}) (string, []interface{}, error) {
  34. switch cond.(type) {
  35. case Cond:
  36. return condToSQL(cond.(Cond))
  37. case *Builder:
  38. return cond.(*Builder).ToSQL()
  39. }
  40. return "", nil, ErrNotSupportType
  41. }
  42. // ToBoundSQL convert a builder or conditions to parameters bound SQL
  43. func ToBoundSQL(cond interface{}) (string, error) {
  44. switch cond.(type) {
  45. case Cond:
  46. return condToBoundSQL(cond.(Cond))
  47. case *Builder:
  48. return cond.(*Builder).ToBoundSQL()
  49. }
  50. return "", ErrNotSupportType
  51. }
  52. func noSQLQuoteNeeded(a interface{}) bool {
  53. if a == nil {
  54. return false
  55. }
  56. switch a.(type) {
  57. case int, int8, int16, int32, int64:
  58. return true
  59. case uint, uint8, uint16, uint32, uint64:
  60. return true
  61. case float32, float64:
  62. return true
  63. case bool:
  64. return true
  65. case string:
  66. return false
  67. case time.Time, *time.Time:
  68. return false
  69. }
  70. t := reflect.TypeOf(a)
  71. switch t.Kind() {
  72. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  73. return true
  74. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  75. return true
  76. case reflect.Float32, reflect.Float64:
  77. return true
  78. case reflect.Bool:
  79. return true
  80. case reflect.String:
  81. return false
  82. }
  83. return false
  84. }
  85. // ConvertToBoundSQL will convert SQL and args to a bound SQL
  86. func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
  87. buf := strings.Builder{}
  88. var i, j, start int
  89. for ; i < len(sql); i++ {
  90. if sql[i] == '?' {
  91. _, err := buf.WriteString(sql[start:i])
  92. if err != nil {
  93. return "", err
  94. }
  95. start = i + 1
  96. if len(args) == j {
  97. return "", ErrNeedMoreArguments
  98. }
  99. arg := args[j]
  100. if namedArg, ok := arg.(sql2.NamedArg); ok {
  101. arg = namedArg.Value
  102. }
  103. if noSQLQuoteNeeded(arg) {
  104. _, err = fmt.Fprint(&buf, arg)
  105. } else {
  106. // replace ' -> '' (standard replacement) to avoid critical SQL injection,
  107. // NOTICE: may allow some injection like % (or _) in LIKE query
  108. _, err = fmt.Fprintf(&buf, "'%v'", strings.Replace(fmt.Sprintf("%v", arg), "'",
  109. "''", -1))
  110. }
  111. if err != nil {
  112. return "", err
  113. }
  114. j = j + 1
  115. }
  116. }
  117. _, err := buf.WriteString(sql[start:])
  118. if err != nil {
  119. return "", err
  120. }
  121. return buf.String(), nil
  122. }
  123. // ConvertPlaceholder replaces the place holder ? to $1, $2 ... or :1, :2 ... according prefix
  124. func ConvertPlaceholder(sql, prefix string) (string, error) {
  125. buf := strings.Builder{}
  126. var i, j, start int
  127. var ready = true
  128. for ; i < len(sql); i++ {
  129. if sql[i] == '\'' && i > 0 && sql[i-1] != '\\' {
  130. ready = !ready
  131. }
  132. if ready && sql[i] == '?' {
  133. if _, err := buf.WriteString(sql[start:i]); err != nil {
  134. return "", err
  135. }
  136. start = i + 1
  137. j = j + 1
  138. if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil {
  139. return "", err
  140. }
  141. }
  142. }
  143. if _, err := buf.WriteString(sql[start:]); err != nil {
  144. return "", err
  145. }
  146. return buf.String(), nil
  147. }