Skip to content

Commit

Permalink
updates to current main
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Dec 4, 2023
1 parent 28b5049 commit 82c314b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
20 changes: 10 additions & 10 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1167,10 +1167,10 @@ pub fn input_wrapper(
);
quote! {
#[doc = #doc1]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::CustomModule)]
pub struct FromTuple;
#[doc = #doc2]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)]
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::prelude::CustomModule)]
pub struct IntoTuple;
}
};
Expand Down Expand Up @@ -1235,16 +1235,16 @@ pub fn input_wrapper(
let doc2 = format!("Module to convert a [`{}`] into a tuple.", wrapper_ident,);
quote! {
#[doc = #doc1]
impl<#(#wrapper_generic_names), *> crate::prelude::Module<(#(#field_ty_names), *)> for FromTuple {
impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<(#(#field_ty_names), *)> for FromTuple {
type Output = #wrapper_ident<#(#wrapper_generic_names), *>;
fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result<Self::Output, crate::prelude::Error> {
fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result<Self::Output, ::dfdx::prelude::Error> {
Ok(x.into())
}
}
#[doc = #doc2]
impl<#(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple {
impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple {
type Output = (#(#field_ty_names), *);
fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, crate::prelude::Error> {
fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, ::dfdx::prelude::Error> {
Ok(x.into())
}
}
Expand Down Expand Up @@ -1295,7 +1295,7 @@ pub fn input_wrapper(
);
}
contains_ident = true;
quote!(<M as crate::prelude::Module<#ty>>::Output)
quote!(<M as ::dfdx::prelude::Module<#ty>>::Output)
} else {
quote!(#ty_ident)
}
Expand Down Expand Up @@ -1334,15 +1334,15 @@ pub fn input_wrapper(

let field_access_module = quote! {
#[doc = #doc]
impl<M: crate::prelude::Module<#ty>, #(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for crate::prelude::On<#on_acccess, M> {
impl<M: ::dfdx::prelude::Module<#ty>, #(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for ::dfdx::prelude::On<#on_acccess, M> {
type Output = #wrapper_ident<#(#output_generics), *>;
fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, crate::prelude::Error> {
fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, ::dfdx::prelude::Error> {
#(#field_extraction)*
let #forward = self.t.try_forward(#forward)?;
let x = #field_replacement;
Ok(x)
}
fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, crate::prelude::Error> {
fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result<Self::Output, ::dfdx::prelude::Error> {
#(#field_extraction)*
let #forward = self.t.try_forward_mut(#forward)?;
let x = #field_replacement;
Expand Down
7 changes: 3 additions & 4 deletions dfdx/src/nn/layers/on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ use std::marker::PhantomData;
// TODO: try making a Call module, whih allows calling an arbitrary method on the input.

/// Applies module `T` into an input field from a wrapper.
#[derive(
Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)]
#[repr(transparent)]
pub struct On<N, T> {
#[module]
#[serialize]
#[cfg_attr(feature = "safetensors", serialize)]
pub t: T,

pub _n: PhantomData<N>,
Expand Down

0 comments on commit 82c314b

Please sign in to comment.