diff --git a/keras_cv/models/segmentation/segformer/segformer.py b/keras_cv/models/segmentation/segformer/segformer.py index 64640f33e3..c6fe9763b8 100644 --- a/keras_cv/models/segmentation/segformer/segformer.py +++ b/keras_cv/models/segmentation/segformer/segformer.py @@ -29,10 +29,7 @@ class SegFormer(Task): `tf.keras.Model` that implements the `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" and layer names as values. - num_classes: int, the number of classes for the detection model. Note - that the `num_classes` doesn't contain the background class, and the - classes from the data should be represented by integers with range - [0, `num_classes`). + num_classes: int, the number of classes for the detection model, including the background class. projection_filters: int, default 256, number of filters in the convolution layer projecting the concatenated features into a segmentation map.