1 # Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
2 # Use of this source code is governed by a BSD-style license that can be
3 # found in the LICENSE file.
5 """Support generic spreadsheet-like table information."""
7 from __future__ import print_function
13 from chromite.lib import cros_build_lib
17 """Class to represent column headers and rows of data."""
19 __slots__ = ['_column_set', # Set of column headers (for faster lookup)
20 '_columns', # List of column headers in order
21 '_name', # Name to associate with table
22 '_rows', # List of row dicts
27 CSV_BQ = '__BEGINQUOTE__'
28 CSV_EQ = '__ENDQUOTE__'
31 def _SplitCSVLine(line):
32 '''Split a single CSV line into separate values.
34 Behavior illustrated by the following examples, with all but
35 the last example taken from Google Docs spreadsheet behavior:
36 'a,b,c,d': ==> ['a', 'b', 'c', 'd'],
37 'a, b, c, d': ==> ['a', ' b', ' c', ' d'],
38 'a,b,c,': ==> ['a', 'b', 'c', ''],
39 'a,"b c",d': ==> ['a', 'b c', 'd'],
40 'a,"b, c",d': ==> ['a', 'b, c', 'd'],
41 'a,"b, c, d",e': ==> ['a', 'b, c, d', 'e'],
42 'a,"""b, c""",d': ==> ['a', '"b, c"', 'd'],
43 'a,"""b, c"", d",e': ==> ['a', '"b, c", d', 'e'],
44 'a,b\,c,d': ==> ['a', 'b,c', 'd'],
46 Return a list of values.
48 # Split on commas, handling two special cases:
49 # 1) Escaped commas are not separators.
50 # 2) A quoted value can have non-separator commas in it. Quotes
53 for val in re.split(r'(?<!\\),', line):
58 # Handle regular double quotes at beginning/end specially.
60 val = Table.CSV_BQ + val[1:]
61 if val[-1] == '"' and (val[-2] != '"' or val[-3] == '"'):
62 val = val[0:-1] + Table.CSV_EQ
64 # Remove escape characters now.
65 val = val.replace(r'\,', ',') # \ before ,
66 val = val.replace('""', '"') # " before " (Google Spreadsheet syntax)
68 prevval = vals[-1] if vals else None
70 # If previous value started with quote and ended without one, then
71 # the current value is just a continuation of the previous value.
72 if prevval and prevval.startswith(Table.CSV_BQ):
73 val = prevval + "," + val
74 # Once entire value is read, strip surrounding quotes
75 if val.endswith(Table.CSV_EQ):
76 vals[-1] = val[len(Table.CSV_BQ):-len(Table.CSV_EQ)]
79 elif val.endswith(Table.CSV_EQ):
80 vals.append(val[len(Table.CSV_BQ):-len(Table.CSV_EQ)])
84 # If an unpaired Table.CSV_BQ is still in vals, then replace with ".
85 vals = [val.replace(Table.CSV_BQ, '"') for val in vals]
90 def LoadFromCSV(csv_file, name=None):
91 """Create a new Table object by loading contents of |csv_file|."""
92 if type(csv_file) is file:
93 file_handle = csv_file
95 file_handle = open(csv_file, 'r')
98 for line in file_handle:
102 vals = Table._SplitCSVLine(line)
106 table = Table(vals, name=name)
110 table.AppendRow(vals)
114 def __init__(self, columns, name=None):
115 self._columns = columns
116 self._column_set = set(columns)
121 """Return a table-like string representation of this table."""
122 cols = ['%10s' % col for col in self._columns]
123 text = 'Columns: %s\n' % ', '.join(cols)
126 for row in self._rows:
127 vals = ['%10s' % row[col] for col in self._columns]
128 text += 'Row %3d: %s\n' % (ix, ', '.join(vals))
132 def __nonzero__(self):
133 """Define boolean equivalent for this table."""
134 return bool(self._columns)
137 """Length of table equals the number of rows."""
138 return self.GetNumRows()
140 def __eq__(self, other):
141 """Return true if two tables are equal."""
142 # pylint: disable=W0212
143 return self._columns == other._columns and self._rows == other._rows
145 def __ne__(self, other):
146 """Return true if two tables are not equal."""
147 return not self == other
149 def __getitem__(self, index):
150 """Access one or more rows by index or slice."""
151 return self.GetRowByIndex(index)
153 def __delitem__(self, index):
154 """Delete one or more rows by index or slice."""
155 self.RemoveRowByIndex(index)
158 """Declare that this class supports iteration (over rows)."""
159 return self._rows.__iter__()
162 """Return name associated with table, None if not available."""
165 def SetName(self, name):
166 """Set the name associated with table."""
170 """Remove all row data."""
173 def GetNumRows(self):
174 """Return the number of rows in the table."""
175 return len(self._rows)
177 def GetNumColumns(self):
178 """Return the number of columns in the table."""
179 return len(self._columns)
181 def GetColumns(self):
182 """Return list of column names in order."""
183 return list(self._columns)
185 def GetRowByIndex(self, index):
186 """Access one or more rows by index or slice.
188 If more than one row is returned they will be contained in a list.
190 return self._rows[index]
192 def _GenRowFilter(self, id_values):
193 """Return a method that returns true for rows matching |id_values|."""
195 """Filter function for rows with id_values."""
196 for key in id_values:
197 if id_values[key] != row.get(key, None):
202 def GetRowsByValue(self, id_values):
203 """Return list of rows matching key/value pairs in |id_values|."""
204 # If row retrieval by value is heavily used for larger tables, then
205 # the implementation should change to be more efficient, at the
206 # expense of some pre-processing and extra storage.
207 grep = self._GenRowFilter(id_values)
208 return [r for r in self._rows if grep(r)]
210 def GetRowIndicesByValue(self, id_values):
211 """Return list of indices for rows matching k/v pairs in |id_values|."""
212 grep = self._GenRowFilter(id_values)
214 for ix, row in enumerate(self._rows):
220 def _PrepareValuesForAdd(self, values):
221 """Prepare a |values| dict/list to be added as a row.
223 If |values| is a dict, verify that only supported column
224 values are included. Add empty string values for columns
225 not seen in the row. The original dict may be altered.
227 If |values| is a list, translate it to a dict using known
228 column order. Append empty values as needed to match number
231 Return prepared dict.
233 if isinstance(values, dict):
235 if not col in self._column_set:
236 raise LookupError("Tried adding data to unknown column '%s'" % col)
238 for col in self._columns:
239 if not col in values:
240 values[col] = self.EMPTY_CELL
242 elif isinstance(values, list):
243 if len(values) > len(self._columns):
244 raise LookupError("Tried adding row with too many columns")
245 if len(values) < len(self._columns):
246 shortage = len(self._columns) - len(values)
247 values.extend([self.EMPTY_CELL] * shortage)
249 values = dict(zip(self._columns, values))
253 def AppendRow(self, values):
254 """Add a single row of data to the table, according to |values|.
256 The |values| argument can be either a dict or list.
258 row = self._PrepareValuesForAdd(values)
259 self._rows.append(row)
261 def SetRowByIndex(self, index, values):
262 """Replace the row at |index| with values from |values| dict."""
263 row = self._PrepareValuesForAdd(values)
264 self._rows[index] = row
266 def RemoveRowByIndex(self, index):
267 """Remove the row at |index|."""
268 del self._rows[index]
270 def HasColumn(self, name):
271 """Return True if column |name| is in this table, False otherwise."""
272 return name in self._column_set
274 def GetColumnIndex(self, name):
275 """Return the column index for column |name|, -1 if not found."""
276 for ix, col in enumerate(self._columns):
281 def GetColumnByIndex(self, index):
282 """Return the column name at |index|"""
283 return self._columns[index]
285 def InsertColumn(self, index, name, value=None):
286 """Insert a new column |name| into table at index |index|.
288 If |value| is specified, all rows will have |value| in the new column.
289 Otherwise, they will have the EMPTY_CELL value.
291 if self.HasColumn(name):
292 raise LookupError("Column %s already exists in table." % name)
294 self._columns.insert(index, name)
295 self._column_set.add(name)
297 for row in self._rows:
298 row[name] = value if value is not None else self.EMPTY_CELL
300 def AppendColumn(self, name, value=None):
301 """Same as InsertColumn, but new column is appended after existing ones."""
302 self.InsertColumn(self.GetNumColumns(), name, value)
304 def ProcessRows(self, row_processor):
305 """Invoke |row_processor| on each row in sequence."""
306 for row in self._rows:
309 def MergeTable(self, other_table, id_columns, merge_rules=None,
310 allow_new_columns=False, key=None, reverse=False,
312 """Merge |other_table| into this table, identifying rows by |id_columns|.
314 The |id_columns| argument can either be a list of identifying columns names
315 or a single column name (string). The values in these columns will be used
316 to identify the existing row that each row in |other_table| should be
319 The |merge_rules| specify what to do when there is a merge conflict. Every
320 column where a conflict is anticipated should have an entry in the
321 |merge_rules| dict. The value should be one of:
322 'join_with:<text>| = Join the two conflicting values with <text>
323 'accept_this_val' = Keep value in 'this' table and discard 'other' value.
324 'accept_other_val' = Keep value in 'other' table and discard 'this' value.
325 function = Keep return value from function(col_name, this_val, other_val)
327 A default merge rule can be specified with the key '__DEFAULT__' in
330 By default, the |other_table| must not have any columns that don't already
331 exist in this table. To allow new columns to be creating by virtue of their
332 presence in |other_table| set |allow_new_columns| to true.
334 To sort the final merged table, supply |key| and |reverse| arguments exactly
335 as they work with the Sort method.
337 # If requested, allow columns in other_table to create new columns
338 # in this table if this table does not already have them.
339 if allow_new_columns:
340 # pylint: disable=W0212
341 for ix, col in enumerate(other_table._columns):
342 if not self.HasColumn(col):
343 # Create a merge_rule on the fly for this new column.
346 merge_rules[col] = 'accept_other_val'
349 self.InsertColumn(0, col)
351 prevcol = other_table._columns[ix - 1]
352 previx = self.GetColumnIndex(prevcol)
353 self.InsertColumn(previx + 1, col)
355 for other_row in other_table:
356 self._MergeRow(other_row, id_columns, merge_rules=merge_rules)
358 # Optionally re-sort the merged table.
360 self.Sort(key, reverse=reverse)
363 self.SetName(new_name)
364 elif self.GetName() and other_table.GetName():
365 self.SetName(self.GetName() + ' + ' + other_table.GetName())
367 def _GetIdValuesForRow(self, row, id_columns):
368 """Return a dict with values from |row| in |id_columns|."""
369 id_values = dict((col, row[col]) for col in
370 cros_build_lib.iflatten_instance(id_columns))
373 def _MergeRow(self, other_row, id_columns, merge_rules=None):
374 """Merge |other_row| into this table.
376 See MergeTables for description of |id_columns| and |merge_rules|.
378 id_values = self._GetIdValuesForRow(other_row, id_columns)
380 row_indices = self.GetRowIndicesByValue(id_values)
382 row_index = row_indices[0]
383 row = self.GetRowByIndex(row_index)
384 for col in other_row:
386 # Find the merge rule that applies to this column, if there is one.
389 merge_rule = merge_rules.get(col, None)
390 if not merge_rule and merge_rules:
391 merge_rule = merge_rules.get('__DEFAULT__', None)
394 val = self._MergeColValue(col, row[col], other_row[col],
395 merge_rule=merge_rule)
397 msg = "Failed to merge '%s' value in row %r" % (col, id_values)
398 print(msg, file=sys.stderr)
404 # Cannot add new columns to row this way.
405 raise LookupError("Tried merging data to unknown column '%s'" % col)
406 self.SetRowByIndex(row_index, row)
408 self.AppendRow(other_row)
410 def _MergeColValue(self, col, val, other_val, merge_rule):
411 """Merge |col| values |val| and |other_val| according to |merge_rule|.
413 See MergeTable method for explanation of option |merge_rule|.
419 raise ValueError("Cannot merge column values without rule: '%s' vs '%s'" %
421 elif inspect.isfunction(merge_rule):
423 return merge_rule(col, val, other_val)
425 pass # Fall through to exception at end
426 elif merge_rule == 'accept_this_val':
428 elif merge_rule == 'accept_other_val':
431 match = re.match(r'join_with:(.+)$', merge_rule)
433 return match.group(1).join(v for v in (val, other_val) if v)
435 raise ValueError("Invalid merge rule (%s) for values '%s' and '%s'." %
436 (merge_rule, val, other_val))
438 def Sort(self, key, reverse=False):
439 """Sort the rows using the given |key| function."""
440 self._rows.sort(key=key, reverse=reverse)
442 def WriteCSV(self, filehandle, hiddencols=None):
443 """Write this table out as comma-separated values to |filehandle|.
445 To skip certain columns during the write, use the |hiddencols| set.
448 """Filter function for columns not in hiddencols."""
449 return not hiddencols or col not in hiddencols
451 cols = [col for col in self._columns if ColFilter(col)]
452 filehandle.write(','.join(cols) + '\n')
453 for row in self._rows:
454 vals = [row.get(col, self.EMPTY_CELL) for col in cols]
455 filehandle.write(','.join(vals) + '\n')