vizia_derive/
lens.rs

1#![allow(missing_docs)]
2
3// Adapted from Druid lens.rs
4
5// Copyright 2019 The Druid Authors.
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10//
11//     http://www.apache.org/licenses/LICENSE-2.0
12//
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
18
19// use proc_macro2::{Ident, Span};
20use quote::quote;
21// use std::collections::HashSet;
22use syn::spanned::Spanned;
23use syn::{Data, GenericParam, Ident, Token, TypeParam, VisRestricted, Visibility};
24
25use super::attr::{FieldKind, Fields, LensAttrs};
26
27pub(crate) fn derive_lens_impl(
28    input: syn::DeriveInput,
29) -> Result<proc_macro2::TokenStream, syn::Error> {
30    match &input.data {
31        Data::Struct(_) => derive_struct(&input),
32        Data::Enum(_) => derive_enum(&input),
33        Data::Union(u) => Err(syn::Error::new(
34            u.union_token.span(),
35            "Lens implementations cannot be derived from unions",
36        )),
37    }
38}
39
40fn derive_struct(input: &syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
41    let struct_type = &input.ident;
42
43    // The generated module should have the same visibility as the struct. If the struct is private
44    // then the generated structs within the new module should be visible only to the module the
45    // original struct was in.
46    let module_vis = &input.vis;
47    let struct_vis = increase_visibility(module_vis);
48
49    let fields = if let syn::Data::Struct(syn::DataStruct { fields, .. }) = &input.data {
50        Fields::<LensAttrs>::parse_ast(fields)?
51    } else {
52        return Err(syn::Error::new(
53            input.span(),
54            "Lens implementations can only be derived from structs with named fields",
55        ));
56    };
57
58    if fields.kind != FieldKind::Named {
59        return Err(syn::Error::new(
60            input.span(),
61            "Lens implementations can only be derived from structs with named fields",
62        ));
63    }
64
65    let twizzled_name = if is_camel_case(&struct_type.to_string()) {
66        let temp_name = format!("{}_derived_lenses", to_snake_case(&struct_type.to_string()));
67        proc_macro2::Ident::new(&temp_name, proc_macro2::Span::call_site())
68    } else {
69        return Err(syn::Error::new(
70            struct_type.span(),
71            "Lens implementations can only be derived from CamelCase types",
72        ));
73    };
74    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
75
76    let mut lens_ty_idents = Vec::new();
77    let mut phantom_decls = Vec::new();
78    let mut phantom_inits = Vec::new();
79    let mut lens_ty_decls = Vec::new();
80
81    for gp in input.generics.params.iter() {
82        if let GenericParam::Type(TypeParam { ident, .. }) = gp {
83            lens_ty_idents.push(quote! {#ident});
84            lens_ty_decls.push(quote! {#ident: 'static});
85            phantom_decls.push(quote! {::std::marker::PhantomData<*const #ident>});
86            phantom_inits.push(quote! {::std::marker::PhantomData});
87        }
88    }
89
90    let lens_ty_generics = quote! {
91        <#(#lens_ty_idents),*>
92    };
93
94    let lens_ty_generics_decls = quote! {
95        <#(#lens_ty_decls),*>
96    };
97
98    // Define lens types for each field
99    let defs = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
100        let field_name = &f.ident.unwrap_named();
101        let struct_docs = format!(
102            "Lens for the field `{field}` on [`{ty}`](super::{ty}).",
103            field = field_name,
104            ty = struct_type,
105        );
106
107        let fn_docs = format!(
108            "Creates a new lens for the field `{field}` on [`{ty}`](super::{ty}). \
109            Use [`{ty}::{field}`](super::{ty}::{field}) instead.",
110            field = field_name,
111            ty = struct_type,
112        );
113
114        quote! {
115            #[doc = #struct_docs]
116            #[allow(non_camel_case_types)]
117            #[derive(::std::cmp::PartialEq, ::std::cmp::Eq)]
118            #struct_vis struct #field_name #lens_ty_generics(#(#phantom_decls),*);
119
120            impl #lens_ty_generics #field_name #lens_ty_generics {
121                #[doc = #fn_docs]
122                pub const fn new()->Self{
123                    Self(#(#phantom_inits),*)
124                }
125            }
126
127            impl #lens_ty_generics_decls ::std::hash::Hash for #field_name #lens_ty_generics {
128                fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
129                    ::std::any::TypeId::of::<Self>().hash(state);
130                }
131            }
132
133            impl #lens_ty_generics ::std::fmt::Debug for #field_name #lens_ty_generics {
134                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
135                    write!(f, "{}:{}", stringify!(#struct_type), stringify!(#field_name))
136                }
137            }
138
139            impl #lens_ty_generics ::std::clone::Clone for #field_name #lens_ty_generics  {
140                fn clone(&self) -> #field_name #lens_ty_generics {
141                    *self
142                }
143            }
144
145            impl #lens_ty_generics ::std::marker::Copy for #field_name #lens_ty_generics {}
146        }
147    });
148
149    // let used_params: HashSet<String> = input
150    //     .generics
151    //     .params
152    //     .iter()
153    //     .flat_map(|gp: &GenericParam| match gp {
154    //         GenericParam::Type(TypeParam { ident, .. }) => Some(ident.to_string()),
155    //         _ => None,
156    //     })
157    //     .collect();
158
159    // let gen_new_param = |name: &str| {
160    //     let mut candidate: String = name.into();
161    //     let mut count = 1usize;
162    //     while used_params.contains(&candidate) {
163    //         candidate = format!("{}_{}", name, count);
164    //         count += 1;
165    //     }
166    //     Ident::new(&candidate, Span::call_site())
167    // };
168
169    //let func_ty_par = gen_new_param("F");
170    //let val_ty_par = gen_new_param("V");
171
172    let impls = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
173        let field_name = &f.ident.unwrap_named();
174        let field_ty = &f.ty;
175        quote! {
176            impl #impl_generics Lens for #twizzled_name::#field_name #lens_ty_generics #where_clause {
177                type Source = #struct_type #ty_generics;
178                type Target = #field_ty;
179
180                fn view<'a>(&self, source: &'a #struct_type #ty_generics) -> ::std::option::Option<LensValue<'a, Self::Target>> {
181                    ::std::option::Option::Some(LensValue::Borrowed(&source.#field_name))
182                }
183            }
184        }
185    });
186
187    let associated_items = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
188        let field_name = &f.ident.unwrap_named();
189        let lens_field_name = f.attrs.lens_name_override.as_ref().unwrap_or(field_name);
190        let field_vis = &f.vis;
191
192        quote! {
193            /// Lens for the corresponding field.
194            #field_vis const #lens_field_name: Wrapper<#twizzled_name::#field_name #lens_ty_generics> = Wrapper(#twizzled_name::#field_name::new());
195        }
196    });
197
198    let mod_docs = format!("Derived lenses for [`{}`].", struct_type);
199    let root_docs = format!("Lens for the whole [`{ty}`](super::{ty}) struct.", ty = struct_type);
200    //let lens_docs = format!("# Lenses for [`{ty}`](super::{ty})", ty = struct_type);
201
202    let expanded = quote! {
203        #[doc = #mod_docs]
204        #module_vis mod #twizzled_name {
205            #(#defs)*
206            #[derive(::std::cmp::PartialEq, ::std::cmp::Eq)]
207            #[doc = #root_docs]
208            #[allow(non_camel_case_types)]
209            #struct_vis struct root #lens_ty_generics(#(#phantom_decls),*);
210
211            impl #lens_ty_generics root #lens_ty_generics {
212                ///
213                pub const fn new()->Self{
214                    Self(#(#phantom_inits),*)
215                }
216            }
217
218            impl #lens_ty_generics_decls ::std::hash::Hash for root #lens_ty_generics {
219                fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
220                    ::std::any::TypeId::of::<Self>().hash(state);
221                }
222            }
223
224            impl #lens_ty_generics ::std::fmt::Debug for root #lens_ty_generics {
225                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
226                    write!(f,"{}",stringify!(#struct_type))
227                }
228            }
229
230            impl #lens_ty_generics ::std::clone::Clone for root #lens_ty_generics  {
231                fn clone(&self) -> root #lens_ty_generics {
232                    *self
233                }
234            }
235
236            impl #lens_ty_generics ::std::marker::Copy for root #lens_ty_generics {}
237        }
238
239        #(#impls)*
240
241        impl #impl_generics Lens for #twizzled_name::root #lens_ty_generics {
242            type Source = #struct_type #ty_generics;
243            type Target = #struct_type #ty_generics;
244
245            fn view<'a>(&self, source: &'a Self::Source) -> ::std::option::Option<LensValue<'a, Self::Target>> {
246                ::std::option::Option::Some(LensValue::Borrowed(source))
247            }
248        }
249
250        #[allow(non_upper_case_globals)]
251        #[doc(hidden)]
252        impl #impl_generics #struct_type #ty_generics #where_clause {
253            #(#associated_items)*
254
255            pub const root: Wrapper<#twizzled_name::root #lens_ty_generics> = Wrapper(#twizzled_name::root::new());
256        }
257    };
258
259    Ok(expanded)
260}
261
262//I stole these from rustc!
263pub(crate) fn char_has_case(c: char) -> bool {
264    c.is_lowercase() || c.is_uppercase()
265}
266
267fn is_camel_case(name: &str) -> bool {
268    let name = name.trim_matches('_');
269    if name.is_empty() {
270        return true;
271    }
272
273    // start with a non-lowercase letter rather than non-uppercase
274    // ones (some scripts don't have a concept of upper/lowercase)
275    !name.chars().next().unwrap().is_lowercase()
276        && !name.contains("__")
277        && !name.chars().collect::<Vec<_>>().windows(2).any(|pair| {
278            // contains a capitalisable character followed by, or preceded by, an underscore
279            char_has_case(pair[0]) && pair[1] == '_' || char_has_case(pair[1]) && pair[0] == '_'
280        })
281}
282
283fn to_snake_case(mut str: &str) -> String {
284    let mut words = vec![];
285    // Preserve leading underscores
286    str = str.trim_start_matches(|c: char| {
287        if c == '_' {
288            words.push(String::new());
289            true
290        } else {
291            false
292        }
293    });
294    for s in str.split('_') {
295        let mut last_upper = false;
296        let mut buf = String::new();
297        if s.is_empty() {
298            continue;
299        }
300        for ch in s.chars() {
301            if !buf.is_empty() && buf != "'" && ch.is_uppercase() && !last_upper {
302                words.push(buf);
303                buf = String::new();
304            }
305            last_upper = ch.is_uppercase();
306            buf.extend(ch.to_lowercase());
307        }
308        words.push(buf);
309    }
310    words.join("_")
311}
312
313fn derive_enum(input: &syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
314    let enum_type = &input.ident;
315
316    // See `derive_struct`
317    let module_vis = &input.vis;
318    let struct_vis = increase_visibility(module_vis);
319
320    let variants = if let syn::Data::Enum(syn::DataEnum { variants, .. }) = &input.data {
321        variants
322    } else {
323        panic!("I don't know what this case being triggered means. Please open an issue!")
324    };
325
326    let usable_variants = variants
327        .iter()
328        .filter_map(|v| match &v.fields {
329            syn::Fields::Unnamed(f) => {
330                if f.unnamed.len() == 1 {
331                    Some((&v.ident, &f.unnamed.first().unwrap().ty))
332                } else {
333                    None
334                }
335            }
336            _ => None,
337        })
338        .collect::<Vec<_>>();
339    if usable_variants.is_empty() {
340        panic!("This enum has no variants which can have Lenses built. A valid variant has exactly one unnamed field. If you think this is unreasonable, please work on https://github.com/rust-lang/rfcs/pull/2593")
341    }
342
343    let twizzled_name = if is_camel_case(&enum_type.to_string()) {
344        let temp_name = format!("{}_derived_lenses", to_snake_case(&enum_type.to_string()));
345        proc_macro2::Ident::new(&temp_name, proc_macro2::Span::call_site())
346    } else {
347        return Err(syn::Error::new(
348            enum_type.span(),
349            "Lens implementations can only be derived from CamelCase types",
350        ));
351    };
352
353    if !input.generics.params.is_empty() {
354        panic!("Lens implementations can only be derived from non-generic enums (for now)");
355    }
356
357    let defs = usable_variants.iter().map(|(variant_name, _)| {
358        quote! {
359            #[allow(non_camel_case_types)]
360            #[derive(:std::marker::Copy, ::std::clone::Clone, ::std::hash::Hash, ::std::cmp::PartialEq, ::std::cmp::Eq)]
361            #struct_vis struct #variant_name();
362
363            impl #variant_name {
364                pub const fn new() -> Self {
365                    Self()
366                }
367            }
368
369            impl ::std::fmt::Debug for #variant_name {
370                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
371                    write!(f,"{}:{}",stringify!(#enum_type), stringify!(#variant_name))
372                }
373            }
374        }
375    });
376
377    let impls = usable_variants.iter().map(|(variant_name, variant_type)| {
378        let name = format!("{}:{}", enum_type, variant_name);
379        quote! {
380            impl Lens for #twizzled_name::#variant_name {
381                type Source = #enum_type;
382                type Target = #variant_type;
383
384                fn view<'a>(&self, source: &'a Self::Source) -> Option<LensValue<'a, Self::Target>> {
385                    if let #enum_type::#variant_name(inner_value) = source {
386                        ::std::option::Option::Some(LensValue::Borrowed(inner_value))
387                    } else {
388                        ::std::panic!("failed")
389                    }
390
391                }
392            }
393
394            impl ::std::fmt::Debug for #twizzled_name::#variant_name {
395                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
396                    f.write_str(#name)
397                }
398            }
399        }
400    });
401
402    let associated_items = usable_variants.iter().map(|(variant_name, _)| {
403        let variant_const_name = to_snake_case(&variant_name.to_string());
404        let variant_const_name = proc_macro2::Ident::new(&variant_const_name, proc_macro2::Span::call_site());
405        quote! {
406            pub const #variant_const_name: #twizzled_name::#variant_name = #twizzled_name::#variant_name::new();
407        }
408    });
409
410    let expanded = quote! {
411        #module_vis mod #twizzled_name {
412            #(#defs)*
413        }
414
415        #(#impls)*
416
417        #[allow(non_upper_case_globals)]
418        impl #enum_type {
419            #(#associated_items)*
420        }
421    };
422
423    Ok(expanded)
424}
425
426/// Increase privite/inherited visiblity to `pub(super)`, `pub(super)` or anything else relative to
427/// `super` to one module higher than that, and leave everything else as is.
428pub(crate) fn increase_visibility(vis: &Visibility) -> Visibility {
429    match vis {
430        // Private structs are promoted to `pub(super)`
431        Visibility::Inherited => Visibility::Restricted(VisRestricted {
432            pub_token: Token![pub](vis.span()),
433            paren_token: syn::token::Paren(vis.span()),
434            in_token: None,
435            path: Box::new(syn::Path::from(Token![super](vis.span()))),
436        }),
437        // `pub(super(::...))` should be promoted to `pub(super::super(:...))`. Checking for this
438        // looks a bit funky.
439        Visibility::Restricted(vis @ VisRestricted { path, .. })
440            if path.segments.first().map(|segment| &segment.ident)
441                == Some(&Ident::from(Token![super](vis.span()))) =>
442        {
443            let mut new_path = syn::Path::from(Token![super](vis.span()));
444            for segment in &path.segments {
445                new_path.segments.push(segment.clone());
446            }
447
448            Visibility::Restricted(VisRestricted {
449                path: Box::new(new_path),
450                // Anything other than `crate` or `super` always needs to be prefixed with `in`
451                in_token: Some(Token![in](vis.span())),
452                ..*vis
453            })
454        }
455        // Everything else stays the same
456        vis => vis.clone(),
457    }
458}