//go:build go1.18 // +build go1.18 package swag import ( "errors" "fmt" "github.com/go-openapi/spec" "go/ast" "strings" "sync" "unicode" ) var genericDefinitionsMutex = &sync.RWMutex{} var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{} type genericTypeSpec struct { ArrayDepth int TypeSpec *TypeSpecDef Name string } func (s *genericTypeSpec) Type() ast.Expr { if s.TypeSpec != nil { return &ast.SelectorExpr{ X: &ast.Ident{Name: ""}, Sel: &ast.Ident{Name: s.Name}, } } return &ast.Ident{Name: s.Name} } func (s *genericTypeSpec) TypeDocName() string { if s.TypeSpec != nil { return strings.Replace(TypeDocName(s.TypeSpec.FullName(), s.TypeSpec.TypeSpec), "-", "_", -1) } return s.Name } func typeSpecFullName(typeSpecDef *TypeSpecDef) string { fullName := typeSpecDef.FullName() if typeSpecDef.TypeSpec.TypeParams != nil { fullName = fullName + "[" for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List { if i > 0 { fullName = fullName + "-" } fullName = fullName + typeParam.Names[0].Name } fullName = fullName + "]" } return fullName } func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef { genericDefinitionsMutex.RLock() tSpec, ok := genericsDefinitions[original][fullGenericForm] genericDefinitionsMutex.RUnlock() if ok { return tSpec } pkgName := strings.Split(fullGenericForm, ".")[0] genericTypeName, genericParams := splitStructName(fullGenericForm) if genericParams == nil { return nil } genericParamTypeDefs := map[string]*genericTypeSpec{} if len(genericParams) != len(original.TypeSpec.TypeParams.List) { return nil } for i, genericParam := range genericParams { arrayDepth := 0 for { if len(genericParam) <= 2 || genericParam[:2] != "[]" { break } genericParam = genericParam[2:] arrayDepth++ } tdef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency) if tdef != nil && !strings.Contains(genericParam, ".") { genericParam = fullTypeName(file.Name.Name, genericParam) } genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{ ArrayDepth: arrayDepth, TypeSpec: tdef, Name: genericParam, } } parametrizedTypeSpec := &TypeSpecDef{ File: original.File, PkgPath: original.PkgPath, TypeSpec: &ast.TypeSpec{ Doc: original.TypeSpec.Doc, Comment: original.TypeSpec.Comment, Assign: original.TypeSpec.Assign, }, } ident := &ast.Ident{ NamePos: original.TypeSpec.Name.NamePos, Obj: original.TypeSpec.Name.Obj, } if strings.Contains(genericTypeName, ".") { genericTypeName = strings.Split(genericTypeName, ".")[1] } var typeName = []string{TypeDocName(fullTypeName(pkgName, genericTypeName), parametrizedTypeSpec.TypeSpec)} for _, def := range original.TypeSpec.TypeParams.List { if specDef, ok := genericParamTypeDefs[def.Names[0].Name]; ok { var prefix = "" if specDef.ArrayDepth > 0 { prefix = "array_" if specDef.ArrayDepth > 1 { prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth) } } typeName = append(typeName, prefix+specDef.TypeDocName()) } } ident.Name = strings.Join(typeName, "-") ident.Name = strings.Replace(ident.Name, ".", "_", -1) pkgNamePrefix := pkgName + "_" if strings.HasPrefix(ident.Name, pkgNamePrefix) { ident.Name = fullTypeName(pkgName, ident.Name[len(pkgNamePrefix):]) } ident.Name = string(IgnoreNameOverridePrefix) + ident.Name parametrizedTypeSpec.TypeSpec.Name = ident newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency) genericDefinitionsMutex.Lock() defer genericDefinitionsMutex.Unlock() parametrizedTypeSpec.TypeSpec.Type = newType if genericsDefinitions[original] == nil { genericsDefinitions[original] = map[string]*TypeSpecDef{} } genericsDefinitions[original][fullGenericForm] = parametrizedTypeSpec return parametrizedTypeSpec } // splitStructName splits a generic struct name in his parts func splitStructName(fullGenericForm string) (string, []string) { //remove all spaces character fullGenericForm = strings.Map(func(r rune) rune { if unicode.IsSpace(r) { return -1 } return r }, fullGenericForm) // split only at the first '[' and remove the last ']' if fullGenericForm[len(fullGenericForm)-1] != ']' { return "", nil } genericParams := strings.SplitN(fullGenericForm[:len(fullGenericForm)-1], "[", 2) if len(genericParams) == 1 { return "", nil } // generic type name genericTypeName := genericParams[0] depth := 0 genericParams = strings.FieldsFunc(genericParams[1], func(r rune) bool { if r == '[' { depth++ } else if r == ']' { depth-- } else if r == ',' && depth == 0 { return true } return false }) if depth != 0 { return "", nil } return genericTypeName, genericParams } func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr { switch astExpr := expr.(type) { case *ast.Ident: if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok { retType := genTypeSpec.Type() for i := 0; i < genTypeSpec.ArrayDepth; i++ { retType = &ast.ArrayType{Elt: retType} } return retType } case *ast.ArrayType: return &ast.ArrayType{ Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs, parseDependency), Len: astExpr.Len, Lbrack: astExpr.Lbrack, } case *ast.StarExpr: return &ast.StarExpr{ Star: astExpr.Star, X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs, parseDependency), } case *ast.IndexExpr, *ast.IndexListExpr: fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs) typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency) if typeDef != nil { return typeDef.TypeSpec.Type } case *ast.StructType: newStructTypeDef := &ast.StructType{ Struct: astExpr.Struct, Incomplete: astExpr.Incomplete, Fields: &ast.FieldList{ Opening: astExpr.Fields.Opening, Closing: astExpr.Fields.Closing, }, } for _, field := range astExpr.Fields.List { newField := &ast.Field{ Type: field.Type, Doc: field.Doc, Names: field.Names, Tag: field.Tag, Comment: field.Comment, } newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs, parseDependency) newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField) } return newStructTypeDef } return expr } func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) { switch fieldType := field.(type) { case *ast.ArrayType: fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt, genericParamTypeDefs) return "[]" + fieldName, err case *ast.StarExpr: return getExtendedGenericFieldType(file, fieldType.X, genericParamTypeDefs) case *ast.Ident: if genericParamTypeDefs != nil { if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok { return typeSpec.Name, nil } } if fieldType.Obj == nil { return fieldType.Name, nil } tSpec := &TypeSpecDef{ File: file, TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec), PkgPath: file.Name.Name, } return tSpec.FullName(), nil default: return getFieldType(file, field) } } func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) { var fullName string var baseName string var err error switch fieldType := field.(type) { case *ast.IndexListExpr: baseName, err = getGenericTypeName(file, fieldType.X) if err != nil { return "", err } fullName = baseName + "[" for _, index := range fieldType.Indices { fieldName, err := getExtendedGenericFieldType(file, index, genericParamTypeDefs) if err != nil { return "", err } fullName += fieldName + "," } fullName = strings.TrimRight(fullName, ",") + "]" case *ast.IndexExpr: baseName, err = getGenericTypeName(file, fieldType.X) if err != nil { return "", err } indexName, err := getExtendedGenericFieldType(file, fieldType.Index, genericParamTypeDefs) if err != nil { return "", err } fullName = fmt.Sprintf("%s[%s]", baseName, indexName) } if fullName == "" { return "", fmt.Errorf("unknown field type %#v", field) } var packageName string if !strings.Contains(baseName, ".") { if file.Name == nil { return "", errors.New("file name is nil") } packageName, _ = getFieldType(file, file.Name) } return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil } func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) { switch fieldType := field.(type) { case *ast.Ident: if fieldType.Obj == nil { return fieldType.Name, nil } tSpec := &TypeSpecDef{ File: file, TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec), PkgPath: file.Name.Name, } return tSpec.FullName(), nil case *ast.ArrayType: tSpec := &TypeSpecDef{ File: file, TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec), PkgPath: file.Name.Name, } return tSpec.FullName(), nil case *ast.SelectorExpr: return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil } return "", fmt.Errorf("unknown type %#v", field) } func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) { switch expr := typeExpr.(type) { // suppress debug messages for these types case *ast.InterfaceType: case *ast.StructType: case *ast.Ident: case *ast.StarExpr: case *ast.SelectorExpr: case *ast.ArrayType: case *ast.MapType: case *ast.FuncType: case *ast.IndexExpr: name, err := getExtendedGenericFieldType(file, expr, nil) if err == nil { if schema, err := parser.getTypeSchema(name, file, false); err == nil { return schema, nil } } parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead. (%s)\n", typeExpr, err) default: parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr) } return PrimitiveSchema(OBJECT), nil }