Skip to content

Commit

Permalink
[DocString] Support a revision in the docstring `add_code_sample_do…
Browse files Browse the repository at this point in the history
…cstrings` to facilitate integrations (#27645)

* initial commit

* dummy changes

* style

* Update src/transformers/utils/doc.py

Co-authored-by: Alex McKinney <[email protected]>

* nits

* nit use ` if re.match(r'^refs/pr/\d*', revision):`

* restrict

* nit

* test the doc vuilder

* wow

* oke the order was wrong

---------

Co-authored-by: Alex McKinney <[email protected]>
  • Loading branch information
ArthurZucker and vvvm23 authored Nov 24, 2023
1 parent 2098d34 commit a6d178e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,13 +1267,14 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)


def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None):
def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None, revision=None):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
revision=revision,
)(model_class.__call__)


Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/albert/modeling_flax_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
module_class = FlaxAlbertForMaskedLMModule


append_call_sample_docstring(FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
append_call_sample_docstring(
FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
)


class FlaxAlbertForSequenceClassificationModule(nn.Module):
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/utils/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ def add_code_sample_docstrings(
expected_output=None,
expected_loss=None,
real_checkpoint=None,
revision=None,
):
def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise
Expand Down Expand Up @@ -1143,6 +1144,15 @@ def docstring_decorator(fn):
func_doc = (fn.__doc__ or "") + "".join(docstr)
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
built_doc = code_sample.format(**doc_kwargs)
if revision is not None:
if re.match(r"^refs/pr/\\d+", revision):
raise ValueError(
f"The provided revision '{revision}' is incorrect. It should point to"
" a pull request reference on the hub like 'refs/pr/6'"
)
built_doc = built_doc.replace(
f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
)
fn.__doc__ = func_doc + output_doc + built_doc
return fn

Expand Down

0 comments on commit a6d178e

Please sign in to comment.