From b4e1610d50617cf3cdca55ee72e0c4ab8a70bf7d Mon Sep 17 00:00:00 2001 From: Abdulla Abdurakhmanov Date: Sun, 25 Aug 2024 00:15:47 +0200 Subject: [PATCH] --gcp-vertex-ai-block-none-harmful argument support --- src/args.rs | 8 ++++++++ src/redacters/gcp_vertex_ai.rs | 17 ++++++++++++++--- src/redacters/simple_image_redacter.rs | 8 ++++---- src/redacters/stream_redacter.rs | 2 +- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/args.rs b/src/args.rs index b84fd89..b18fd1e 100644 --- a/src/args.rs +++ b/src/args.rs @@ -175,6 +175,13 @@ pub struct RedacterArgs { )] pub gcp_vertex_ai_image_model: Option, + #[arg( + long, + help = "Block none harmful content threshold for Vertex AI redacter. Default is BlockOnlyHigh since BlockNone is required a special billing settings.", + default_value = "false" + )] + pub gcp_vertex_ai_block_none_harmful: bool, + #[arg( long, help = "Disable CSV headers (if they are not present)", @@ -309,6 +316,7 @@ impl TryInto for RedacterArgs { native_image_support: self.gcp_vertex_ai_native_image_support, text_model: self.gcp_vertex_ai_text_model.clone(), image_model: self.gcp_vertex_ai_image_model.clone(), + block_none_harmful: self.gcp_vertex_ai_block_none_harmful, }, )), }?; diff --git a/src/redacters/gcp_vertex_ai.rs b/src/redacters/gcp_vertex_ai.rs index 8e44d50..95ed06e 100644 --- a/src/redacters/gcp_vertex_ai.rs +++ b/src/redacters/gcp_vertex_ai.rs @@ -19,6 +19,7 @@ pub struct GcpVertexAiRedacterOptions { pub native_image_support: bool, pub text_model: Option, pub image_model: Option, + pub block_none_harmful: bool, } #[derive(Debug, Clone, ValueStruct)] @@ -30,6 +31,7 @@ pub struct GcpVertexAiRedacter<'a> { options: GcpVertexAiRedacterOptions, #[allow(dead_code)] reporter: &'a AppReporter<'a>, + safety_setting: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold } impl<'a> GcpVertexAiRedacter<'a> { @@ -46,10 +48,18 @@ impl<'a> GcpVertexAiRedacter<'a> { format!("https://{}-aiplatform.googleapis.com",options.gcp_region.value()), None, ).await?; + + let safety_setting = if options.block_none_harmful { + gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone + } else { + gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockOnlyHigh + }; + Ok(GcpVertexAiRedacter { client, options, reporter, + safety_setting, }) } @@ -82,7 +92,7 @@ impl<'a> GcpVertexAiRedacter<'a> { gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment, ].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting { category: category.into(), - threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(), + threshold: self.safety_setting.into(), method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(), }).collect(), contents: vec![ @@ -216,7 +226,7 @@ impl<'a> GcpVertexAiRedacter<'a> { gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment, ].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting { category: category.into(), - threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(), + threshold: self.safety_setting.into(), method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(), }).collect(), contents: vec![ @@ -334,7 +344,7 @@ impl<'a> GcpVertexAiRedacter<'a> { gcloud_sdk::google::cloud::aiplatform::v1beta1::HarmCategory::Harassment, ].into_iter().map(|category| gcloud_sdk::google::cloud::aiplatform::v1beta1::SafetySetting { category: category.into(), - threshold: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockThreshold::BlockNone.into(), + threshold: self.safety_setting.into(), method: gcloud_sdk::google::cloud::aiplatform::v1beta1::safety_setting::HarmBlockMethod::Unspecified.into(), }).collect(), contents: vec![ @@ -537,6 +547,7 @@ mod tests { native_image_support: false, text_model: None, image_model: None, + block_none_harmful: false, }, &reporter, ) diff --git a/src/redacters/simple_image_redacter.rs b/src/redacters/simple_image_redacter.rs index c8c678c..852360f 100644 --- a/src/redacters/simple_image_redacter.rs +++ b/src/redacters/simple_image_redacter.rs @@ -2,7 +2,7 @@ use crate::common_types::TextImageCoords; use crate::errors::AppError; use crate::AppResult; use bytes::Bytes; -use image::{ImageFormat, RgbaImage}; +use image::{ImageFormat, RgbImage}; use mime::Mime; pub fn redact_image_at_coords( @@ -15,7 +15,7 @@ pub fn redact_image_at_coords( message: format!("Unsupported image mime type: {}", mime), })?; let image = image::load_from_memory_with_format(&data, image_format)?; - let mut image = image.to_rgba8(); + let mut image = image.to_rgb8(); redact_rgba_image_at_coords(&mut image, &pii_coords, approximation_factor); let mut output = std::io::Cursor::new(Vec::new()); image.write_to(&mut output, image_format)?; @@ -23,7 +23,7 @@ pub fn redact_image_at_coords( } pub fn redact_rgba_image_at_coords( - image: &mut RgbaImage, + image: &mut RgbImage, pii_coords: &Vec, approximation_factor: f32, ) { @@ -36,7 +36,7 @@ pub fn redact_rgba_image_at_coords( { let safe_x = x.min(image.width() - 1).max(0); let safe_y = y.min(image.height() - 1).max(0); - image.put_pixel(safe_x, safe_y, image::Rgba([0, 0, 0, 255])); + image.put_pixel(safe_x, safe_y, image::Rgb([0, 0, 0])); } } } diff --git a/src/redacters/stream_redacter.rs b/src/redacters/stream_redacter.rs index ffdea6a..c5c3c6d 100644 --- a/src/redacters/stream_redacter.rs +++ b/src/redacters/stream_redacter.rs @@ -515,7 +515,7 @@ impl<'a> StreamRedacter<'a> { RedacterDataItemContent::Value(content) => { let words_set: HashSet<&str> = HashSet::from_iter(content.split(" ").collect::>()); - let mut redacted_image = image.to_rgba8(); + let mut redacted_image = image.to_rgb8(); for text_coord in text_coords { if let Some(text) = &text_coord.text { if !words_set.contains(text.as_str()) {