1 // Copyright 2016 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
8 // This program is run via "go generate" (via a directive in sort.go)
9 // to generate zfuncversion.go.
11 // It copies sort.go to zfuncversion.go, only retaining funcs which
12 // take a "data Interface" parameter, and renaming each to have a
13 // "_func" suffix and taking a "data lessSwap" instead. It then rewrites
14 // each internal function call to the appropriate _func variants.
29 var fset
= token
.NewFileSet()
32 af
, err
:= parser
.ParseFile(fset
, "sort.go", nil, 0)
40 var newDecl
[]ast
.Decl
41 for _
, d
:= range af
.Decls
{
42 fd
, ok
:= d
.(*ast
.FuncDecl
)
46 if fd
.Recv
!= nil || fd
.Name
.IsExported() {
50 if len(typ
.Params
.List
) < 1 {
53 arg0
:= typ
.Params
.List
[0]
54 arg0Name
:= arg0
.Names
[0].Name
55 arg0Type
:= arg0
.Type
.(*ast
.Ident
)
56 if arg0Name
!= "data" || arg0Type
.Name
!= "Interface" {
59 arg0Type
.Name
= "lessSwap"
61 newDecl
= append(newDecl
, fd
)
64 ast
.Walk(visitFunc(rewriteCalls
), af
)
67 if err
:= format
.Node(&out
, fset
, af
); err
!= nil {
68 log
.Fatalf("format.Node: %v", err
)
71 // Get rid of blank lines after removal of comments.
72 src
:= regexp
.MustCompile(`\n{2,}`).ReplaceAll(out
.Bytes(), []byte("\n"))
74 // Add comments to each func, for the lost reader.
75 // This is so much easier than adding comments via the AST
76 // and trying to get position info correct.
77 src
= regexp
.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src
, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
80 src
, err
= format
.Source(src
)
82 log
.Fatalf("format.Source: %v on\n%s", err
, src
)
86 out
.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT.
88 // Copyright 2016 The Go Authors. All rights reserved.
89 // Use of this source code is governed by a BSD-style
90 // license that can be found in the LICENSE file.
95 const target
= "zfuncversion.go"
96 if err
:= os
.WriteFile(target
, out
.Bytes(), 0644); err
!= nil {
101 type visitFunc
func(ast
.Node
) ast
.Visitor
103 func (f visitFunc
) Visit(n ast
.Node
) ast
.Visitor
{ return f(n
) }
105 func rewriteCalls(n ast
.Node
) ast
.Visitor
{
106 ce
, ok
:= n
.(*ast
.CallExpr
)
110 return visitFunc(rewriteCalls
)
113 func rewriteCall(ce
*ast
.CallExpr
) {
114 ident
, ok
:= ce
.Fun
.(*ast
.Ident
)
116 // e.g. skip SelectorExpr (data.Less(..) calls)
120 if ident
.Name
== "int" || ident
.Name
== "uint" {
123 if len(ce
.Args
) < 1 {
126 ident
.Name
+= "_func"