#include "postgres.h" #include "fmgr.h" #include "utils/guc.h" #include "parser/analyze.h" #include "nodes/nodeFuncs.h" PG_MODULE_MAGIC; void _PG_init(void); bool safeupdate_enabled; static post_parse_analyze_hook_type prev_post_parse_analyze_hook = NULL; static void delete_needs_where_check(ParseState *pstate, Query *query) { ListCell *l; if (!safeupdate_enabled) return; if (query->hasModifyingCTE) { foreach(l, query->cteList) { CommonTableExpr *cte = (CommonTableExpr *) lfirst(l); delete_needs_where_check(pstate, (Query *) cte->ctequery); } } switch (query->commandType) { case CMD_DELETE: if (query->commandType == CMD_DELETE) { Assert(query->jointree != NULL); if (query->jointree->quals == NULL) ereport(ERROR, (errcode(ERRCODE_CARDINALITY_VIOLATION), errmsg("DELETE requires a WHERE clause"), NULL)); } break; case CMD_UPDATE: Assert(query->jointree != NULL); if (query->jointree->quals == NULL) ereport(ERROR, (errcode(ERRCODE_CARDINALITY_VIOLATION), errmsg("UPDATE requires a WHERE clause"), NULL)); default: break; } if (prev_post_parse_analyze_hook != NULL) (*prev_post_parse_analyze_hook) (pstate, query); } void _PG_init(void) { DefineCustomBoolVariable("safeupdate.enabled", "Enforce qualified updates", "Prevent DML without a WHERE clause", &safeupdate_enabled, 1, PGC_SUSET, 0, NULL, NULL, NULL); prev_post_parse_analyze_hook = post_parse_analyze_hook; post_parse_analyze_hook = delete_needs_where_check; }