vizia_derive/
data.rs

1#![allow(dead_code)]
2
3// Adapted from Druid data.rs
4
5// Copyright 2019 The Druid Authors.
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10//
11//     http://www.apache.org/licenses/LICENSE-2.0
12//
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
18
19use crate::attr::{DataAttr, Field, FieldKind, Fields};
20
21use quote::{quote, quote_spanned};
22use syn::{spanned::Spanned, Data, DataEnum, DataStruct};
23
24pub(crate) fn derive_data_impl(
25    input: syn::DeriveInput,
26) -> Result<proc_macro2::TokenStream, syn::Error> {
27    match &input.data {
28        Data::Struct(s) => derive_struct(&input, s),
29        Data::Enum(e) => derive_enum(&input, e),
30        Data::Union(u) => Err(syn::Error::new(
31            u.union_token.span(),
32            "Data implementations cannot be derived from unions",
33        )),
34    }
35}
36
37fn derive_struct(
38    input: &syn::DeriveInput,
39    s: &DataStruct,
40) -> Result<proc_macro2::TokenStream, syn::Error> {
41    let ident = &input.ident;
42    let impl_generics = generics_bounds(&input.generics);
43    let (_, ty_generics, where_clause) = &input.generics.split_for_impl();
44
45    let fields = Fields::<DataAttr>::parse_ast(&s.fields)?;
46
47    let diff = if fields.len() > 0 {
48        let same_fns =
49            fields.iter().filter(|f| f.attrs != DataAttr::Ignore).map(Field::same_fn_path_tokens);
50        let fields = fields.iter().filter(|f| f.attrs != DataAttr::Ignore).map(Field::ident_tokens);
51        quote!( #( #same_fns(&self.#fields, &other.#fields) )&&* )
52    } else {
53        quote!(true)
54    };
55
56    let res = quote! {
57        impl<#impl_generics> Data for #ident #ty_generics #where_clause {
58            fn same(&self, other: &Self) -> bool {
59                #diff
60            }
61        }
62    };
63
64    Ok(res)
65}
66
67fn ident_from_str(s: &str) -> proc_macro2::Ident {
68    proc_macro2::Ident::new(s, proc_macro2::Span::call_site())
69}
70
71fn is_c_style_enum(s: &DataEnum) -> bool {
72    s.variants.iter().all(|variant| match &variant.fields {
73        syn::Fields::Named(fs) => fs.named.is_empty(),
74        syn::Fields::Unnamed(fs) => fs.unnamed.is_empty(),
75        syn::Fields::Unit => true,
76    })
77}
78
79fn derive_enum(
80    input: &syn::DeriveInput,
81    s: &DataEnum,
82) -> Result<proc_macro2::TokenStream, syn::Error> {
83    let ident = &input.ident;
84    let impl_generics = generics_bounds(&input.generics);
85    let (_, ty_generics, where_clause) = &input.generics.split_for_impl();
86
87    if is_c_style_enum(s) {
88        let res = quote! {
89            impl<#impl_generics> Data for #ident #ty_generics #where_clause {
90                fn same(&self, other: &Self) -> bool { self == other }
91            }
92        };
93        return Ok(res);
94    }
95
96    let cases: Vec<proc_macro2::TokenStream> = s
97        .variants
98        .iter()
99        .map(|variant| {
100            let fields = Fields::<DataAttr>::parse_ast(&variant.fields)?;
101            let variant = &variant.ident;
102
103            // the various inner `same()` calls, to the right of the match arm.
104            let tests: Vec<_> = fields
105                .iter()
106                .filter(|f| f.attrs != DataAttr::Ignore)
107                .map(|field| {
108                    let same_fn = field.same_fn_path_tokens();
109                    let var_left = ident_from_str(&format!("__self_{}", field.ident_string()));
110                    let var_right = ident_from_str(&format!("__other_{}", field.ident_string()));
111                    quote!( #same_fn(#var_left, #var_right) )
112                })
113                .collect();
114
115            if let FieldKind::Named = fields.kind {
116                let lefts: Vec<_> = fields
117                    .iter()
118                    .map(|field| {
119                        let ident = field.ident_tokens();
120                        let var = ident_from_str(&format!("__self_{}", field.ident_string()));
121                        quote!( #ident: #var )
122                    })
123                    .collect();
124                let rights: Vec<_> = fields
125                    .iter()
126                    .map(|field| {
127                        let ident = field.ident_tokens();
128                        let var = ident_from_str(&format!("__other_{}", field.ident_string()));
129                        quote!( #ident: #var )
130                    })
131                    .collect();
132
133                Ok(quote! {
134                    (#ident :: #variant { #( #lefts ),* }, #ident :: #variant { #( #rights ),* }) => {
135                        #( #tests )&&*
136                    }
137                })
138            } else {
139                let vars_left: Vec<_> = fields
140                    .iter()
141                    .map(|field| ident_from_str(&format!("__self_{}", field.ident_string())))
142                    .collect();
143                let vars_right: Vec<_> = fields
144                    .iter()
145                    .map(|field| ident_from_str(&format!("__other_{}", field.ident_string())))
146                    .collect();
147
148                if fields.iter().count() > 0 {
149                    Ok(quote! {
150                        ( #ident :: #variant( #(#vars_left),* ),  #ident :: #variant( #(#vars_right),* )) => {
151                            #( #tests )&&*
152                        }
153                    })
154                } else {
155                    Ok(quote! {
156                        ( #ident :: #variant ,  #ident :: #variant ) => { true }
157                    })
158                }
159            }
160        })
161        .collect::<Result<Vec<proc_macro2::TokenStream>, syn::Error>>()?;
162
163    let res = quote! {
164        impl<#impl_generics> Data for #ident #ty_generics #where_clause {
165            fn same(&self, other: &Self) -> bool {
166                match (self, other) {
167                    #( #cases ),*
168                    _ => false,
169                }
170            }
171        }
172    };
173
174    Ok(res)
175}
176
177fn generics_bounds(generics: &syn::Generics) -> proc_macro2::TokenStream {
178    let res = generics.params.iter().map(|gp| {
179        use syn::GenericParam::*;
180        match gp {
181            Type(ty) => {
182                let ident = &ty.ident;
183                let bounds = &ty.bounds;
184                if bounds.is_empty() {
185                    quote_spanned!(ty.span()=> #ident : Data)
186                } else {
187                    quote_spanned!(ty.span()=> #ident : #bounds + Data)
188                }
189            }
190            Lifetime(lf) => quote!(#lf),
191            Const(cst) => quote!(#cst),
192        }
193    });
194
195    quote!( #( #res, )* )
196}