# coding=utf-8 # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # Pivoting # The goal of the MADlib pivot function is to provide a data summarization tool # that can do basic OLAP type operations on data stored in one table and output # the summarized data to a second table. Typical operations are count, average, # min, max and standard deviation, however user defined aggregates (UDAs) are # also be allowed. # Please refer to the pivot.sql_in file for the documentation """ @file pivot.py_in """ import plpy import itertools from control import MinWarning from utilities import _assert from utilities import split_quoted_delimited_str from utilities import strip_end_quotes from utilities import extract_keyvalue_params from validate_args import table_exists from validate_args import columns_exist_in_table from validate_args import table_is_empty from validate_args import _get_table_schema_names from validate_args import get_first_schema from validate_args import get_expr_type m4_changequote(`') def pivot(schema_madlib, source_table, out_table, index, pivot_cols, pivot_values, aggregate_func=None, fill_value=None, keep_null=False, output_col_dictionary=False, output_type=None, **kwargs): """ Helper function that can be used to pivot tables Args: @param source_table The original data table @param out_table The output table that contains the dummy variable columns @param index The index columns to group the records by @param pivot_cols The columns to pivot the table @param pivot_values The value columns to be summarized in the pivoted table @param aggregate_func The aggregate function to be applied to the values @param fill_value If specified, determines how to fill NULL values resulting from pivot operation @param keep_null The flag for determining how to handle NULL values in pivot columns Assume we have the following table pivset( id INTEGER, piv FLOAT8, val FLOAT8 ) where the piv column has 3 distinct values (10, 20 and 30). If the pivot function call is : SELECT madlib.pivot('pivset', 'pivout', 'id', 'piv', 'val'); We want to construct the following sql code to pivot the table. CREATE TABLE pivout AS (SELECT id, avg(CASE WHEN "piv" = '10' THEN val ELSE NULL END ) as "val_avg_piv_10", avg(CASE WHEN "piv" = '20' THEN val ELSE NULL END ) as "val_avg_piv_20", avg(CASE WHEN "piv" = '30' THEN val ELSE NULL END ) as "val_avg_piv_30" FROM pivset GROUP BY id ORDER BY id) """ def _fill_value_wrapper(sel_str): """ Wrap a given SQL SELECT statement with COALESCE using a given fill value. No-op if the fill value is not provided """ if fill_value is not None: return " COALESCE({0}, {1}) ".format(sel_str, fill_value) else: return sel_str with MinWarning('warning'): # If there are more than 1000 columns for the output table, we give a # warning as it might give an error. MAX_OUTPUT_COLUMN_COUNT = 1000 # If a column name has more than 63 characters it gets trimmed automatically, # which may cause an exception. Enable the output dictionary in this case. MAX_COLUMN_LENGTH = 63 indices = split_quoted_delimited_str(index) pcols = split_quoted_delimited_str(pivot_cols) pvals = split_quoted_delimited_str(pivot_values) # output type for specific supported types output_type = 'column' if not output_type else output_type.lower() all_output_types = sorted(['array', 'column', 'svec']) try: # allow user to specify a prefix substring of # supported output types. This works because the supported # output types have unique prefixes. output_type = next(s for s in all_output_types if s.startswith(output_type)) except StopIteration: # next() returns a StopIteration if no element found plpy.error("Encoding categorical: Output type should be one of {0}". format(','.join(all_output_types))) is_array_output = output_type in ('array', 'svec') # always build dictionary table if output is array output_col_dictionary = True if is_array_output else output_col_dictionary validate_pivot_coding(source_table, out_table, indices, pcols, pvals) # Strip the end quotes for building output columns (this can only be # performed after the validation) pcols = [strip_end_quotes(pcol.strip()) for pcol in pcols] pvals = [strip_end_quotes(pval.strip()) for pval in pvals] # Create a dictionary that assigns one or more aggregate functions for every # value column. agg_dict = parse_aggregates(pvals, aggregate_func) validate_output_types(source_table, agg_dict, is_array_output) # Find the distinct values of pivot_cols array_agg_str = ', '.join("array_agg(DISTINCT {pcol}) AS {pcol}_values". format(pcol=pcol) for pcol in pcols) if keep_null: # Some platforms don't include NULL values as part of the array_agg(DISTINCT ...) # Below clause checks explicitly for NULL values null_str = ", " + ', '.join( "bool_or(CASE WHEN {pcol} IS NULL THEN True ELSE False END)" "AS {pcol}_isnull".format(pcol=pcol) for pcol in pcols) else: null_str = "" distinct_values = plpy.execute("SELECT {0} {1} FROM {2}". format(array_agg_str, null_str, source_table))[0] # Collect the distinct values for every pivot column into a dictionary pcol_distinct_values = {} pcol_max_length = 0 for pcol in pcols: pcol_tmp = set(item for item in distinct_values[pcol + "_values"]) if not keep_null: pcol_tmp.discard(None) elif distinct_values[pcol + "_isnull"]: pcol_tmp.add(None) pcol_distinct_values[pcol] = sorted(pcol_tmp) # Max pcol length calculation: the name of column (pcol) + # name of longest value in column (item) + # underscore (1) pcol_max_length += (len(pcol) + max([len(str(item)) for item in pcol_tmp]) + 1) # Create the combination of every possible pivot column # Assume piv and piv2 are pivot columns. piv=(1,2) and piv2=(3,4,5) # pivot_comb = ((1,3),(1,4),(1,5),(2,3),(2,4),(2,5)) pivot_comb = list(itertools.product(*([pcol_distinct_values[pcol] for pcol in pcols]))) # Check the max possible length of a output column name # If it is over 63 (postgresql upper limit) create dictionary lookup for pval in pvals: agg_func = agg_dict[pval] # Length calculation: value column length + aggregate length + # 2 underscores + pivots and their values (pcol_max_length) # Example: val _ sum _ piv1_10_piv2_100 col_name_len = (2 + len(pval) + pcol_max_length + max([len(item) for item in agg_func])) if col_name_len > MAX_COLUMN_LENGTH: with MinWarning("warning"): plpy.warning("Pivot: Output columns are renamed to keep them " "under 63 characters. Please refer to " "{source_table}_dictionary for the original names.". format(**locals())) output_col_dictionary = True # Types of pivot columns are needed for building the right columns # in the dictionary table and to decide if a pivot column value needs to # be quoted during comparison (will be quoted if it's a text column) types_str = ', '.join("pg_typeof(\"{pcol}\") as {pcol}". format(pcol=p) for p in pcols) pcol_types = plpy.execute("SELECT {0} FROM {1} LIMIT 1". format(types_str, source_table))[0] if output_col_dictionary: out_dict = out_table + "_dictionary" _assert(not table_exists(out_dict), "Pivot: Output dictionary table already exists!") # Create the empty dictionary table pcol_names_types = ', '.join(" {pcol} {pcol_type} ". format(pcol=pcol, pcol_type=pcol_types[pcol]) for pcol in pcols) plpy.execute(""" CREATE TABLE {out_dict} ( __pivot_cid__ VARCHAR, pval VARCHAR, agg VARCHAR, {pcol_names_types}, col_name VARCHAR) """.format(out_dict=out_dict, pcol_names_types=pcol_names_types)) # List of rows to insert into output dictionary dict_insert_str = [] # Counter for the new output column names dict_counter = 1 pivot_sel_list = [] pivot_from_list = [] for pval in pvals: agg_func = agg_dict[pval] for agg in agg_func: # is using array_output, create a new array for each pval-agg combo if is_array_output: # we store information in the dictionary table for each # index in the array. 'index_counter' is the current index # being updated (resets for each new array) index_counter = 1 sub_pivot_sel_list = [] for comb in pivot_comb: pivot_col_condition = [] # note column name starts with double quotes pivot_col_name = ['{pval}_{agg}'.format(pval=pval, agg=agg)] if output_col_dictionary: # Prepare the entry for the dictionary if not is_array_output: index_name = ("__p_{dict_counter}__". format(dict_counter=dict_counter)) else: # for arrays, index_name is just the index into each array index_name = str(index_counter) index_counter += 1 dict_insert_str.append( "(\'{index_name}\', \'{pval}\', \'{agg}\' ". format(index_name=index_name, pval=pval, agg=agg)) # For every pivot column in a given combination for counter, pcol in enumerate(pcols): if comb[counter] is None: quoted_pcol_value = "NULL" elif pcol_types[pcol] in ("text", "varchar", "character varying"): quoted_pcol_value = "'" + comb[counter] + "'" else: quoted_pcol_value = comb[counter] # If we encounter a NULL value that means it is not filtered # because of keep_null. Use "IS NULL" for comparison if comb[counter] is None: pivot_col_condition.append(" \"{0}\" IS NULL".format(pcol)) pivot_col_name.append("_{0}_null".format(pcol)) else: pivot_col_condition.append(" \"{0}\" = {1}". format(pcol, quoted_pcol_value)) pivot_col_name.append("_{0}_{1}".format(pcol, comb[counter])) if output_col_dictionary: dict_insert_str.append("{0}".format(quoted_pcol_value)) if output_col_dictionary: # Store the whole string as additional info dict_insert_str.append("'{0}')".format(''.join(pivot_col_name))) pivot_col_name = ["__p_" + str(dict_counter) + "__"] dict_counter += 1 # Collecting the whole sql query # Please refer to the earlier comment for a sample output # Build the pivot column with NULL values in tuples that don't # satisfy that column's condition p_name = '"{0}"'.format(''.join(pivot_col_name)) pivot_str_from = ( "(CASE WHEN {condition} THEN {pval} END) AS {p_name}". format(pval=pval, condition=' AND '.join(pivot_col_condition), p_name=p_name)) pivot_from_list.append(pivot_str_from) # Aggregate over each pivot column, while filtering all NULL # values created by previous query. sub_pivot_str_sel = _fill_value_wrapper( "{agg}({p_name}) " " FILTER (WHERE {p_name} IS NOT NULL)". format(agg=agg, p_name=p_name)) if not is_array_output: # keep spaces around the 'AS' sub_pivot_str_sel += " AS " + p_name sub_pivot_sel_list.append(sub_pivot_str_sel) if sub_pivot_sel_list: if is_array_output: if output_type is 'svec': cast_str = '::FLOAT8[]::{0}.svec'.format(schema_madlib) else: cast_str = '::FLOAT8[]' pivot_sel_list.append( 'ARRAY[{all_pivot_sel}]{cast_str} AS "{pval}_{agg}"'. format(all_pivot_sel=', '.join(sub_pivot_sel_list), cast_str=cast_str, pval=pval, agg=agg)) else: pivot_sel_list += sub_pivot_sel_list try: plpy.execute(""" CREATE TABLE {out_table} AS SELECT {index}, {all_pivot_sel_str} FROM ( SELECT {index}, {all_pivot_from_str} FROM {source_table} ) x GROUP BY {index} """.format(out_table=out_table, index=index, source_table=source_table, all_pivot_from_str=', '.join(pivot_from_list), all_pivot_sel_str=', '.join(pivot_sel_list) )) if output_col_dictionary: plpy.execute("INSERT INTO {out_dict} VALUES {insert_sql}". format(out_dict=out_dict, insert_sql=', '.join(dict_insert_str))) except plpy.SPIError: # Warn user if the number of columns is over the limit with MinWarning("warning"): # The column options from value columns and aggregates # times the number of pivot combinations if ((sum([len(item) for item in agg_dict.values()])* len(pivot_comb)) > MAX_OUTPUT_COLUMN_COUNT): plpy.warning( "Pivot: Too many distinct values for pivoting! " "The execution may fail due to too many columns in the " "output table.") else: plpy.warning( "Pivot: Pivoting is only supported over aggregates with " "transition functions defined as STRICT.") raise return None # ------------------------------------------------------------------------------ def parse_aggregates(pvals, aggregate_func): """ Helper function that parses the aggregate function parameter Args: @param pvals The value columns to be summarized in the pivoted table @param aggregate_func The aggregate function to be applied to the values """ """ The aggregate_func can get one of the following forms 1) NULL: Use the default aggregate ('avg') 2) A single aggregate (eg. 'sum') 3) A comma-separated list of aggregates (eg. 'sum,avg') 4) A complete mapping (eg. 'val=sum, val2=[avg,sum]') 5) A partial mapping (eg. 'val2=sum'): Use the default ('avg') for the missing value columns """ param_types = dict.fromkeys(pvals, tuple) agg_dict = extract_keyvalue_params(aggregate_func, param_types) if not agg_dict: agg_list = tuple(split_quoted_delimited_str(aggregate_func)) agg_dict = dict.fromkeys(pvals, (agg_list if agg_list else ('avg', ))) else: for pval in pvals: if pval not in agg_dict: agg_dict[pval] = ('avg', ) return agg_dict # ------------------------------------------------------------------------------ def validate_pivot_coding(source_table, out_table, indices, pivs, vals): """ Args: @param source_table The original data table @param out_table The output table that will contain dummy columns @param indices An array of index column names @param cols An array of categorical column names """ _assert(out_table and out_table.strip().lower() not in ('null', ''), "Pivot: Invalid output table name!") _assert(not table_exists(out_table), "Pivot: Output table already exists!") _assert(source_table and source_table.strip().lower() not in ('null', ''), "Pivot: Invalid data table name!") _assert(table_exists(source_table), "Pivot: Data table ({0}) is missing!". format(source_table)) _assert(not table_is_empty(source_table), "Pivot: Data table ({0}) is empty!". format(source_table)) _assert(indices and indices not in ('null', ''), "Pivot: Invalid index column!") _assert(pivs and pivs not in ('null', ''), "Pivot: Invalid pivot column!") _assert(vals and vals not in ('null', ''), "Pivot: Invalid value column!") _assert(columns_exist_in_table(source_table, indices), "Pivot: Not all columns from {0} present in source table ({1})" .format(indices, source_table)) _assert(columns_exist_in_table(source_table, pivs), "Pivot: Not all columns from {0} present in source table ({1})" .format(pivs, source_table)) _assert(columns_exist_in_table(source_table, vals), "Pivot: Not all columns from {0} present in source table ({1})" .format(vals, source_table)) # ------------------------------------------------------------------------------ def validate_output_types(source_table, agg_dict, is_array_output): """ Args: @param source_table: str, Name of table containing data @param agg_dict: dict, Key-value pair containing aggregates applied for each val column @param is_array_output: bool, Is the pivot output columnar (False) or array (True) Returns: None """ for val, func_iterable in agg_dict.items(): for func in func_iterable: func_call_str = '{0}({1})'.format(func, val) _assert(not ('[]' in get_expr_type(func_call_str, source_table) and is_array_output), "Pivot: Aggregate {0} with an array return type cannot be " "combined with output_type='array' or 'svec'".format(func)) # ---------------------------------------------------------------------- def pivot_help(schema_madlib, message, **kwargs): """ Help function for pivot Args: @param schema_madlib @param message: string, Help message string @param kwargs Returns: String. Help/usage information """ if not message: help_string = """ ----------------------------------------------------------------------- SUMMARY ----------------------------------------------------------------------- Provide a data summarization tool that can do basic OLAP type operations on data stored in one table and output the summarized data to a second table. Typical operations are count, average, min, max and standard deviation, however user defined aggregates (UDAs) are also be allowed. For more details on function usage: SELECT {schema_madlib}.pivot('usage') """ elif message in ['usage', 'help', '?']: help_string = """ ----------------------------------------------------------------------- USAGE ----------------------------------------------------------------------- SELECT {schema_madlib}.pivot( source_table, -- Name of source table containing data for pivoting out_table, -- Name of output table taht contains pivoted data index, -- Comma-separated columns that will form the index -- of the output pivot table pivot_cols, -- Comma-separated columns that will form the -- columns of the output pivot table pivot_values, -- Comma-separated columns that contain the values -- to be summarized in the output pivot table fill_value, -- If specified, determines how to fill NULL values -- resulting from pivot operation keep_null, -- The flag for determining how to handle NULL -- values in pivot columns output_col_dictionary, -- The flag for enabling the creation of the -- output dictionary for shorter column names output_type -- This parameter controls the output format -- of the pivoted variables. -- If 'column', a column is created for each pivot -- If 'array', an array is created combining all pivots -- If 'svec', the array is cast to madlib.svec ); ----------------------------------------------------------------------- OUTPUT ----------------------------------------------------------------------- The output table ('output_table' above) has all the columns present in index column list, plus additional columns for each distinct value in pivot_cols. The column name for the pivot is set as '_'. A dictionary table ('_dictionary') is created if either 'output_col_dictionary' is True or if the auto-generated column names exceed the PostgreSQL limit of 63 bytes . This table gives a mapping between short column names in 'output_table' and the meaning of those columns i.e. which index, value, agg and pivot column they belong to. """ else: help_string = "No such option. Use {schema_madlib}.pivot()" return help_string.format(schema_madlib=schema_madlib) # ---------------------------------------------------------------------