packages.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. package swag
  2. import (
  3. "go/ast"
  4. goparser "go/parser"
  5. "go/token"
  6. "os"
  7. "path/filepath"
  8. "runtime"
  9. "sort"
  10. "strings"
  11. "golang.org/x/tools/go/loader"
  12. )
  13. // PackagesDefinitions map[package import path]*PackageDefinitions.
  14. type PackagesDefinitions struct {
  15. files map[*ast.File]*AstFileInfo
  16. packages map[string]*PackageDefinitions
  17. uniqueDefinitions map[string]*TypeSpecDef
  18. }
  19. // NewPackagesDefinitions create object PackagesDefinitions.
  20. func NewPackagesDefinitions() *PackagesDefinitions {
  21. return &PackagesDefinitions{
  22. files: make(map[*ast.File]*AstFileInfo),
  23. packages: make(map[string]*PackageDefinitions),
  24. uniqueDefinitions: make(map[string]*TypeSpecDef),
  25. }
  26. }
  27. // CollectAstFile collect ast.file.
  28. func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error {
  29. if pkgDefs.files == nil {
  30. pkgDefs.files = make(map[*ast.File]*AstFileInfo)
  31. }
  32. if pkgDefs.packages == nil {
  33. pkgDefs.packages = make(map[string]*PackageDefinitions)
  34. }
  35. // return without storing the file if we lack a packageDir
  36. if packageDir == "" {
  37. return nil
  38. }
  39. path, err := filepath.Abs(path)
  40. if err != nil {
  41. return err
  42. }
  43. dependency, ok := pkgDefs.packages[packageDir]
  44. if ok {
  45. // return without storing the file if it already exists
  46. _, exists := dependency.Files[path]
  47. if exists {
  48. return nil
  49. }
  50. dependency.Files[path] = astFile
  51. } else {
  52. pkgDefs.packages[packageDir] = &PackageDefinitions{
  53. Name: astFile.Name.Name,
  54. Files: map[string]*ast.File{path: astFile},
  55. TypeDefinitions: make(map[string]*TypeSpecDef),
  56. }
  57. }
  58. pkgDefs.files[astFile] = &AstFileInfo{
  59. File: astFile,
  60. Path: path,
  61. PackagePath: packageDir,
  62. }
  63. return nil
  64. }
  65. // RangeFiles for range the collection of ast.File in alphabetic order.
  66. func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error {
  67. sortedFiles := make([]*AstFileInfo, 0, len(files))
  68. for _, info := range files {
  69. // ignore package path prefix with 'vendor' or $GOROOT,
  70. // because the router info of api will not be included these files.
  71. if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) {
  72. continue
  73. }
  74. sortedFiles = append(sortedFiles, info)
  75. }
  76. sort.Slice(sortedFiles, func(i, j int) bool {
  77. return strings.Compare(sortedFiles[i].Path, sortedFiles[j].Path) < 0
  78. })
  79. for _, info := range sortedFiles {
  80. err := handle(info.Path, info.File)
  81. if err != nil {
  82. return err
  83. }
  84. }
  85. return nil
  86. }
  87. // ParseTypes parse types
  88. // @Return parsed definitions.
  89. func (pkgDefs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) {
  90. parsedSchemas := make(map[*TypeSpecDef]*Schema)
  91. for astFile, info := range pkgDefs.files {
  92. pkgDefs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas)
  93. pkgDefs.parseFunctionScopedTypesFromFile(astFile, info.PackagePath, parsedSchemas)
  94. }
  95. return parsedSchemas, nil
  96. }
  97. func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) {
  98. for _, astDeclaration := range astFile.Decls {
  99. if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE {
  100. for _, astSpec := range generalDeclaration.Specs {
  101. if typeSpec, ok := astSpec.(*ast.TypeSpec); ok {
  102. typeSpecDef := &TypeSpecDef{
  103. PkgPath: packagePath,
  104. File: astFile,
  105. TypeSpec: typeSpec,
  106. }
  107. if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil {
  108. parsedSchemas[typeSpecDef] = &Schema{
  109. PkgPath: typeSpecDef.PkgPath,
  110. Name: astFile.Name.Name,
  111. Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)),
  112. }
  113. }
  114. if pkgDefs.uniqueDefinitions == nil {
  115. pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef)
  116. }
  117. fullName := typeSpecFullName(typeSpecDef)
  118. anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName]
  119. if ok {
  120. if typeSpecDef.PkgPath == anotherTypeDef.PkgPath {
  121. continue
  122. } else {
  123. delete(pkgDefs.uniqueDefinitions, fullName)
  124. }
  125. } else {
  126. pkgDefs.uniqueDefinitions[fullName] = typeSpecDef
  127. }
  128. if pkgDefs.packages[typeSpecDef.PkgPath] == nil {
  129. pkgDefs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{
  130. Name: astFile.Name.Name,
  131. TypeDefinitions: map[string]*TypeSpecDef{typeSpecDef.Name(): typeSpecDef},
  132. }
  133. } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok {
  134. pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef
  135. }
  136. }
  137. }
  138. }
  139. }
  140. }
  141. func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) {
  142. for _, astDeclaration := range astFile.Decls {
  143. funcDeclaration, ok := astDeclaration.(*ast.FuncDecl)
  144. if ok && funcDeclaration.Body != nil {
  145. for _, stmt := range funcDeclaration.Body.List {
  146. if declStmt, ok := (stmt).(*ast.DeclStmt); ok {
  147. if genDecl, ok := (declStmt.Decl).(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
  148. for _, astSpec := range genDecl.Specs {
  149. if typeSpec, ok := astSpec.(*ast.TypeSpec); ok {
  150. typeSpecDef := &TypeSpecDef{
  151. PkgPath: packagePath,
  152. File: astFile,
  153. TypeSpec: typeSpec,
  154. ParentSpec: astDeclaration,
  155. }
  156. if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil {
  157. parsedSchemas[typeSpecDef] = &Schema{
  158. PkgPath: typeSpecDef.PkgPath,
  159. Name: astFile.Name.Name,
  160. Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)),
  161. }
  162. }
  163. if pkgDefs.uniqueDefinitions == nil {
  164. pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef)
  165. }
  166. fullName := typeSpecFullName(typeSpecDef)
  167. anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName]
  168. if ok {
  169. if typeSpecDef.PkgPath == anotherTypeDef.PkgPath {
  170. continue
  171. } else {
  172. delete(pkgDefs.uniqueDefinitions, fullName)
  173. }
  174. } else {
  175. pkgDefs.uniqueDefinitions[fullName] = typeSpecDef
  176. }
  177. if pkgDefs.packages[typeSpecDef.PkgPath] == nil {
  178. pkgDefs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{
  179. Name: astFile.Name.Name,
  180. TypeDefinitions: map[string]*TypeSpecDef{fullName: typeSpecDef},
  181. }
  182. } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName]; !ok {
  183. pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName] = typeSpecDef
  184. }
  185. }
  186. }
  187. }
  188. }
  189. }
  190. }
  191. }
  192. }
  193. func (pkgDefs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef {
  194. if pkgDefs.packages == nil {
  195. return nil
  196. }
  197. pd, found := pkgDefs.packages[pkgPath]
  198. if found {
  199. typeSpec, ok := pd.TypeDefinitions[typeName]
  200. if ok {
  201. return typeSpec
  202. }
  203. }
  204. return nil
  205. }
  206. func (pkgDefs *PackagesDefinitions) loadExternalPackage(importPath string) error {
  207. cwd, err := os.Getwd()
  208. if err != nil {
  209. return err
  210. }
  211. conf := loader.Config{
  212. ParserMode: goparser.ParseComments,
  213. Cwd: cwd,
  214. }
  215. conf.Import(importPath)
  216. loaderProgram, err := conf.Load()
  217. if err != nil {
  218. return err
  219. }
  220. for _, info := range loaderProgram.AllPackages {
  221. pkgPath := strings.TrimPrefix(info.Pkg.Path(), "vendor/")
  222. for _, astFile := range info.Files {
  223. pkgDefs.parseTypesFromFile(astFile, pkgPath, nil)
  224. }
  225. }
  226. return nil
  227. }
  228. // findPackagePathFromImports finds out the package path of a package via ranging imports of an ast.File
  229. // @pkg the name of the target package
  230. // @file current ast.File in which to search imports
  231. // @fuzzy search for the package path that the last part matches the @pkg if true
  232. // @return the package path of a package of @pkg.
  233. func (pkgDefs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File, fuzzy bool) string {
  234. if file == nil {
  235. return ""
  236. }
  237. if strings.ContainsRune(pkg, '.') {
  238. pkg = strings.Split(pkg, ".")[0]
  239. }
  240. hasAnonymousPkg := false
  241. matchLastPathPart := func(pkgPath string) bool {
  242. paths := strings.Split(pkgPath, "/")
  243. return paths[len(paths)-1] == pkg
  244. }
  245. // prior to match named package
  246. for _, imp := range file.Imports {
  247. if imp.Name != nil {
  248. if imp.Name.Name == pkg {
  249. return strings.Trim(imp.Path.Value, `"`)
  250. }
  251. if imp.Name.Name == "_" {
  252. hasAnonymousPkg = true
  253. }
  254. continue
  255. }
  256. if pkgDefs.packages != nil {
  257. path := strings.Trim(imp.Path.Value, `"`)
  258. if fuzzy {
  259. if matchLastPathPart(path) {
  260. return path
  261. }
  262. continue
  263. }
  264. pd, ok := pkgDefs.packages[path]
  265. if ok && pd.Name == pkg {
  266. return path
  267. }
  268. }
  269. }
  270. // match unnamed package
  271. if hasAnonymousPkg && pkgDefs.packages != nil {
  272. for _, imp := range file.Imports {
  273. if imp.Name == nil {
  274. continue
  275. }
  276. if imp.Name.Name == "_" {
  277. path := strings.Trim(imp.Path.Value, `"`)
  278. if fuzzy {
  279. if matchLastPathPart(path) {
  280. return path
  281. }
  282. } else if pd, ok := pkgDefs.packages[path]; ok && pd.Name == pkg {
  283. return path
  284. }
  285. }
  286. }
  287. }
  288. return ""
  289. }
  290. // FindTypeSpec finds out TypeSpecDef of a type by typeName
  291. // @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file
  292. // @file the ast.file in which @typeName is used
  293. // @pkgPath the package path of @file.
  294. func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef {
  295. if IsGolangPrimitiveType(typeName) {
  296. return nil
  297. }
  298. if file == nil { // for test
  299. return pkgDefs.uniqueDefinitions[typeName]
  300. }
  301. parts := strings.Split(strings.Split(typeName, "[")[0], ".")
  302. if len(parts) > 1 {
  303. isAliasPkgName := func(file *ast.File, pkgName string) bool {
  304. if file != nil && file.Imports != nil {
  305. for _, pkg := range file.Imports {
  306. if pkg.Name != nil && pkg.Name.Name == pkgName {
  307. return true
  308. }
  309. }
  310. }
  311. return false
  312. }
  313. if !isAliasPkgName(file, parts[0]) {
  314. typeDef, ok := pkgDefs.uniqueDefinitions[typeName]
  315. if ok {
  316. return typeDef
  317. }
  318. }
  319. pkgPath := pkgDefs.findPackagePathFromImports(parts[0], file, false)
  320. if len(pkgPath) == 0 {
  321. // check if the current package
  322. if parts[0] == file.Name.Name {
  323. pkgPath = pkgDefs.files[file].PackagePath
  324. } else if parseDependency {
  325. // take it as an external package, needs to be loaded
  326. if pkgPath = pkgDefs.findPackagePathFromImports(parts[0], file, true); len(pkgPath) > 0 {
  327. if err := pkgDefs.loadExternalPackage(pkgPath); err != nil {
  328. return nil
  329. }
  330. }
  331. }
  332. }
  333. if def := pkgDefs.findGenericTypeSpec(typeName, file, parseDependency); def != nil {
  334. return def
  335. }
  336. return pkgDefs.findTypeSpec(pkgPath, parts[1])
  337. }
  338. if def := pkgDefs.findGenericTypeSpec(fullTypeName(file.Name.Name, typeName), file, parseDependency); def != nil {
  339. return def
  340. }
  341. typeDef, ok := pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, typeName)]
  342. if ok {
  343. return typeDef
  344. }
  345. typeDef = pkgDefs.findTypeSpec(pkgDefs.files[file].PackagePath, typeName)
  346. if typeDef != nil {
  347. return typeDef
  348. }
  349. for _, imp := range file.Imports {
  350. if imp.Name != nil && imp.Name.Name == "." {
  351. typeDef := pkgDefs.findTypeSpec(strings.Trim(imp.Path.Value, `"`), typeName)
  352. if typeDef != nil {
  353. return typeDef
  354. }
  355. }
  356. }
  357. return nil
  358. }
  359. func (pkgDefs *PackagesDefinitions) findGenericTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef {
  360. if strings.Contains(typeName, "[") {
  361. // genericName differs from typeName in that it does not contain any type parameters
  362. genericName := strings.SplitN(typeName, "[", 2)[0]
  363. for tName, tSpec := range pkgDefs.uniqueDefinitions {
  364. if !strings.Contains(tName, "[") {
  365. continue
  366. }
  367. if strings.Contains(tName, genericName) {
  368. if parametrized := pkgDefs.parametrizeGenericType(file, tSpec, typeName, parseDependency); parametrized != nil {
  369. return parametrized
  370. }
  371. }
  372. }
  373. }
  374. return nil
  375. }