field_parser.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. package swag
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "unicode"
  11. "github.com/go-openapi/jsonreference"
  12. "github.com/go-openapi/spec"
  13. )
  14. var _ FieldParser = &tagBaseFieldParser{p: nil, field: nil, tag: ""}
  15. const (
  16. requiredLabel = "required"
  17. optionalLabel = "optional"
  18. swaggerTypeTag = "swaggertype"
  19. swaggerIgnoreTag = "swaggerignore"
  20. )
  21. type tagBaseFieldParser struct {
  22. p *Parser
  23. field *ast.Field
  24. tag reflect.StructTag
  25. }
  26. func newTagBaseFieldParser(p *Parser, field *ast.Field) FieldParser {
  27. fieldParser := tagBaseFieldParser{
  28. p: p,
  29. field: field,
  30. tag: "",
  31. }
  32. if fieldParser.field.Tag != nil {
  33. fieldParser.tag = reflect.StructTag(strings.ReplaceAll(field.Tag.Value, "`", ""))
  34. }
  35. return &fieldParser
  36. }
  37. func (ps *tagBaseFieldParser) ShouldSkip() bool {
  38. // Skip non-exported fields.
  39. if !ast.IsExported(ps.field.Names[0].Name) {
  40. return true
  41. }
  42. if ps.field.Tag == nil {
  43. return false
  44. }
  45. ignoreTag := ps.tag.Get(swaggerIgnoreTag)
  46. if strings.EqualFold(ignoreTag, "true") {
  47. return true
  48. }
  49. // json:"tag,hoge"
  50. name := strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0])
  51. if name == "-" {
  52. return true
  53. }
  54. return false
  55. }
  56. func (ps *tagBaseFieldParser) FieldName() (string, error) {
  57. var name string
  58. if ps.field.Tag != nil {
  59. // json:"tag,hoge"
  60. name = strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0])
  61. if name != "" {
  62. return name, nil
  63. }
  64. }
  65. switch ps.p.PropNamingStrategy {
  66. case SnakeCase:
  67. return toSnakeCase(ps.field.Names[0].Name), nil
  68. case PascalCase:
  69. return ps.field.Names[0].Name, nil
  70. default:
  71. return toLowerCamelCase(ps.field.Names[0].Name), nil
  72. }
  73. }
  74. func toSnakeCase(in string) string {
  75. var (
  76. runes = []rune(in)
  77. length = len(runes)
  78. out []rune
  79. )
  80. for idx := 0; idx < length; idx++ {
  81. if idx > 0 && unicode.IsUpper(runes[idx]) &&
  82. ((idx+1 < length && unicode.IsLower(runes[idx+1])) || unicode.IsLower(runes[idx-1])) {
  83. out = append(out, '_')
  84. }
  85. out = append(out, unicode.ToLower(runes[idx]))
  86. }
  87. return string(out)
  88. }
  89. func toLowerCamelCase(in string) string {
  90. var flag bool
  91. out := make([]rune, len(in))
  92. runes := []rune(in)
  93. for i, curr := range runes {
  94. if (i == 0 && unicode.IsUpper(curr)) || (flag && unicode.IsUpper(curr)) {
  95. out[i] = unicode.ToLower(curr)
  96. flag = true
  97. continue
  98. }
  99. out[i] = curr
  100. flag = false
  101. }
  102. return string(out)
  103. }
  104. func (ps *tagBaseFieldParser) CustomSchema() (*spec.Schema, error) {
  105. if ps.field.Tag == nil {
  106. return nil, nil
  107. }
  108. typeTag := ps.tag.Get(swaggerTypeTag)
  109. if typeTag != "" {
  110. return BuildCustomSchema(strings.Split(typeTag, ","))
  111. }
  112. return nil, nil
  113. }
  114. type structField struct {
  115. schemaType string
  116. arrayType string
  117. formatType string
  118. maximum *float64
  119. minimum *float64
  120. multipleOf *float64
  121. maxLength *int64
  122. minLength *int64
  123. maxItems *int64
  124. minItems *int64
  125. exampleValue interface{}
  126. enums []interface{}
  127. enumVarNames []interface{}
  128. unique bool
  129. }
  130. // splitNotWrapped slices s into all substrings separated by sep if sep is not
  131. // wrapped by brackets and returns a slice of the substrings between those separators.
  132. func splitNotWrapped(s string, sep rune) []string {
  133. openCloseMap := map[rune]rune{
  134. '(': ')',
  135. '[': ']',
  136. '{': '}',
  137. }
  138. var (
  139. result = make([]string, 0)
  140. current = strings.Builder{}
  141. openCount = 0
  142. openChar rune
  143. )
  144. for _, char := range s {
  145. switch {
  146. case openChar == 0 && openCloseMap[char] != 0:
  147. openChar = char
  148. openCount++
  149. current.WriteRune(char)
  150. case char == openChar:
  151. openCount++
  152. current.WriteRune(char)
  153. case openCount > 0 && char == openCloseMap[openChar]:
  154. openCount--
  155. current.WriteRune(char)
  156. case openCount == 0 && char == sep:
  157. result = append(result, current.String())
  158. openChar = 0
  159. current = strings.Builder{}
  160. default:
  161. current.WriteRune(char)
  162. }
  163. }
  164. if current.String() != "" {
  165. result = append(result, current.String())
  166. }
  167. return result
  168. }
  169. func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error {
  170. types := ps.p.GetSchemaTypePath(schema, 2)
  171. if len(types) == 0 {
  172. return fmt.Errorf("invalid type for field: %s", ps.field.Names[0])
  173. }
  174. if ps.field.Tag == nil {
  175. if ps.field.Doc != nil {
  176. schema.Description = strings.TrimSpace(ps.field.Doc.Text())
  177. }
  178. if schema.Description == "" && ps.field.Comment != nil {
  179. schema.Description = strings.TrimSpace(ps.field.Comment.Text())
  180. }
  181. return nil
  182. }
  183. field := &structField{
  184. schemaType: types[0],
  185. formatType: ps.tag.Get(formatTag),
  186. }
  187. if len(types) > 1 && (types[0] == ARRAY || types[0] == OBJECT) {
  188. field.arrayType = types[1]
  189. }
  190. jsonTagValue := ps.tag.Get(jsonTag)
  191. bindingTagValue := ps.tag.Get(bindingTag)
  192. if bindingTagValue != "" {
  193. parseValidTags(bindingTagValue, field)
  194. }
  195. validateTagValue := ps.tag.Get(validateTag)
  196. if validateTagValue != "" {
  197. parseValidTags(validateTagValue, field)
  198. }
  199. enumsTagValue := ps.tag.Get(enumsTag)
  200. if enumsTagValue != "" {
  201. err := parseEnumTags(enumsTagValue, field)
  202. if err != nil {
  203. return err
  204. }
  205. }
  206. if IsNumericType(field.schemaType) || IsNumericType(field.arrayType) {
  207. maximum, err := getFloatTag(ps.tag, maximumTag)
  208. if err != nil {
  209. return err
  210. }
  211. if maximum != nil {
  212. field.maximum = maximum
  213. }
  214. minimum, err := getFloatTag(ps.tag, minimumTag)
  215. if err != nil {
  216. return err
  217. }
  218. if minimum != nil {
  219. field.minimum = minimum
  220. }
  221. multipleOf, err := getFloatTag(ps.tag, multipleOfTag)
  222. if err != nil {
  223. return err
  224. }
  225. if multipleOf != nil {
  226. field.multipleOf = multipleOf
  227. }
  228. }
  229. if field.schemaType == STRING || field.arrayType == STRING {
  230. maxLength, err := getIntTag(ps.tag, maxLengthTag)
  231. if err != nil {
  232. return err
  233. }
  234. if maxLength != nil {
  235. field.maxLength = maxLength
  236. }
  237. minLength, err := getIntTag(ps.tag, minLengthTag)
  238. if err != nil {
  239. return err
  240. }
  241. if minLength != nil {
  242. field.minLength = minLength
  243. }
  244. }
  245. // json:"name,string" or json:",string"
  246. exampleTagValue, ok := ps.tag.Lookup(exampleTag)
  247. if ok {
  248. field.exampleValue = exampleTagValue
  249. if !strings.Contains(jsonTagValue, ",string") {
  250. example, err := defineTypeOfExample(field.schemaType, field.arrayType, exampleTagValue)
  251. if err != nil {
  252. return err
  253. }
  254. field.exampleValue = example
  255. }
  256. }
  257. // perform this after setting everything else (min, max, etc...)
  258. if strings.Contains(jsonTagValue, ",string") {
  259. // @encoding/json: "It applies only to fields of string, floating point, integer, or boolean types."
  260. defaultValues := map[string]string{
  261. // Zero Values as string
  262. STRING: "",
  263. INTEGER: "0",
  264. BOOLEAN: "false",
  265. NUMBER: "0",
  266. }
  267. defaultValue, ok := defaultValues[field.schemaType]
  268. if ok {
  269. field.schemaType = STRING
  270. *schema = *PrimitiveSchema(field.schemaType)
  271. if field.exampleValue == nil {
  272. // if exampleValue is not defined by the user,
  273. // we will force an example with a correct value
  274. // (eg: int->"0", bool:"false")
  275. field.exampleValue = defaultValue
  276. }
  277. }
  278. }
  279. if ps.field.Doc != nil {
  280. schema.Description = strings.TrimSpace(ps.field.Doc.Text())
  281. }
  282. if schema.Description == "" && ps.field.Comment != nil {
  283. schema.Description = strings.TrimSpace(ps.field.Comment.Text())
  284. }
  285. schema.ReadOnly = ps.tag.Get(readOnlyTag) == "true"
  286. if !reflect.ValueOf(schema.Ref).IsZero() && schema.ReadOnly {
  287. schema.AllOf = []spec.Schema{*spec.RefSchema(schema.Ref.String())}
  288. schema.Ref = spec.Ref{
  289. Ref: jsonreference.Ref{
  290. HasFullURL: false,
  291. HasURLPathOnly: false,
  292. HasFragmentOnly: false,
  293. HasFileScheme: false,
  294. HasFullFilePath: false,
  295. },
  296. } // clear out existing ref
  297. }
  298. defaultTagValue := ps.tag.Get(defaultTag)
  299. if defaultTagValue != "" {
  300. value, err := defineType(field.schemaType, defaultTagValue)
  301. if err != nil {
  302. return err
  303. }
  304. schema.Default = value
  305. }
  306. schema.Example = field.exampleValue
  307. if field.schemaType != ARRAY {
  308. schema.Format = field.formatType
  309. }
  310. extensionsTagValue := ps.tag.Get(extensionsTag)
  311. if extensionsTagValue != "" {
  312. schema.Extensions = setExtensionParam(extensionsTagValue)
  313. }
  314. varNamesTag := ps.tag.Get("x-enum-varnames")
  315. if varNamesTag != "" {
  316. varNames := strings.Split(varNamesTag, ",")
  317. if len(varNames) != len(field.enums) {
  318. return fmt.Errorf("invalid count of x-enum-varnames. expected %d, got %d", len(field.enums), len(varNames))
  319. }
  320. field.enumVarNames = nil
  321. for _, v := range varNames {
  322. field.enumVarNames = append(field.enumVarNames, v)
  323. }
  324. if field.schemaType == ARRAY {
  325. // Add the var names in the items schema
  326. if schema.Items.Schema.Extensions == nil {
  327. schema.Items.Schema.Extensions = map[string]interface{}{}
  328. }
  329. schema.Items.Schema.Extensions["x-enum-varnames"] = field.enumVarNames
  330. } else {
  331. // Add to top level schema
  332. if schema.Extensions == nil {
  333. schema.Extensions = map[string]interface{}{}
  334. }
  335. schema.Extensions["x-enum-varnames"] = field.enumVarNames
  336. }
  337. }
  338. eleSchema := schema
  339. if field.schemaType == ARRAY {
  340. // For Array only
  341. schema.MaxItems = field.maxItems
  342. schema.MinItems = field.minItems
  343. schema.UniqueItems = field.unique
  344. eleSchema = schema.Items.Schema
  345. eleSchema.Format = field.formatType
  346. }
  347. eleSchema.Maximum = field.maximum
  348. eleSchema.Minimum = field.minimum
  349. eleSchema.MultipleOf = field.multipleOf
  350. eleSchema.MaxLength = field.maxLength
  351. eleSchema.MinLength = field.minLength
  352. eleSchema.Enum = field.enums
  353. return nil
  354. }
  355. func getFloatTag(structTag reflect.StructTag, tagName string) (*float64, error) {
  356. strValue := structTag.Get(tagName)
  357. if strValue == "" {
  358. return nil, nil
  359. }
  360. value, err := strconv.ParseFloat(strValue, 64)
  361. if err != nil {
  362. return nil, fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err)
  363. }
  364. return &value, nil
  365. }
  366. func getIntTag(structTag reflect.StructTag, tagName string) (*int64, error) {
  367. strValue := structTag.Get(tagName)
  368. if strValue == "" {
  369. return nil, nil
  370. }
  371. value, err := strconv.ParseInt(strValue, 10, 64)
  372. if err != nil {
  373. return nil, fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err)
  374. }
  375. return &value, nil
  376. }
  377. func (ps *tagBaseFieldParser) IsRequired() (bool, error) {
  378. if ps.field.Tag == nil {
  379. return false, nil
  380. }
  381. bindingTag := ps.tag.Get(bindingTag)
  382. if bindingTag != "" {
  383. for _, val := range strings.Split(bindingTag, ",") {
  384. switch val {
  385. case requiredLabel:
  386. return true, nil
  387. case optionalLabel:
  388. return false, nil
  389. }
  390. }
  391. }
  392. validateTag := ps.tag.Get(validateTag)
  393. if validateTag != "" {
  394. for _, val := range strings.Split(validateTag, ",") {
  395. switch val {
  396. case requiredLabel:
  397. return true, nil
  398. case optionalLabel:
  399. return false, nil
  400. }
  401. }
  402. }
  403. return ps.p.RequiredByDefault, nil
  404. }
  405. func parseValidTags(validTag string, sf *structField) {
  406. // `validate:"required,max=10,min=1"`
  407. // ps. required checked by IsRequired().
  408. for _, val := range strings.Split(validTag, ",") {
  409. var (
  410. valValue string
  411. keyVal = strings.Split(val, "=")
  412. )
  413. switch len(keyVal) {
  414. case 1:
  415. case 2:
  416. valValue = strings.ReplaceAll(strings.ReplaceAll(keyVal[1], utf8HexComma, ","), utf8Pipe, "|")
  417. default:
  418. continue
  419. }
  420. switch keyVal[0] {
  421. case "max", "lte":
  422. sf.setMax(valValue)
  423. case "min", "gte":
  424. sf.setMin(valValue)
  425. case "oneof":
  426. sf.setOneOf(valValue)
  427. case "unique":
  428. if sf.schemaType == ARRAY {
  429. sf.unique = true
  430. }
  431. case "dive":
  432. // ignore dive
  433. return
  434. default:
  435. continue
  436. }
  437. }
  438. }
  439. func parseEnumTags(enumTag string, field *structField) error {
  440. enumType := field.schemaType
  441. if field.schemaType == ARRAY {
  442. enumType = field.arrayType
  443. }
  444. field.enums = nil
  445. for _, e := range strings.Split(enumTag, ",") {
  446. value, err := defineType(enumType, e)
  447. if err != nil {
  448. return err
  449. }
  450. field.enums = append(field.enums, value)
  451. }
  452. return nil
  453. }
  454. func (sf *structField) setOneOf(valValue string) {
  455. if len(sf.enums) != 0 {
  456. return
  457. }
  458. enumType := sf.schemaType
  459. if sf.schemaType == ARRAY {
  460. enumType = sf.arrayType
  461. }
  462. valValues := parseOneOfParam2(valValue)
  463. for i := range valValues {
  464. value, err := defineType(enumType, valValues[i])
  465. if err != nil {
  466. continue
  467. }
  468. sf.enums = append(sf.enums, value)
  469. }
  470. }
  471. func (sf *structField) setMin(valValue string) {
  472. value, err := strconv.ParseFloat(valValue, 64)
  473. if err != nil {
  474. return
  475. }
  476. switch sf.schemaType {
  477. case INTEGER, NUMBER:
  478. sf.minimum = &value
  479. case STRING:
  480. intValue := int64(value)
  481. sf.minLength = &intValue
  482. case ARRAY:
  483. intValue := int64(value)
  484. sf.minItems = &intValue
  485. }
  486. }
  487. func (sf *structField) setMax(valValue string) {
  488. value, err := strconv.ParseFloat(valValue, 64)
  489. if err != nil {
  490. return
  491. }
  492. switch sf.schemaType {
  493. case INTEGER, NUMBER:
  494. sf.maximum = &value
  495. case STRING:
  496. intValue := int64(value)
  497. sf.maxLength = &intValue
  498. case ARRAY:
  499. intValue := int64(value)
  500. sf.maxItems = &intValue
  501. }
  502. }
  503. const (
  504. utf8HexComma = "0x2C"
  505. utf8Pipe = "0x7C"
  506. )
  507. // These code copy from
  508. // https://github.com/go-playground/validator/blob/d4271985b44b735c6f76abc7a06532ee997f9476/baked_in.go#L207
  509. // ---.
  510. var oneofValsCache = map[string][]string{}
  511. var oneofValsCacheRWLock = sync.RWMutex{}
  512. var splitParamsRegex = regexp.MustCompile(`'[^']*'|\S+`)
  513. func parseOneOfParam2(param string) []string {
  514. oneofValsCacheRWLock.RLock()
  515. values, ok := oneofValsCache[param]
  516. oneofValsCacheRWLock.RUnlock()
  517. if !ok {
  518. oneofValsCacheRWLock.Lock()
  519. values = splitParamsRegex.FindAllString(param, -1)
  520. for i := 0; i < len(values); i++ {
  521. values[i] = strings.ReplaceAll(values[i], "'", "")
  522. }
  523. oneofValsCache[param] = values
  524. oneofValsCacheRWLock.Unlock()
  525. }
  526. return values
  527. }
  528. // ---.