compiler.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. package encoder
  2. import (
  3. "context"
  4. "encoding"
  5. "encoding/json"
  6. "reflect"
  7. "sync/atomic"
  8. "unsafe"
  9. "github.com/goccy/go-json/internal/errors"
  10. "github.com/goccy/go-json/internal/runtime"
  11. )
  12. type marshalerContext interface {
  13. MarshalJSON(context.Context) ([]byte, error)
  14. }
  15. var (
  16. marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
  17. marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
  18. marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
  19. jsonNumberType = reflect.TypeOf(json.Number(""))
  20. cachedOpcodeSets []*OpcodeSet
  21. cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
  22. typeAddr *runtime.TypeAddr
  23. )
  24. func init() {
  25. typeAddr = runtime.AnalyzeTypeAddr()
  26. if typeAddr == nil {
  27. typeAddr = &runtime.TypeAddr{}
  28. }
  29. cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1)
  30. }
  31. func loadOpcodeMap() map[uintptr]*OpcodeSet {
  32. p := atomic.LoadPointer(&cachedOpcodeMap)
  33. return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
  34. }
  35. func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
  36. newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
  37. newOpcodeMap[typ] = set
  38. for k, v := range m {
  39. newOpcodeMap[k] = v
  40. }
  41. atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
  42. }
  43. func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
  44. opcodeMap := loadOpcodeMap()
  45. if codeSet, exists := opcodeMap[typeptr]; exists {
  46. return codeSet, nil
  47. }
  48. codeSet, err := newCompiler().compile(typeptr)
  49. if err != nil {
  50. return nil, err
  51. }
  52. storeOpcodeSet(typeptr, codeSet, opcodeMap)
  53. return codeSet, nil
  54. }
  55. func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
  56. if (ctx.Option.Flag & ContextOption) == 0 {
  57. return codeSet, nil
  58. }
  59. query := FieldQueryFromContext(ctx.Option.Context)
  60. if query == nil {
  61. return codeSet, nil
  62. }
  63. ctx.Option.Flag |= FieldQueryOption
  64. cacheCodeSet := codeSet.getQueryCache(query.Hash())
  65. if cacheCodeSet != nil {
  66. return cacheCodeSet, nil
  67. }
  68. queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
  69. if err != nil {
  70. return nil, err
  71. }
  72. codeSet.setQueryCache(query.Hash(), queryCodeSet)
  73. return queryCodeSet, nil
  74. }
  75. type Compiler struct {
  76. structTypeToCode map[uintptr]*StructCode
  77. }
  78. func newCompiler() *Compiler {
  79. return &Compiler{
  80. structTypeToCode: map[uintptr]*StructCode{},
  81. }
  82. }
  83. func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
  84. // noescape trick for header.typ ( reflect.*rtype )
  85. typ := *(**runtime.Type)(unsafe.Pointer(&typeptr))
  86. code, err := c.typeToCode(typ)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return c.codeToOpcodeSet(typ, code)
  91. }
  92. func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
  93. noescapeKeyCode := c.codeToOpcode(&compileContext{
  94. structTypeToCodes: map[uintptr]Opcodes{},
  95. recursiveCodes: &Opcodes{},
  96. }, typ, code)
  97. if err := noescapeKeyCode.Validate(); err != nil {
  98. return nil, err
  99. }
  100. escapeKeyCode := c.codeToOpcode(&compileContext{
  101. structTypeToCodes: map[uintptr]Opcodes{},
  102. recursiveCodes: &Opcodes{},
  103. escapeKey: true,
  104. }, typ, code)
  105. noescapeKeyCode = copyOpcode(noescapeKeyCode)
  106. escapeKeyCode = copyOpcode(escapeKeyCode)
  107. setTotalLengthToInterfaceOp(noescapeKeyCode)
  108. setTotalLengthToInterfaceOp(escapeKeyCode)
  109. interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
  110. interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
  111. codeLength := noescapeKeyCode.TotalLength()
  112. return &OpcodeSet{
  113. Type: typ,
  114. NoescapeKeyCode: noescapeKeyCode,
  115. EscapeKeyCode: escapeKeyCode,
  116. InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
  117. InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
  118. CodeLength: codeLength,
  119. EndCode: ToEndCode(interfaceNoescapeKeyCode),
  120. Code: code,
  121. QueryCache: map[string]*OpcodeSet{},
  122. }, nil
  123. }
  124. func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
  125. switch {
  126. case c.implementsMarshalJSON(typ):
  127. return c.marshalJSONCode(typ)
  128. case c.implementsMarshalText(typ):
  129. return c.marshalTextCode(typ)
  130. }
  131. isPtr := false
  132. orgType := typ
  133. if typ.Kind() == reflect.Ptr {
  134. typ = typ.Elem()
  135. isPtr = true
  136. }
  137. switch {
  138. case c.implementsMarshalJSON(typ):
  139. return c.marshalJSONCode(orgType)
  140. case c.implementsMarshalText(typ):
  141. return c.marshalTextCode(orgType)
  142. }
  143. switch typ.Kind() {
  144. case reflect.Slice:
  145. elem := typ.Elem()
  146. if elem.Kind() == reflect.Uint8 {
  147. p := runtime.PtrTo(elem)
  148. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  149. return c.bytesCode(typ, isPtr)
  150. }
  151. }
  152. return c.sliceCode(typ)
  153. case reflect.Map:
  154. if isPtr {
  155. return c.ptrCode(runtime.PtrTo(typ))
  156. }
  157. return c.mapCode(typ)
  158. case reflect.Struct:
  159. return c.structCode(typ, isPtr)
  160. case reflect.Int:
  161. return c.intCode(typ, isPtr)
  162. case reflect.Int8:
  163. return c.int8Code(typ, isPtr)
  164. case reflect.Int16:
  165. return c.int16Code(typ, isPtr)
  166. case reflect.Int32:
  167. return c.int32Code(typ, isPtr)
  168. case reflect.Int64:
  169. return c.int64Code(typ, isPtr)
  170. case reflect.Uint, reflect.Uintptr:
  171. return c.uintCode(typ, isPtr)
  172. case reflect.Uint8:
  173. return c.uint8Code(typ, isPtr)
  174. case reflect.Uint16:
  175. return c.uint16Code(typ, isPtr)
  176. case reflect.Uint32:
  177. return c.uint32Code(typ, isPtr)
  178. case reflect.Uint64:
  179. return c.uint64Code(typ, isPtr)
  180. case reflect.Float32:
  181. return c.float32Code(typ, isPtr)
  182. case reflect.Float64:
  183. return c.float64Code(typ, isPtr)
  184. case reflect.String:
  185. return c.stringCode(typ, isPtr)
  186. case reflect.Bool:
  187. return c.boolCode(typ, isPtr)
  188. case reflect.Interface:
  189. return c.interfaceCode(typ, isPtr)
  190. default:
  191. if isPtr && typ.Implements(marshalTextType) {
  192. typ = orgType
  193. }
  194. return c.typeToCodeWithPtr(typ, isPtr)
  195. }
  196. }
  197. func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
  198. switch {
  199. case c.implementsMarshalJSON(typ):
  200. return c.marshalJSONCode(typ)
  201. case c.implementsMarshalText(typ):
  202. return c.marshalTextCode(typ)
  203. }
  204. switch typ.Kind() {
  205. case reflect.Ptr:
  206. return c.ptrCode(typ)
  207. case reflect.Slice:
  208. elem := typ.Elem()
  209. if elem.Kind() == reflect.Uint8 {
  210. p := runtime.PtrTo(elem)
  211. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  212. return c.bytesCode(typ, false)
  213. }
  214. }
  215. return c.sliceCode(typ)
  216. case reflect.Array:
  217. return c.arrayCode(typ)
  218. case reflect.Map:
  219. return c.mapCode(typ)
  220. case reflect.Struct:
  221. return c.structCode(typ, isPtr)
  222. case reflect.Interface:
  223. return c.interfaceCode(typ, false)
  224. case reflect.Int:
  225. return c.intCode(typ, false)
  226. case reflect.Int8:
  227. return c.int8Code(typ, false)
  228. case reflect.Int16:
  229. return c.int16Code(typ, false)
  230. case reflect.Int32:
  231. return c.int32Code(typ, false)
  232. case reflect.Int64:
  233. return c.int64Code(typ, false)
  234. case reflect.Uint:
  235. return c.uintCode(typ, false)
  236. case reflect.Uint8:
  237. return c.uint8Code(typ, false)
  238. case reflect.Uint16:
  239. return c.uint16Code(typ, false)
  240. case reflect.Uint32:
  241. return c.uint32Code(typ, false)
  242. case reflect.Uint64:
  243. return c.uint64Code(typ, false)
  244. case reflect.Uintptr:
  245. return c.uintCode(typ, false)
  246. case reflect.Float32:
  247. return c.float32Code(typ, false)
  248. case reflect.Float64:
  249. return c.float64Code(typ, false)
  250. case reflect.String:
  251. return c.stringCode(typ, false)
  252. case reflect.Bool:
  253. return c.boolCode(typ, false)
  254. }
  255. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  256. }
  257. const intSize = 32 << (^uint(0) >> 63)
  258. //nolint:unparam
  259. func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  260. return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  261. }
  262. //nolint:unparam
  263. func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  264. return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  265. }
  266. //nolint:unparam
  267. func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  268. return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  269. }
  270. //nolint:unparam
  271. func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  272. return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  273. }
  274. //nolint:unparam
  275. func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  276. return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  277. }
  278. //nolint:unparam
  279. func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  280. return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  281. }
  282. //nolint:unparam
  283. func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  284. return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  285. }
  286. //nolint:unparam
  287. func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  288. return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  289. }
  290. //nolint:unparam
  291. func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  292. return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  293. }
  294. //nolint:unparam
  295. func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  296. return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  297. }
  298. //nolint:unparam
  299. func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  300. return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  301. }
  302. //nolint:unparam
  303. func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  304. return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  305. }
  306. //nolint:unparam
  307. func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) {
  308. return &StringCode{typ: typ, isPtr: isPtr}, nil
  309. }
  310. //nolint:unparam
  311. func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) {
  312. return &BoolCode{typ: typ, isPtr: isPtr}, nil
  313. }
  314. //nolint:unparam
  315. func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) {
  316. return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil
  317. }
  318. //nolint:unparam
  319. func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) {
  320. return &IntCode{typ: typ, bitSize: 8, isString: true}, nil
  321. }
  322. //nolint:unparam
  323. func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) {
  324. return &IntCode{typ: typ, bitSize: 16, isString: true}, nil
  325. }
  326. //nolint:unparam
  327. func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) {
  328. return &IntCode{typ: typ, bitSize: 32, isString: true}, nil
  329. }
  330. //nolint:unparam
  331. func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) {
  332. return &IntCode{typ: typ, bitSize: 64, isString: true}, nil
  333. }
  334. //nolint:unparam
  335. func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) {
  336. return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil
  337. }
  338. //nolint:unparam
  339. func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) {
  340. return &UintCode{typ: typ, bitSize: 8, isString: true}, nil
  341. }
  342. //nolint:unparam
  343. func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) {
  344. return &UintCode{typ: typ, bitSize: 16, isString: true}, nil
  345. }
  346. //nolint:unparam
  347. func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) {
  348. return &UintCode{typ: typ, bitSize: 32, isString: true}, nil
  349. }
  350. //nolint:unparam
  351. func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) {
  352. return &UintCode{typ: typ, bitSize: 64, isString: true}, nil
  353. }
  354. //nolint:unparam
  355. func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) {
  356. return &BytesCode{typ: typ, isPtr: isPtr}, nil
  357. }
  358. //nolint:unparam
  359. func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) {
  360. return &InterfaceCode{typ: typ, isPtr: isPtr}, nil
  361. }
  362. //nolint:unparam
  363. func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) {
  364. return &MarshalJSONCode{
  365. typ: typ,
  366. isAddrForMarshaler: c.isPtrMarshalJSONType(typ),
  367. isNilableType: c.isNilableType(typ),
  368. isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType),
  369. }, nil
  370. }
  371. //nolint:unparam
  372. func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) {
  373. return &MarshalTextCode{
  374. typ: typ,
  375. isAddrForMarshaler: c.isPtrMarshalTextType(typ),
  376. isNilableType: c.isNilableType(typ),
  377. }, nil
  378. }
  379. func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
  380. code, err := c.typeToCodeWithPtr(typ.Elem(), true)
  381. if err != nil {
  382. return nil, err
  383. }
  384. ptr, ok := code.(*PtrCode)
  385. if ok {
  386. return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil
  387. }
  388. return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil
  389. }
  390. func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) {
  391. elem := typ.Elem()
  392. code, err := c.listElemCode(elem)
  393. if err != nil {
  394. return nil, err
  395. }
  396. if code.Kind() == CodeKindStruct {
  397. structCode := code.(*StructCode)
  398. structCode.enableIndirect()
  399. }
  400. return &SliceCode{typ: typ, value: code}, nil
  401. }
  402. func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) {
  403. elem := typ.Elem()
  404. code, err := c.listElemCode(elem)
  405. if err != nil {
  406. return nil, err
  407. }
  408. if code.Kind() == CodeKindStruct {
  409. structCode := code.(*StructCode)
  410. structCode.enableIndirect()
  411. }
  412. return &ArrayCode{typ: typ, value: code}, nil
  413. }
  414. func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) {
  415. keyCode, err := c.mapKeyCode(typ.Key())
  416. if err != nil {
  417. return nil, err
  418. }
  419. valueCode, err := c.mapValueCode(typ.Elem())
  420. if err != nil {
  421. return nil, err
  422. }
  423. if valueCode.Kind() == CodeKindStruct {
  424. structCode := valueCode.(*StructCode)
  425. structCode.enableIndirect()
  426. }
  427. return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil
  428. }
  429. func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
  430. switch {
  431. case c.isPtrMarshalJSONType(typ):
  432. return c.marshalJSONCode(typ)
  433. case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
  434. return c.marshalTextCode(typ)
  435. case typ.Kind() == reflect.Map:
  436. return c.ptrCode(runtime.PtrTo(typ))
  437. default:
  438. // isPtr was originally used to indicate whether the type of top level is pointer.
  439. // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true.
  440. // See here for related issues: https://github.com/goccy/go-json/issues/370
  441. code, err := c.typeToCodeWithPtr(typ, true)
  442. if err != nil {
  443. return nil, err
  444. }
  445. ptr, ok := code.(*PtrCode)
  446. if ok {
  447. if ptr.value.Kind() == CodeKindMap {
  448. ptr.ptrNum++
  449. }
  450. }
  451. return code, nil
  452. }
  453. }
  454. func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
  455. switch {
  456. case c.implementsMarshalText(typ):
  457. return c.marshalTextCode(typ)
  458. }
  459. switch typ.Kind() {
  460. case reflect.Ptr:
  461. return c.ptrCode(typ)
  462. case reflect.String:
  463. return c.stringCode(typ, false)
  464. case reflect.Int:
  465. return c.intStringCode(typ)
  466. case reflect.Int8:
  467. return c.int8StringCode(typ)
  468. case reflect.Int16:
  469. return c.int16StringCode(typ)
  470. case reflect.Int32:
  471. return c.int32StringCode(typ)
  472. case reflect.Int64:
  473. return c.int64StringCode(typ)
  474. case reflect.Uint:
  475. return c.uintStringCode(typ)
  476. case reflect.Uint8:
  477. return c.uint8StringCode(typ)
  478. case reflect.Uint16:
  479. return c.uint16StringCode(typ)
  480. case reflect.Uint32:
  481. return c.uint32StringCode(typ)
  482. case reflect.Uint64:
  483. return c.uint64StringCode(typ)
  484. case reflect.Uintptr:
  485. return c.uintStringCode(typ)
  486. }
  487. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  488. }
  489. func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
  490. switch typ.Kind() {
  491. case reflect.Map:
  492. return c.ptrCode(runtime.PtrTo(typ))
  493. default:
  494. code, err := c.typeToCodeWithPtr(typ, false)
  495. if err != nil {
  496. return nil, err
  497. }
  498. ptr, ok := code.(*PtrCode)
  499. if ok {
  500. if ptr.value.Kind() == CodeKindMap {
  501. ptr.ptrNum++
  502. }
  503. }
  504. return code, nil
  505. }
  506. }
  507. func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
  508. typeptr := uintptr(unsafe.Pointer(typ))
  509. if code, exists := c.structTypeToCode[typeptr]; exists {
  510. derefCode := *code
  511. derefCode.isRecursive = true
  512. return &derefCode, nil
  513. }
  514. indirect := runtime.IfaceIndir(typ)
  515. code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
  516. c.structTypeToCode[typeptr] = code
  517. fieldNum := typ.NumField()
  518. tags := c.typeToStructTags(typ)
  519. fields := []*StructFieldCode{}
  520. for i, tag := range tags {
  521. isOnlyOneFirstField := i == 0 && fieldNum == 1
  522. field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField)
  523. if err != nil {
  524. return nil, err
  525. }
  526. if field.isAnonymous {
  527. structCode := field.getAnonymousStruct()
  528. if structCode != nil {
  529. structCode.removeFieldsByTags(tags)
  530. if c.isAssignableIndirect(field, isPtr) {
  531. if indirect {
  532. structCode.isIndirect = true
  533. } else {
  534. structCode.isIndirect = false
  535. }
  536. }
  537. }
  538. } else {
  539. structCode := field.getStruct()
  540. if structCode != nil {
  541. if indirect {
  542. // if parent is indirect type, set child indirect property to true
  543. structCode.isIndirect = true
  544. } else {
  545. // if parent is not indirect type, set child indirect property to false.
  546. // but if parent's indirect is false and isPtr is true, then indirect must be true.
  547. // Do this only if indirectConversion is enabled at the end of compileStruct.
  548. structCode.isIndirect = false
  549. }
  550. }
  551. }
  552. fields = append(fields, field)
  553. }
  554. fieldMap := c.getFieldMap(fields)
  555. duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap)
  556. code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap)
  557. if !code.disableIndirectConversion && !indirect && isPtr {
  558. code.enableIndirect()
  559. }
  560. delete(c.structTypeToCode, typeptr)
  561. return code, nil
  562. }
  563. func toElemType(t *runtime.Type) *runtime.Type {
  564. for t.Kind() == reflect.Ptr {
  565. t = t.Elem()
  566. }
  567. return t
  568. }
  569. func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) {
  570. field := tag.Field
  571. fieldType := runtime.Type2RType(field.Type)
  572. isIndirectSpecialCase := isPtr && isOnlyOneFirstField
  573. fieldCode := &StructFieldCode{
  574. typ: fieldType,
  575. key: tag.Key,
  576. tag: tag,
  577. offset: field.Offset,
  578. isAnonymous: field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct,
  579. isTaggedKey: tag.IsTaggedKey,
  580. isNilableType: c.isNilableType(fieldType),
  581. isNilCheck: true,
  582. }
  583. switch {
  584. case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase):
  585. code, err := c.marshalJSONCode(fieldType)
  586. if err != nil {
  587. return nil, err
  588. }
  589. fieldCode.value = code
  590. fieldCode.isAddrForMarshaler = true
  591. fieldCode.isNilCheck = false
  592. structCode.isIndirect = false
  593. structCode.disableIndirectConversion = true
  594. case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase):
  595. code, err := c.marshalTextCode(fieldType)
  596. if err != nil {
  597. return nil, err
  598. }
  599. fieldCode.value = code
  600. fieldCode.isAddrForMarshaler = true
  601. fieldCode.isNilCheck = false
  602. structCode.isIndirect = false
  603. structCode.disableIndirectConversion = true
  604. case isPtr && c.isPtrMarshalJSONType(fieldType):
  605. // *struct{ field T }
  606. // func (*T) MarshalJSON() ([]byte, error)
  607. code, err := c.marshalJSONCode(fieldType)
  608. if err != nil {
  609. return nil, err
  610. }
  611. fieldCode.value = code
  612. fieldCode.isAddrForMarshaler = true
  613. fieldCode.isNilCheck = false
  614. case isPtr && c.isPtrMarshalTextType(fieldType):
  615. // *struct{ field T }
  616. // func (*T) MarshalText() ([]byte, error)
  617. code, err := c.marshalTextCode(fieldType)
  618. if err != nil {
  619. return nil, err
  620. }
  621. fieldCode.value = code
  622. fieldCode.isAddrForMarshaler = true
  623. fieldCode.isNilCheck = false
  624. default:
  625. code, err := c.typeToCodeWithPtr(fieldType, isPtr)
  626. if err != nil {
  627. return nil, err
  628. }
  629. switch code.Kind() {
  630. case CodeKindPtr, CodeKindInterface:
  631. fieldCode.isNextOpPtrType = true
  632. }
  633. fieldCode.value = code
  634. }
  635. return fieldCode, nil
  636. }
  637. func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool {
  638. if isPtr {
  639. return false
  640. }
  641. codeType := fieldCode.value.Kind()
  642. if codeType == CodeKindMarshalJSON {
  643. return false
  644. }
  645. if codeType == CodeKindMarshalText {
  646. return false
  647. }
  648. return true
  649. }
  650. func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode {
  651. fieldMap := map[string][]*StructFieldCode{}
  652. for _, field := range fields {
  653. if field.isAnonymous {
  654. for k, v := range c.getAnonymousFieldMap(field) {
  655. fieldMap[k] = append(fieldMap[k], v...)
  656. }
  657. continue
  658. }
  659. fieldMap[field.key] = append(fieldMap[field.key], field)
  660. }
  661. return fieldMap
  662. }
  663. func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode {
  664. fieldMap := map[string][]*StructFieldCode{}
  665. structCode := field.getAnonymousStruct()
  666. if structCode == nil || structCode.isRecursive {
  667. fieldMap[field.key] = append(fieldMap[field.key], field)
  668. return fieldMap
  669. }
  670. for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) {
  671. fieldMap[k] = append(fieldMap[k], v...)
  672. }
  673. return fieldMap
  674. }
  675. func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode {
  676. fieldMap := map[string][]*StructFieldCode{}
  677. for _, field := range fields {
  678. if field.isAnonymous {
  679. for k, v := range c.getAnonymousFieldMap(field) {
  680. // Do not handle tagged key when embedding more than once
  681. for _, vv := range v {
  682. vv.isTaggedKey = false
  683. }
  684. fieldMap[k] = append(fieldMap[k], v...)
  685. }
  686. continue
  687. }
  688. fieldMap[field.key] = append(fieldMap[field.key], field)
  689. }
  690. return fieldMap
  691. }
  692. func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} {
  693. duplicatedFieldMap := map[*StructFieldCode]struct{}{}
  694. for _, fields := range fieldMap {
  695. if len(fields) == 1 {
  696. continue
  697. }
  698. if c.isTaggedKeyOnly(fields) {
  699. for _, field := range fields {
  700. if field.isTaggedKey {
  701. continue
  702. }
  703. duplicatedFieldMap[field] = struct{}{}
  704. }
  705. } else {
  706. for _, field := range fields {
  707. duplicatedFieldMap[field] = struct{}{}
  708. }
  709. }
  710. }
  711. return duplicatedFieldMap
  712. }
  713. func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode {
  714. filteredFields := make([]*StructFieldCode, 0, len(fields))
  715. for _, field := range fields {
  716. if field.isAnonymous {
  717. structCode := field.getAnonymousStruct()
  718. if structCode != nil && !structCode.isRecursive {
  719. structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap)
  720. if len(structCode.fields) > 0 {
  721. filteredFields = append(filteredFields, field)
  722. }
  723. continue
  724. }
  725. }
  726. if _, exists := duplicatedFieldMap[field]; exists {
  727. continue
  728. }
  729. filteredFields = append(filteredFields, field)
  730. }
  731. return filteredFields
  732. }
  733. func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool {
  734. var taggedKeyFieldCount int
  735. for _, field := range fields {
  736. if field.isTaggedKey {
  737. taggedKeyFieldCount++
  738. }
  739. }
  740. return taggedKeyFieldCount == 1
  741. }
  742. func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags {
  743. tags := runtime.StructTags{}
  744. fieldNum := typ.NumField()
  745. for i := 0; i < fieldNum; i++ {
  746. field := typ.Field(i)
  747. if runtime.IsIgnoredStructField(field) {
  748. continue
  749. }
  750. tags = append(tags, runtime.StructTagFromField(field))
  751. }
  752. return tags
  753. }
  754. // *struct{ field T } => struct { field *T }
  755. // func (*T) MarshalJSON() ([]byte, error)
  756. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  757. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ)
  758. }
  759. // *struct{ field T } => struct { field *T }
  760. // func (*T) MarshalText() ([]byte, error)
  761. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  762. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ)
  763. }
  764. func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool {
  765. if !c.implementsMarshalJSONType(typ) {
  766. return false
  767. }
  768. if typ.Kind() != reflect.Ptr {
  769. return true
  770. }
  771. // type kind is reflect.Ptr
  772. if !c.implementsMarshalJSONType(typ.Elem()) {
  773. return true
  774. }
  775. // needs to dereference
  776. return false
  777. }
  778. func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool {
  779. if !typ.Implements(marshalTextType) {
  780. return false
  781. }
  782. if typ.Kind() != reflect.Ptr {
  783. return true
  784. }
  785. // type kind is reflect.Ptr
  786. if !typ.Elem().Implements(marshalTextType) {
  787. return true
  788. }
  789. // needs to dereference
  790. return false
  791. }
  792. func (c *Compiler) isNilableType(typ *runtime.Type) bool {
  793. if !runtime.IfaceIndir(typ) {
  794. return true
  795. }
  796. switch typ.Kind() {
  797. case reflect.Ptr:
  798. return true
  799. case reflect.Map:
  800. return true
  801. case reflect.Func:
  802. return true
  803. default:
  804. return false
  805. }
  806. }
  807. func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool {
  808. return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
  809. }
  810. func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool {
  811. return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ))
  812. }
  813. func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool {
  814. return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
  815. }
  816. func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode {
  817. codes := code.ToOpcode(ctx)
  818. codes.Last().Next = newEndOp(ctx, typ)
  819. c.linkRecursiveCode(ctx)
  820. return codes.First()
  821. }
  822. func (c *Compiler) linkRecursiveCode(ctx *compileContext) {
  823. recursiveCodes := map[uintptr]*CompiledCode{}
  824. for _, recursive := range *ctx.recursiveCodes {
  825. typeptr := uintptr(unsafe.Pointer(recursive.Type))
  826. codes := ctx.structTypeToCodes[typeptr]
  827. if recursiveCode, ok := recursiveCodes[typeptr]; ok {
  828. *recursive.Jmp = *recursiveCode
  829. continue
  830. }
  831. code := copyOpcode(codes.First())
  832. code.Op = code.Op.PtrHeadToHead()
  833. lastCode := newEndOp(&compileContext{}, recursive.Type)
  834. lastCode.Op = OpRecursiveEnd
  835. // OpRecursiveEnd must set before call TotalLength
  836. code.End.Next = lastCode
  837. totalLength := code.TotalLength()
  838. // Idx, ElemIdx, Length must set after call TotalLength
  839. lastCode.Idx = uint32((totalLength + 1) * uintptrSize)
  840. lastCode.ElemIdx = lastCode.Idx + uintptrSize
  841. lastCode.Length = lastCode.Idx + 2*uintptrSize
  842. // extend length to alloc slot for elemIdx + length
  843. curTotalLength := uintptr(recursive.TotalLength()) + 3
  844. nextTotalLength := uintptr(totalLength) + 3
  845. compiled := recursive.Jmp
  846. compiled.Code = code
  847. compiled.CurLen = curTotalLength
  848. compiled.NextLen = nextTotalLength
  849. compiled.Linked = true
  850. recursiveCodes[typeptr] = compiled
  851. }
  852. }