1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
20 if *rewriteRule == "" {
21 rewrite = nil // disable any previous rewrite
24 f := strings.Split(*rewriteRule, "->")
26 fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
29 pattern := parseExpr(f[0], "pattern")
30 replace := parseExpr(f[1], "replacement")
31 rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
34 // parseExpr parses s as an expression.
35 // It might make sense to expand this to allow statement patterns,
36 // but there are problems with preserving formatting and also
37 // with what a wildcard for a statement looks like.
38 func parseExpr(s, what string) ast.Expr {
39 x, err := parser.ParseExpr(s)
41 fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
47 // Keep this function for debugging.
49 func dump(msg string, val reflect.Value) {
50 fmt.Printf("%s:\n", msg)
51 ast.Print(fileSet, val.Interface())
56 // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
57 func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
58 cmap := ast.NewCommentMap(fileSet, p, p.Comments)
59 m := make(map[string]reflect.Value)
60 pat := reflect.ValueOf(pattern)
61 repl := reflect.ValueOf(replace)
63 var rewriteVal func(val reflect.Value) reflect.Value
64 rewriteVal = func(val reflect.Value) reflect.Value {
65 // don't bother if val is invalid to start with
67 return reflect.Value{}
69 val = apply(rewriteVal, val)
73 if match(m, pat, val) {
74 val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
79 r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
80 r.Comments = cmap.Filter(r).Comments() // recreate comments list
84 // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
85 func set(x, y reflect.Value) {
86 // don't bother if x cannot be set or y is invalid
87 if !x.CanSet() || !y.IsValid() {
91 if x := recover(); x != nil {
92 if s, ok := x.(string); ok &&
93 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
94 // x cannot be set to y - ignore this rewrite
103 // Values/types for special cases.
105 objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
106 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
108 identType = reflect.TypeOf((*ast.Ident)(nil))
109 objectPtrType = reflect.TypeOf((*ast.Object)(nil))
110 positionType = reflect.TypeOf(token.NoPos)
111 callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
112 scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
115 // apply replaces each AST field x in val with f(x), returning val.
116 // To avoid extra conversions, f operates on the reflect.Value form.
117 func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
119 return reflect.Value{}
122 // *ast.Objects introduce cycles and are likely incorrect after
123 // rewrite; don't follow them but replace with nil instead
124 if val.Type() == objectPtrType {
128 // similarly for scopes: they are likely incorrect after a rewrite;
129 // replace them with nil
130 if val.Type() == scopePtrType {
134 switch v := reflect.Indirect(val); v.Kind() {
136 for i := 0; i < v.Len(); i++ {
141 for i := 0; i < v.NumField(); i++ {
145 case reflect.Interface:
152 func isWildcard(s string) bool {
153 rune, size := utf8.DecodeRuneInString(s)
154 return size == len(s) && unicode.IsLower(rune)
157 // match reports whether pattern matches val,
158 // recording wildcard submatches in m.
159 // If m == nil, match checks whether pattern == val.
160 func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
161 // Wildcard matches any expression. If it appears multiple
162 // times in the pattern, it must match the same expression
164 if m != nil && pattern.IsValid() && pattern.Type() == identType {
165 name := pattern.Interface().(*ast.Ident).Name
166 if isWildcard(name) && val.IsValid() {
167 // wildcards only match valid (non-nil) expressions.
168 if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
169 if old, ok := m[name]; ok {
170 return match(nil, old, val)
178 // Otherwise, pattern and val must match recursively.
179 if !pattern.IsValid() || !val.IsValid() {
180 return !pattern.IsValid() && !val.IsValid()
182 if pattern.Type() != val.Type() {
187 switch pattern.Type() {
189 // For identifiers, only the names need to match
190 // (and none of the other *ast.Object information).
191 // This is a common case, handle it all here instead
192 // of recursing down any further via reflection.
193 p := pattern.Interface().(*ast.Ident)
194 v := val.Interface().(*ast.Ident)
195 return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
196 case objectPtrType, positionType:
197 // object pointers and token positions always match
200 // For calls, the Ellipsis fields (token.Position) must
201 // match since that is how f(x) and f(x...) are different.
202 // Check them here but fall through for the remaining fields.
203 p := pattern.Interface().(*ast.CallExpr)
204 v := val.Interface().(*ast.CallExpr)
205 if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
210 p := reflect.Indirect(pattern)
211 v := reflect.Indirect(val)
212 if !p.IsValid() || !v.IsValid() {
213 return !p.IsValid() && !v.IsValid()
218 if p.Len() != v.Len() {
221 for i := 0; i < p.Len(); i++ {
222 if !match(m, p.Index(i), v.Index(i)) {
229 for i := 0; i < p.NumField(); i++ {
230 if !match(m, p.Field(i), v.Field(i)) {
236 case reflect.Interface:
237 return match(m, p.Elem(), v.Elem())
240 // Handle token integers, etc.
241 return p.Interface() == v.Interface()
244 // subst returns a copy of pattern with values from m substituted in place
245 // of wildcards and pos used as the position of tokens from the pattern.
246 // if m == nil, subst returns a copy of pattern and doesn't change the line
247 // number information.
248 func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
249 if !pattern.IsValid() {
250 return reflect.Value{}
253 // Wildcard gets replaced with map value.
254 if m != nil && pattern.Type() == identType {
255 name := pattern.Interface().(*ast.Ident).Name
256 if isWildcard(name) {
257 if old, ok := m[name]; ok {
258 return subst(nil, old, reflect.Value{})
263 if pos.IsValid() && pattern.Type() == positionType {
264 // use new position only if old position was valid in the first place
265 if old := pattern.Interface().(token.Pos); !old.IsValid() {
272 switch p := pattern; p.Kind() {
274 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
275 for i := 0; i < p.Len(); i++ {
276 v.Index(i).Set(subst(m, p.Index(i), pos))
281 v := reflect.New(p.Type()).Elem()
282 for i := 0; i < p.NumField(); i++ {
283 v.Field(i).Set(subst(m, p.Field(i), pos))
288 v := reflect.New(p.Type()).Elem()
289 if elem := p.Elem(); elem.IsValid() {
290 v.Set(subst(m, elem, pos).Addr())
294 case reflect.Interface:
295 v := reflect.New(p.Type()).Elem()
296 if elem := p.Elem(); elem.IsValid() {
297 v.Set(subst(m, elem, pos))