Skip to content

Commit

Permalink
--gcp-vertex-ai-block-none-harmful argument support
Browse files Browse the repository at this point in the history
  • Loading branch information
abdolence committed Aug 24, 2024
1 parent 68074e1 commit b4e1610
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
8 changes: 8 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ pub struct RedacterArgs {
)]
pub gcp_vertex_ai_image_model: Option<GcpVertexAiModelName>,

#[arg(
long,
help = "Block none harmful content threshold for Vertex AI redacter. Default is BlockOnlyHigh since BlockNone is required a special billing settings.",
default_value = "false"
)]
pub gcp_vertex_ai_block_none_harmful: bool,

#[arg(
long,
help = "Disable CSV headers (if they are not present)",
Expand Down Expand Up @@ -309,6 +316,7 @@ impl TryInto<RedacterOptions> for RedacterArgs {
native_image_support: self.gcp_vertex_ai_native_image_support,
text_model: self.gcp_vertex_ai_text_model.clone(),
image_model: self.gcp_vertex_ai_image_model.clone(),
block_none_harmful: self.gcp_vertex_ai_block_none_harmful,
},
)),
}?;
Expand Down
17 changes: 14 additions & 3 deletions src/redacters/gcp_vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct GcpVertexAiRedacterOptions {
pub native_image_support: bool,
pub text_model: Option<GcpVertexAiModelName>,
pub image_model: Option<GcpVertexAiModelName>,
pub block_none_harmful: bool,
}

#[derive(Debug, Clone, ValueStruct)]
Expand All @@ -30,6 +31,7 @@ pub struct GcpVertexAiRedacter<'a> {
options: GcpVertexAiRedacterOptions,
#[allow(dead_code)]
reporter: &'a AppReporter<'a>,
safety_setting: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold
}

impl<'a> GcpVertexAiRedacter<'a> {
Expand All @@ -46,10 +48,18 @@ impl<'a> GcpVertexAiRedacter<'a> {
format!("https://{}-aiplatform.googleapis.com",options.gcp_region.value()),
None,
).await?;

let safety_setting = if options.block_none_harmful {
gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone
} else {
gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockOnlyHigh
};

Ok(GcpVertexAiRedacter {
client,
options,
reporter,
safety_setting,
})
}

Expand Down Expand Up @@ -82,7 +92,7 @@ impl<'a> GcpVertexAiRedacter<'a> {
gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment,
].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting {
category: category.into(),
threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(),
threshold: self.safety_setting.into(),
method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(),
}).collect(),
contents: vec![
Expand Down Expand Up @@ -216,7 +226,7 @@ impl<'a> GcpVertexAiRedacter<'a> {
gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment,
].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting {
category: category.into(),
threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(),
threshold: self.safety_setting.into(),
method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(),
}).collect(),
contents: vec![
Expand Down Expand Up @@ -334,7 +344,7 @@ impl<'a> GcpVertexAiRedacter<'a> {
gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment,
].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting {
category: category.into(),
threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(),
threshold: self.safety_setting.into(),
method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(),
}).collect(),
contents: vec![
Expand Down Expand Up @@ -537,6 +547,7 @@ mod tests {
native_image_support: false,
text_model: None,
image_model: None,
block_none_harmful: false,
},
&reporter,
)
Expand Down
8 changes: 4 additions & 4 deletions src/redacters/simple_image_redacter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::common_types::TextImageCoords;
use crate::errors::AppError;
use crate::AppResult;
use bytes::Bytes;
use image::{ImageFormat, RgbaImage};
use image::{ImageFormat, RgbImage};
use mime::Mime;

pub fn redact_image_at_coords(
Expand All @@ -15,15 +15,15 @@ pub fn redact_image_at_coords(
message: format!("Unsupported image mime type: {}", mime),
})?;
let image = image::load_from_memory_with_format(&data, image_format)?;
let mut image = image.to_rgba8();
let mut image = image.to_rgb8();
redact_rgba_image_at_coords(&mut image, &pii_coords, approximation_factor);
let mut output = std::io::Cursor::new(Vec::new());
image.write_to(&mut output, image_format)?;
Ok(output.into_inner().into())
}

pub fn redact_rgba_image_at_coords(
image: &mut RgbaImage,
image: &mut RgbImage,
pii_coords: &Vec<TextImageCoords>,
approximation_factor: f32,
) {
Expand All @@ -36,7 +36,7 @@ pub fn redact_rgba_image_at_coords(
{
let safe_x = x.min(image.width() - 1).max(0);
let safe_y = y.min(image.height() - 1).max(0);
image.put_pixel(safe_x, safe_y, image::Rgba([0, 0, 0, 255]));
image.put_pixel(safe_x, safe_y, image::Rgb([0, 0, 0]));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/redacters/stream_redacter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ impl<'a> StreamRedacter<'a> {
RedacterDataItemContent::Value(content) => {
let words_set: HashSet<&str> =
HashSet::from_iter(content.split(" ").collect::<Vec<_>>());
let mut redacted_image = image.to_rgba8();
let mut redacted_image = image.to_rgb8();
for text_coord in text_coords {
if let Some(text) = &text_coord.text {
if !words_set.contains(text.as_str()) {
Expand Down

0 comments on commit b4e1610

Please sign in to comment.