Datum vec_mul_with_vec(PG_FUNCTION_ARGS); PG_FUNCTION_INFO_V1(vec_mul_with_vec); /** * Does element-wise multiplication between two vectors. * * Both vectors must be the same length, * and returns a vector of that length. * * If either vector contains a NULL, * the result is NULL for that position. * If either input is NULL itself, * the result is NULL. * * by Paul A. Jungwirth */ Datum vec_mul_with_vec(PG_FUNCTION_ARGS) { Oid elemTypeId; int16 elemTypeWidth; bool elemTypeByValue; char elemTypeAlignmentCode; int lhsLength; ArrayType *lhsArray, *rhsArray, *retArray; Datum *lhsContent, *rhsContent, *retContent; bool *lhsNulls, *rhsNulls, *retNulls; int i; int dims[1]; int lbs[1]; if (PG_ARGISNULL(0) || PG_ARGISNULL(1)) { PG_RETURN_NULL(); } lhsArray = PG_GETARG_ARRAYTYPE_P(0); rhsArray = PG_GETARG_ARRAYTYPE_P(1); if (ARR_NDIM(lhsArray) == 0 || (ARR_NDIM(rhsArray) == 0)) { PG_RETURN_NULL(); } if (ARR_NDIM(lhsArray) > 1 || (ARR_NDIM(rhsArray) > 1)) { ereport(ERROR, (errmsg("vec_mul: one-dimensional arrays are required"))); } elemTypeId = ARR_ELEMTYPE(lhsArray); if (elemTypeId != INT2OID && elemTypeId != INT4OID && elemTypeId != INT8OID && elemTypeId != FLOAT4OID && elemTypeId != FLOAT8OID && elemTypeId != NUMERICOID) { ereport(ERROR, (errmsg("vec_mul input must be array of SMALLINT, INTEGER, BIGINT, REAL, DOUBLE PRECISION, or NUMERIC"))); } if (elemTypeId != ARR_ELEMTYPE(rhsArray)) { ereport(ERROR, (errmsg("vec_mul input arrays must be the same type"))); } get_typlenbyvalalign(elemTypeId, &elemTypeWidth, &elemTypeByValue, &elemTypeAlignmentCode); deconstruct_array(lhsArray, elemTypeId, elemTypeWidth, elemTypeByValue, elemTypeAlignmentCode, &lhsContent, &lhsNulls, &lhsLength); deconstruct_array(rhsArray, elemTypeId, elemTypeWidth, elemTypeByValue, elemTypeAlignmentCode, &rhsContent, &rhsNulls, &lhsLength); retContent = palloc0(sizeof(Datum) * lhsLength); retNulls = palloc0(sizeof(bool) * lhsLength); for (i = 0; i < lhsLength; i++) { if (lhsNulls[i] || rhsNulls[i]) { retNulls[i] = true; continue; } retNulls[i] = false; switch(elemTypeId) { case INT2OID: retContent[i] = Int16GetDatum(DatumGetInt16(lhsContent[i]) * DatumGetInt16(rhsContent[i])); break; case INT4OID: retContent[i] = Int32GetDatum(DatumGetInt32(lhsContent[i]) * DatumGetInt32(rhsContent[i])); break; case INT8OID: retContent[i] = Int64GetDatum(DatumGetInt64(lhsContent[i]) * DatumGetInt64(rhsContent[i])); break; case FLOAT4OID: retContent[i] = Float4GetDatum(DatumGetFloat4(lhsContent[i]) * DatumGetFloat4(rhsContent[i])); break; case FLOAT8OID: retContent[i] = Float8GetDatum(DatumGetFloat8(lhsContent[i]) * DatumGetFloat8(rhsContent[i])); break; case NUMERICOID: #if PG_VERSION_NUM < 120000 retContent[i] = DirectFunctionCall2(numeric_mul, lhsContent[i], rhsContent[i]); #else retContent[i] = NumericGetDatum(numeric_mul_opt_error(DatumGetNumeric(lhsContent[i]), DatumGetNumeric(rhsContent[i]), NULL)); #endif break; } } dims[0] = lhsLength; lbs[0] = 1; retArray = construct_md_array(retContent, retNulls, 1, dims, lbs, elemTypeId, elemTypeWidth, elemTypeByValue, elemTypeAlignmentCode); PG_RETURN_ARRAYTYPE_P(retArray); } Datum vec_mul_with_scalar(PG_FUNCTION_ARGS); PG_FUNCTION_INFO_V1(vec_mul_with_scalar); /** * Multiples a scalar by all elements of a given vector. * * If the vector contains a NULL, * the result is NULL for that position. * * If the vector itself is NULL, * the result is NULL. * * If the scalar is NULL, * then all elements of the resulting vector are NULL. * * by Paul A. Jungwirth */ Datum vec_mul_with_scalar(PG_FUNCTION_ARGS) { Oid elemTypeId1 = get_fn_expr_argtype(fcinfo->flinfo, 0); Oid elemTypeId2 = get_fn_expr_argtype(fcinfo->flinfo, 1); Oid scalarTypeId; int16 elemTypeWidth; bool elemTypeByValue; char elemTypeAlignmentCode; int inputLength; ArrayType *inputArray, *retArray; Datum *inputContent, scalarContent, *retContent; bool *inputNulls, scalarNull, *retNulls; int arrayPos, scalarPos; int i; pgnum scalar; int dims[1]; int lbs[1]; if (!OidIsValid(elemTypeId1) || !OidIsValid(elemTypeId2)) elog(ERROR, "could not determine data type of input"); if (elemTypeId1 == INT2OID || elemTypeId1 == INT4OID || elemTypeId1 == INT8OID || elemTypeId1 == FLOAT4OID || elemTypeId1 == FLOAT8OID || elemTypeId1 == NUMERICOID) { scalarPos = 0; arrayPos = 1; scalarTypeId = elemTypeId1; } else if (elemTypeId2 == INT2OID || elemTypeId2 == INT4OID || elemTypeId2 == INT8OID || elemTypeId2 == FLOAT4OID || elemTypeId2 == FLOAT8OID || elemTypeId2 == NUMERICOID) { scalarPos = 1; arrayPos = 0; scalarTypeId = elemTypeId2; } else { ereport(ERROR, (errmsg("vec_mul scalar operand must be a numeric type"))); } if (PG_ARGISNULL(arrayPos)) { PG_RETURN_NULL(); } inputArray = PG_GETARG_ARRAYTYPE_P(arrayPos); scalarNull = PG_ARGISNULL(scalarPos); if (ARR_ELEMTYPE(inputArray) != scalarTypeId) { ereport(ERROR, (errmsg("vec_mul array elements and scalar operand must be the same type"))); } if (ARR_NDIM(inputArray) == 0) { PG_RETURN_NULL(); } else if (ARR_NDIM(inputArray) != 1) { ereport(ERROR, (errmsg("vec_mul: one-dimensional arrays are required"))); } get_typlenbyvalalign(scalarTypeId, &elemTypeWidth, &elemTypeByValue, &elemTypeAlignmentCode); deconstruct_array(inputArray, scalarTypeId, elemTypeWidth, elemTypeByValue, elemTypeAlignmentCode, &inputContent, &inputNulls, &inputLength); retContent = palloc0(sizeof(Datum) * inputLength); retNulls = palloc0(sizeof(bool) * inputLength); if (!scalarNull) { scalarContent = PG_GETARG_DATUM(scalarPos); switch(scalarTypeId) { case INT2OID: scalar.i16 = DatumGetInt16(scalarContent); break; case INT4OID: scalar.i32 = DatumGetInt32(scalarContent); break; case INT8OID: scalar.i64 = DatumGetInt64(scalarContent); break; case FLOAT4OID: scalar.f4 = DatumGetFloat4(scalarContent); break; case FLOAT8OID: scalar.f8 = DatumGetFloat8(scalarContent); break; case NUMERICOID: scalar.num = DatumGetNumeric(scalarContent); break; } } for (i = 0; i < inputLength; i++) { if (scalarNull || inputNulls[i]) { retNulls[i] = true; continue; } retNulls[i] = false; switch(scalarTypeId) { case INT2OID: retContent[i] = Int16GetDatum(scalar.i16 * DatumGetInt16(inputContent[i])); break; case INT4OID: retContent[i] = Int32GetDatum(scalar.i32 * DatumGetInt32(inputContent[i])); break; case INT8OID: retContent[i] = Int64GetDatum(scalar.i64 * DatumGetInt64(inputContent[i])); break; case FLOAT4OID: retContent[i] = Float4GetDatum(scalar.f4 * DatumGetFloat4(inputContent[i])); break; case FLOAT8OID: retContent[i] = Float8GetDatum(scalar.f8 * DatumGetFloat8(inputContent[i])); break; case NUMERICOID: #if PG_VERSION_NUM < 120000 retContent[i] = DirectFunctionCall2(numeric_mul, NumericGetDatum(scalar.num), inputContent[i]); #else retContent[i] = NumericGetDatum(numeric_mul_opt_error(scalar.num, DatumGetNumeric(inputContent[i]), NULL)); #endif break; } } dims[0] = inputLength; lbs[0] = 1; retArray = construct_md_array(retContent, retNulls, 1, dims, lbs, scalarTypeId, elemTypeWidth, elemTypeByValue, elemTypeAlignmentCode); PG_RETURN_ARRAYTYPE_P(retArray); }