Skip to content

Commit

Permalink
Gemini LLM based coarse image redaction (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdolence authored Aug 16, 2024
1 parent 3feefca commit e5561a2
Show file tree
Hide file tree
Showing 8 changed files with 560 additions and 318 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pdfium-render = { version = "0.8", features = ["thread_safe", "image"], optional
image = "0.25"
bytes = { version = "1" }
arboard = { version = "3", features = ["image"], optional = true }
serde_json = "1"


[dev-dependencies]
Expand Down
5 changes: 3 additions & 2 deletions src/commands/copy_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::file_converters::FileConverters;
use crate::file_systems::{DetectFileSystem, FileSystemConnection, FileSystemRef};
use crate::file_tools::{FileMatcher, FileMatcherResult, FileMimeOverride};
use crate::redacters::{
RedactSupportedOptions, Redacter, RedacterBaseOptions, RedacterOptions, Redacters,
redact_stream, RedactSupportedOptions, Redacter, RedacterBaseOptions, RedacterOptions,
Redacters,
};
use crate::reporter::AppReporter;
use crate::AppResult;
Expand Down Expand Up @@ -327,7 +328,7 @@ async fn redact_upload_file<
}
}
if !support_redacters.is_empty() {
match crate::redacters::redact_stream(
match redact_stream(
&support_redacters,
redacter_base_options,
source_reader,
Expand Down
2 changes: 2 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub enum AppError {
ClipboardError(#[from] arboard::Error),
#[error("SystemTimeError: {0}")]
SystemTimeError(#[from] SystemTimeError),
#[error("JSON serialization error: {0}")]
JsonSerializeError(#[from] serde_json::Error),
#[error("System error: {message}")]
SystemError { message: String },
}
Expand Down
193 changes: 187 additions & 6 deletions src/redacters/gemini_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::common_types::GcpProjectId;
use crate::errors::AppError;
use crate::file_systems::FileSystemRef;
use crate::redacters::{
RedactSupportedOptions, Redacter, RedacterDataItem, RedacterDataItemContent, Redacters,
redact_image_at_coords, PiiImageCoords, RedactSupportedOptions, Redacter, RedacterDataItem,
RedacterDataItemContent, Redacters,
};
use crate::reporter::AppReporter;
use crate::AppResult;
Expand Down Expand Up @@ -161,17 +162,191 @@ impl<'a> GeminiLlmRedacter<'a> {
}),
}
}

pub async fn redact_image_file(&self, input: RedacterDataItem) -> AppResult<RedacterDataItem> {
let model_name = self
.gemini_llm_options
.gemini_model
.as_ref()
.map(|model_name| model_name.value().to_string())
.unwrap_or_else(|| Self::DEFAULT_GEMINI_MODEL.to_string());

match input.content {
RedacterDataItemContent::Image { mime_type, data } => {
let image_format =
image::ImageFormat::from_mime_type(&mime_type).ok_or_else(|| {
AppError::SystemError {
message: format!("Unsupported image mime type: {}", mime_type),
}
})?;
let image = image::load_from_memory_with_format(&data, image_format)?;
let resized_image = image.resize(
image.width().min(1024),
image.height().min(1024),
image::imageops::FilterType::Gaussian,
);
let mut resized_image_bytes = std::io::Cursor::new(Vec::new());
resized_image.write_to(&mut resized_image_bytes, image_format)?;
let resized_image_data = resized_image_bytes.into_inner();

let mut request = tonic::Request::new(
gcloud_sdk::google::ai::generativelanguage::v1beta::GenerateContentRequest {
model: model_name,
safety_settings: vec![
gcloud_sdk::google::ai::generativelanguage::v1beta::HarmCategory::HateSpeech,
gcloud_sdk::google::ai::generativelanguage::v1beta::HarmCategory::SexuallyExplicit,
gcloud_sdk::google::ai::generativelanguage::v1beta::HarmCategory::DangerousContent,
gcloud_sdk::google::ai::generativelanguage::v1beta::HarmCategory::Harassment,
].into_iter().map(|category| gcloud_sdk::google::ai::generativelanguage::v1beta::SafetySetting {
category: category.into(),
threshold: gcloud_sdk::google::ai::generativelanguage::v1beta::safety_setting::HarmBlockThreshold::BlockNone.into(),
}).collect(),
contents: vec![
gcloud_sdk::google::ai::generativelanguage::v1beta::Content {
parts: vec![
gcloud_sdk::google::ai::generativelanguage::v1beta::Part {
data: Some(
gcloud_sdk::google::ai::generativelanguage::v1beta::part::Data::Text(
format!("Find anything in the attached image that look like personal information. \
Return their coordinates with x1,y1,x2,y2 as pixel coordinates and the corresponding text. \
The image width is: {}. The image height is: {}.", resized_image.width(), resized_image.height()),
),
),
},
gcloud_sdk::google::ai::generativelanguage::v1beta::Part {
data: Some(
gcloud_sdk::google::ai::generativelanguage::v1beta::part::Data::InlineData(
gcloud_sdk::google::ai::generativelanguage::v1beta::Blob {
mime_type: mime_type.to_string(),
data: resized_image_data.clone(),
}
),
),
}
],
role: "user".to_string(),
},
],
generation_config: Some(
gcloud_sdk::google::ai::generativelanguage::v1beta::GenerationConfig {
candidate_count: Some(1),
temperature: Some(0.2),
response_mime_type: mime::APPLICATION_JSON.to_string(),
response_schema: Some(
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Array.into(),
items: Some(Box::new(
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Object.into(),
properties: vec![
(
"x1".to_string(),
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Number.into(),
..std::default::Default::default()
},
),
(
"y1".to_string(),
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Number.into(),
..std::default::Default::default()
},
),
(
"x2".to_string(),
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Number.into(),
..std::default::Default::default()
},
),
(
"y2".to_string(),
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::Number.into(),
..std::default::Default::default()
},
),
(
"text".to_string(),
gcloud_sdk::google::ai::generativelanguage::v1beta::Schema {
r#type: gcloud_sdk::google::ai::generativelanguage::v1beta::Type::String.into(),
..std::default::Default::default()
},
),
].into_iter().collect(),
required: vec!["x1".to_string(), "y1".to_string(), "x2".to_string(), "y2".to_string()],
..std::default::Default::default()
}
)),
..std::default::Default::default()
}
),
..std::default::Default::default()
},
),
..std::default::Default::default()
},
);
request.metadata_mut().insert(
"x-goog-user-project",
gcloud_sdk::tonic::metadata::MetadataValue::<tonic::metadata::Ascii>::try_from(
self.gemini_llm_options.project_id.as_ref(),
)?,
);
let response = self.client.get().generate_content(request).await?;

let inner = response.into_inner();
if let Some(content) = inner.candidates.first().and_then(|c| c.content.as_ref()) {
let content_json =
content
.parts
.iter()
.fold("".to_string(), |acc, entity| match &entity.data {
Some(
gcloud_sdk::google::ai::generativelanguage::v1beta::part::Data::Text(
text,
),
) => acc + text,
_ => acc,
});
let pii_image_coords: Vec<PiiImageCoords> =
serde_json::from_str(&content_json)?;
Ok(RedacterDataItem {
file_ref: input.file_ref,
content: RedacterDataItemContent::Image {
mime_type: mime_type.clone(),
data: redact_image_at_coords(
mime_type.clone(),
resized_image_data.into(),
pii_image_coords,
0.25,
)?,
},
})
} else {
Err(AppError::SystemError {
message: "No content item in the response".to_string(),
})
}
}
_ => Err(AppError::SystemError {
message: "Unsupported item for text redacting".to_string(),
}),
}
}
}

impl<'a> Redacter for GeminiLlmRedacter<'a> {
async fn redact(&self, input: RedacterDataItem) -> AppResult<RedacterDataItem> {
match &input.content {
RedacterDataItemContent::Value(_) => self.redact_text_file(input).await,
RedacterDataItemContent::Table { .. }
| RedacterDataItemContent::Image { .. }
| RedacterDataItemContent::Pdf { .. } => Err(AppError::SystemError {
message: "Attempt to redact of unsupported type".to_string(),
}),
RedacterDataItemContent::Image { .. } => self.redact_image_file(input).await,
RedacterDataItemContent::Table { .. } | RedacterDataItemContent::Pdf { .. } => {
Err(AppError::SystemError {
message: "Attempt to redact of unsupported type".to_string(),
})
}
}
}

Expand All @@ -183,9 +358,15 @@ impl<'a> Redacter for GeminiLlmRedacter<'a> {
Some(media_type) if Redacters::is_mime_text(media_type) => {
RedactSupportedOptions::Supported
}
Some(media_type) if Redacters::is_mime_image(media_type) => {
RedactSupportedOptions::Supported
}
Some(media_type) if Redacters::is_mime_table(media_type) => {
RedactSupportedOptions::SupportedAsText
}
Some(media_type) if Redacters::is_mime_pdf(media_type) => {
RedactSupportedOptions::SupportedAsImages
}
_ => RedactSupportedOptions::Unsupported,
})
}
Expand Down
Loading

0 comments on commit e5561a2

Please sign in to comment.