From de13a951b38b85195984164819f1ab05fe508677 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 26 Jan 2024 18:20:39 +0000 Subject: [PATCH] [Flax] Update no init test for Flax v0.7.1 (#28735) --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 58ada0226a51bc..ef99786fdff19f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -984,7 +984,7 @@ def test_no_automatic_init(self): # Check if we params can be properly initialized when calling init_weights params = model.init_weights(model.key, model.input_shape) - self.assertIsInstance(params, FrozenDict) + assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}" # Check if all required parmas are initialized keys = set(flatten_dict(unfreeze(params)).keys()) self.assertTrue(all(k in keys for k in model.required_params))