patman: Support updating a branch with review tags
[platform/kernel/u-boot.git] / tools / patman / status.py
1 # SPDX-License-Identifier: GPL-2.0+
2 #
3 # Copyright 2020 Google LLC
4 #
5 """Talks to the patchwork service to figure out what patches have been reviewed
6 and commented on. Allows creation of a new branch based on the old but with the
7 review tags collected from patchwork.
8 """
9
10 import collections
11 import concurrent.futures
12 from itertools import repeat
13 import re
14
15 import pygit2
16 import requests
17
18 from patman import patchstream
19 from patman.patchstream import PatchStream
20 from patman import terminal
21 from patman import tout
22
23 # Patches which are part of a multi-patch series are shown with a prefix like
24 # [prefix, version, sequence], for example '[RFC, v2, 3/5]'. All but the last
25 # part is optional. This decodes the string into groups. For single patches
26 # the [] part is not present:
27 # Groups: (ignore, ignore, ignore, prefix, version, sequence, subject)
28 RE_PATCH = re.compile(r'(\[(((.*),)?(.*),)?(.*)\]\s)?(.*)$')
29
30 # This decodes the sequence string into a patch number and patch count
31 RE_SEQ = re.compile(r'(\d+)/(\d+)')
32
33 def to_int(vals):
34     """Convert a list of strings into integers, using 0 if not an integer
35
36     Args:
37         vals (list): List of strings
38
39     Returns:
40         list: List of integers, one for each input string
41     """
42     out = [int(val) if val.isdigit() else 0 for val in vals]
43     return out
44
45
46 class Patch(dict):
47     """Models a patch in patchwork
48
49     This class records information obtained from patchwork
50
51     Some of this information comes from the 'Patch' column:
52
53         [RFC,v2,1/3] dm: Driver and uclass changes for tiny-dm
54
55     This shows the prefix, version, seq, count and subject.
56
57     The other properties come from other columns in the display.
58
59     Properties:
60         pid (str): ID of the patch (typically an integer)
61         seq (int): Sequence number within series (1=first) parsed from sequence
62             string
63         count (int): Number of patches in series, parsed from sequence string
64         raw_subject (str): Entire subject line, e.g.
65             "[1/2,v2] efi_loader: Sort header file ordering"
66         prefix (str): Prefix string or None (e.g. 'RFC')
67         version (str): Version string or None (e.g. 'v2')
68         raw_subject (str): Raw patch subject
69         subject (str): Patch subject with [..] part removed (same as commit
70             subject)
71     """
72     def __init__(self, pid):
73         super().__init__()
74         self.id = pid  # Use 'id' to match what the Rest API provides
75         self.seq = None
76         self.count = None
77         self.prefix = None
78         self.version = None
79         self.raw_subject = None
80         self.subject = None
81
82     # These make us more like a dictionary
83     def __setattr__(self, name, value):
84         self[name] = value
85
86     def __getattr__(self, name):
87         return self[name]
88
89     def __hash__(self):
90         return hash(frozenset(self.items()))
91
92     def __str__(self):
93         return self.raw_subject
94
95     def parse_subject(self, raw_subject):
96         """Parse the subject of a patch into its component parts
97
98         See RE_PATCH for details. The parsed info is placed into seq, count,
99         prefix, version, subject
100
101         Args:
102             raw_subject (str): Subject string to parse
103
104         Raises:
105             ValueError: the subject cannot be parsed
106         """
107         self.raw_subject = raw_subject.strip()
108         mat = RE_PATCH.search(raw_subject.strip())
109         if not mat:
110             raise ValueError("Cannot parse subject '%s'" % raw_subject)
111         self.prefix, self.version, seq_info, self.subject = mat.groups()[3:]
112         mat_seq = RE_SEQ.match(seq_info) if seq_info else False
113         if mat_seq is None:
114             self.version = seq_info
115             seq_info = None
116         if self.version and not self.version.startswith('v'):
117             self.prefix = self.version
118             self.version = None
119         if seq_info:
120             if mat_seq:
121                 self.seq = int(mat_seq.group(1))
122                 self.count = int(mat_seq.group(2))
123         else:
124             self.seq = 1
125             self.count = 1
126
127 def compare_with_series(series, patches):
128     """Compare a list of patches with a series it came from
129
130     This prints any problems as warnings
131
132     Args:
133         series (Series): Series to compare against
134         patches (:type: list of Patch): list of Patch objects to compare with
135
136     Returns:
137         tuple
138             dict:
139                 key: Commit number (0...n-1)
140                 value: Patch object for that commit
141             dict:
142                 key: Patch number  (0...n-1)
143                 value: Commit object for that patch
144     """
145     # Check the names match
146     warnings = []
147     patch_for_commit = {}
148     all_patches = set(patches)
149     for seq, cmt in enumerate(series.commits):
150         pmatch = [p for p in all_patches if p.subject == cmt.subject]
151         if len(pmatch) == 1:
152             patch_for_commit[seq] = pmatch[0]
153             all_patches.remove(pmatch[0])
154         elif len(pmatch) > 1:
155             warnings.append("Multiple patches match commit %d ('%s'):\n   %s" %
156                             (seq + 1, cmt.subject,
157                              '\n   '.join([p.subject for p in pmatch])))
158         else:
159             warnings.append("Cannot find patch for commit %d ('%s')" %
160                             (seq + 1, cmt.subject))
161
162
163     # Check the names match
164     commit_for_patch = {}
165     all_commits = set(series.commits)
166     for seq, patch in enumerate(patches):
167         cmatch = [c for c in all_commits if c.subject == patch.subject]
168         if len(cmatch) == 1:
169             commit_for_patch[seq] = cmatch[0]
170             all_commits.remove(cmatch[0])
171         elif len(cmatch) > 1:
172             warnings.append("Multiple commits match patch %d ('%s'):\n   %s" %
173                             (seq + 1, patch.subject,
174                              '\n   '.join([c.subject for c in cmatch])))
175         else:
176             warnings.append("Cannot find commit for patch %d ('%s')" %
177                             (seq + 1, patch.subject))
178
179     return patch_for_commit, commit_for_patch, warnings
180
181 def call_rest_api(subpath):
182     """Call the patchwork API and return the result as JSON
183
184     Args:
185         subpath (str): URL subpath to use
186
187     Returns:
188         dict: Json result
189
190     Raises:
191         ValueError: the URL could not be read
192     """
193     url = 'https://patchwork.ozlabs.org/api/1.2/%s' % subpath
194     response = requests.get(url)
195     if response.status_code != 200:
196         raise ValueError("Could not read URL '%s'" % url)
197     return response.json()
198
199 def collect_patches(series, series_id, rest_api=call_rest_api):
200     """Collect patch information about a series from patchwork
201
202     Uses the Patchwork REST API to collect information provided by patchwork
203     about the status of each patch.
204
205     Args:
206         series (Series): Series object corresponding to the local branch
207             containing the series
208         series_id (str): Patch series ID number
209         rest_api (function): API function to call to access Patchwork, for
210             testing
211
212     Returns:
213         list: List of patches sorted by sequence number, each a Patch object
214
215     Raises:
216         ValueError: if the URL could not be read or the web page does not follow
217             the expected structure
218     """
219     data = rest_api('series/%s/' % series_id)
220
221     # Get all the rows, which are patches
222     patch_dict = data['patches']
223     count = len(patch_dict)
224     num_commits = len(series.commits)
225     if count != num_commits:
226         tout.Warning('Warning: Patchwork reports %d patches, series has %d' %
227                      (count, num_commits))
228
229     patches = []
230
231     # Work through each row (patch) one at a time, collecting the information
232     warn_count = 0
233     for pw_patch in patch_dict:
234         patch = Patch(pw_patch['id'])
235         patch.parse_subject(pw_patch['name'])
236         patches.append(patch)
237     if warn_count > 1:
238         tout.Warning('   (total of %d warnings)' % warn_count)
239
240     # Sort patches by patch number
241     patches = sorted(patches, key=lambda x: x.seq)
242     return patches
243
244 def find_new_responses(new_rtag_list, seq, cmt, patch, rest_api=call_rest_api):
245     """Find new rtags collected by patchwork that we don't know about
246
247     This is designed to be run in parallel, once for each commit/patch
248
249     Args:
250         new_rtag_list (list): New rtags are written to new_rtag_list[seq]
251             list, each a dict:
252                 key: Response tag (e.g. 'Reviewed-by')
253                 value: Set of people who gave that response, each a name/email
254                     string
255         seq (int): Position in new_rtag_list to update
256         cmt (Commit): Commit object for this commit
257         patch (Patch): Corresponding Patch object for this patch
258         rest_api (function): API function to call to access Patchwork, for
259             testing
260     """
261     if not patch:
262         return
263
264     # Get the content for the patch email itself as well as all comments
265     data = rest_api('patches/%s/' % patch.id)
266     pstrm = PatchStream.process_text(data['content'], True)
267
268     rtags = collections.defaultdict(set)
269     for response, people in pstrm.commit.rtags.items():
270         rtags[response].update(people)
271
272     data = rest_api('patches/%s/comments/' % patch.id)
273
274     for comment in data:
275         pstrm = PatchStream.process_text(comment['content'], True)
276         for response, people in pstrm.commit.rtags.items():
277             rtags[response].update(people)
278
279     # Find the tags that are not in the commit
280     new_rtags = collections.defaultdict(set)
281     base_rtags = cmt.rtags
282     for tag, people in rtags.items():
283         for who in people:
284             is_new = (tag not in base_rtags or
285                       who not in base_rtags[tag])
286             if is_new:
287                 new_rtags[tag].add(who)
288     new_rtag_list[seq] = new_rtags
289
290 def show_responses(rtags, indent, is_new):
291     """Show rtags collected
292
293     Args:
294         rtags (dict): review tags to show
295             key: Response tag (e.g. 'Reviewed-by')
296             value: Set of people who gave that response, each a name/email string
297         indent (str): Indentation string to write before each line
298         is_new (bool): True if this output should be highlighted
299
300     Returns:
301         int: Number of review tags displayed
302     """
303     col = terminal.Color()
304     count = 0
305     for tag, people in rtags.items():
306         for who in people:
307             terminal.Print(indent + '%s %s: ' % ('+' if is_new else ' ', tag),
308                            newline=False, colour=col.GREEN, bright=is_new)
309             terminal.Print(who, colour=col.WHITE, bright=is_new)
310             count += 1
311     return count
312
313 def create_branch(series, new_rtag_list, branch, dest_branch, overwrite,
314                   repo=None):
315     """Create a new branch with review tags added
316
317     Args:
318         series (Series): Series object for the existing branch
319         new_rtag_list (list): List of review tags to add, one for each commit,
320                 each a dict:
321             key: Response tag (e.g. 'Reviewed-by')
322             value: Set of people who gave that response, each a name/email
323                 string
324         branch (str): Existing branch to update
325         dest_branch (str): Name of new branch to create
326         overwrite (bool): True to force overwriting dest_branch if it exists
327         repo (pygit2.Repository): Repo to use (use None unless testing)
328
329     Returns:
330         int: Total number of review tags added across all commits
331
332     Raises:
333         ValueError: if the destination branch name is the same as the original
334             branch, or it already exists and @overwrite is False
335     """
336     if branch == dest_branch:
337         raise ValueError(
338             'Destination branch must not be the same as the original branch')
339     if not repo:
340         repo = pygit2.Repository('.')
341     count = len(series.commits)
342     new_br = repo.branches.get(dest_branch)
343     if new_br:
344         if not overwrite:
345             raise ValueError("Branch '%s' already exists (-f to overwrite)" %
346                              dest_branch)
347         new_br.delete()
348     if not branch:
349         branch = 'HEAD'
350     target = repo.revparse_single('%s~%d' % (branch, count))
351     repo.branches.local.create(dest_branch, target)
352
353     num_added = 0
354     for seq in range(count):
355         parent = repo.branches.get(dest_branch)
356         cherry = repo.revparse_single('%s~%d' % (branch, count - seq - 1))
357
358         repo.merge_base(cherry.oid, parent.target)
359         base_tree = cherry.parents[0].tree
360
361         index = repo.merge_trees(base_tree, parent, cherry)
362         tree_id = index.write_tree(repo)
363
364         lines = []
365         if new_rtag_list[seq]:
366             for tag, people in new_rtag_list[seq].items():
367                 for who in people:
368                     lines.append('%s: %s' % (tag, who))
369                     num_added += 1
370         message = patchstream.insert_tags(cherry.message.rstrip(),
371                                           sorted(lines))
372
373         repo.create_commit(
374             parent.name, cherry.author, cherry.committer, message, tree_id,
375             [parent.target])
376     return num_added
377
378 def check_patchwork_status(series, series_id, branch, dest_branch, force,
379                            rest_api=call_rest_api, test_repo=None):
380     """Check the status of a series on Patchwork
381
382     This finds review tags and comments for a series in Patchwork, displaying
383     them to show what is new compared to the local series.
384
385     Args:
386         series (Series): Series object for the existing branch
387         series_id (str): Patch series ID number
388         branch (str): Existing branch to update, or None
389         dest_branch (str): Name of new branch to create, or None
390         force (bool): True to force overwriting dest_branch if it exists
391         rest_api (function): API function to call to access Patchwork, for
392             testing
393         test_repo (pygit2.Repository): Repo to use (use None unless testing)
394     """
395     patches = collect_patches(series, series_id, rest_api)
396     col = terminal.Color()
397     count = len(series.commits)
398     new_rtag_list = [None] * count
399
400     patch_for_commit, _, warnings = compare_with_series(series, patches)
401     for warn in warnings:
402         tout.Warning(warn)
403
404     patch_list = [patch_for_commit.get(c) for c in range(len(series.commits))]
405
406     with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
407         futures = executor.map(
408             find_new_responses, repeat(new_rtag_list), range(count),
409             series.commits, patch_list, repeat(rest_api))
410     for fresponse in futures:
411         if fresponse:
412             raise fresponse.exception()
413
414     num_to_add = 0
415     for seq, cmt in enumerate(series.commits):
416         patch = patch_for_commit.get(seq)
417         if not patch:
418             continue
419         terminal.Print('%3d %s' % (patch.seq, patch.subject[:50]),
420                        colour=col.BLUE)
421         cmt = series.commits[seq]
422         base_rtags = cmt.rtags
423         new_rtags = new_rtag_list[seq]
424
425         indent = ' ' * 2
426         show_responses(base_rtags, indent, False)
427         num_to_add += show_responses(new_rtags, indent, True)
428
429     terminal.Print("%d new response%s available in patchwork%s" %
430                    (num_to_add, 's' if num_to_add != 1 else '',
431                     '' if dest_branch
432                     else ' (use -d to write them to a new branch)'))
433
434     if dest_branch:
435         num_added = create_branch(series, new_rtag_list, branch,
436                                   dest_branch, force, test_repo)
437         terminal.Print(
438             "%d response%s added from patchwork into new branch '%s'" %
439             (num_added, 's' if num_added != 1 else '', dest_branch))