diff --git a/src/commands/epgsql_cmd_prepared_query2.erl b/src/commands/epgsql_cmd_prepared_query2.erl index 198ac10..4ce4c63 100644 --- a/src/commands/epgsql_cmd_prepared_query2.erl +++ b/src/commands/epgsql_cmd_prepared_query2.erl @@ -32,7 +32,7 @@ init({Name, Parameters}) -> #pquery2{name = Name, params = Parameters}. -execute(Sock, #pquery2{name = Name, params = Params} = State) -> +execute(Sock, #pquery2{name = Name} = State) -> case maps:get(Name, epgsql_sock:get_stmts(Sock), undefined) of undefined -> Error = #error{ @@ -43,19 +43,28 @@ execute(Sock, #pquery2{name = Name, params = Params} = State) -> extra = [] }, {finish, {error, Error}, Sock}; - #statement{types = Types} = Stmt -> - TypedParams = lists:zip(Types, Params), - #statement{name = StatementName, columns = Columns} = Stmt, - Codec = epgsql_sock:get_codec(Sock), - Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec), - Bin2 = epgsql_wire:encode_formats(Columns), - Commands = - [ - epgsql_wire:encode_bind("", StatementName, Bin1, Bin2), - epgsql_wire:encode_execute("", 0), - epgsql_wire:encode_sync() - ], - {send_multi, Commands, Sock, State#pquery2{stmt = Stmt}} + #statement{} = Stmt -> + do_execute(Sock, State, Stmt) + end. + +do_execute(Sock, State, Statement) -> + #pquery2{params = Params} = State, + #statement{types = Types, name = StatementName, columns = Columns} = Statement, + TypedParams = lists:zip(Types, Params), + Codec = epgsql_sock:get_codec(Sock), + try + Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec), + Bin2 = epgsql_wire:encode_formats(Columns), + Commands = + [ + epgsql_wire:encode_bind("", StatementName, Bin1, Bin2), + epgsql_wire:encode_execute("", 0), + epgsql_wire:encode_sync() + ], + {send_multi, Commands, Sock, State#pquery2{stmt = Statement}} + catch + throw:{bad_param, Type, Value} -> + {finish, {error, {bad_param, Type, Value}}, Sock} end. %% prepared query diff --git a/src/epgsql_idatetime.erl b/src/epgsql_idatetime.erl index 13d6f03..c6fceec 100644 --- a/src/epgsql_idatetime.erl +++ b/src/epgsql_idatetime.erl @@ -96,7 +96,9 @@ i2timestamp2(D, T) -> timestamp2i({Date, Time}) -> D = date2j(Date) - ?POSTGRES_EPOC_JDATE, - D * ?USECS_PER_DAY + time2i(Time). + D * ?USECS_PER_DAY + time2i(Time); +timestamp2i(X) -> + throw(bad_param). now2i({MegaSecs, Secs, MicroSecs}) -> (MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs - ?POSTGRES_EPOC_USECS. diff --git a/src/epgsql_wire.erl b/src/epgsql_wire.erl index 8e7cb5c..c75232a 100644 --- a/src/epgsql_wire.erl +++ b/src/epgsql_wire.erl @@ -265,10 +265,16 @@ encode_parameters([], Count, Formats, Values, _Codec) -> [<>, Formats, <> | lists:reverse(Values)]; encode_parameters([P | T], Count, Formats, Values, Codec) -> - {Format, Value} = encode_parameter(P, Codec), - Formats2 = <>, - Values2 = [Value | Values], - encode_parameters(T, Count + 1, Formats2, Values2, Codec). + try + {Format, Value} = encode_parameter(P, Codec), + Formats2 = <>, + Values2 = [Value | Values], + encode_parameters(T, Count + 1, Formats2, Values2, Codec) + catch + throw:bad_param -> + {Type, Value0} = P, + throw({bad_param, Type, Value0}) + end. %% @doc encode single 'typed' parameter -spec encode_parameter({Type, Val :: any()}, diff --git a/test/epgsql_SUITE.erl b/test/epgsql_SUITE.erl index cbaecfe..4b6a4ce 100644 --- a/test/epgsql_SUITE.erl +++ b/test/epgsql_SUITE.erl @@ -97,6 +97,7 @@ groups() -> prepared_query, prepared_query2, + prepared_query2_bad_params, select, insert, update, @@ -488,6 +489,20 @@ prepared_query2(Config) -> Module:prepared_query2(C, "non_existent_query", [4]) end). +%% Checks that we don't crash ungracefully if user provides parameters that cannot be +%% encoded by the intended codec. +prepared_query2_bad_params(Config) -> + Module = ?config(module, Config), + epgsql_ct:with_connection(Config, fun(C) -> + Name = "bad_params", + Column = get_type_col(timestamp), + SQL = io_lib:format("insert into test_table2(~s) values ($1)", [Column]), + {ok, _} = Module:parse2(C, Name, SQL, []), + {error, {bad_param, timestamp, <<"2024-06-30 00:10:00">>}} = + Module:prepared_query2(C, Name, [<<"2024-06-30 00:10:00">>]), + ok + end). + select(Config) -> Module = ?config(module, Config), epgsql_ct:with_connection(Config, fun(C) ->