1#![allow(missing_docs)]
2
3use quote::quote;
21use syn::spanned::Spanned;
23use syn::{Data, GenericParam, Ident, Token, TypeParam, VisRestricted, Visibility};
24
25use super::attr::{FieldKind, Fields, LensAttrs};
26
27pub(crate) fn derive_lens_impl(
28 input: syn::DeriveInput,
29) -> Result<proc_macro2::TokenStream, syn::Error> {
30 match &input.data {
31 Data::Struct(_) => derive_struct(&input),
32 Data::Enum(_) => derive_enum(&input),
33 Data::Union(u) => Err(syn::Error::new(
34 u.union_token.span(),
35 "Lens implementations cannot be derived from unions",
36 )),
37 }
38}
39
40fn derive_struct(input: &syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
41 let struct_type = &input.ident;
42
43 let module_vis = &input.vis;
47 let struct_vis = increase_visibility(module_vis);
48
49 let fields = if let syn::Data::Struct(syn::DataStruct { fields, .. }) = &input.data {
50 Fields::<LensAttrs>::parse_ast(fields)?
51 } else {
52 return Err(syn::Error::new(
53 input.span(),
54 "Lens implementations can only be derived from structs with named fields",
55 ));
56 };
57
58 if fields.kind != FieldKind::Named {
59 return Err(syn::Error::new(
60 input.span(),
61 "Lens implementations can only be derived from structs with named fields",
62 ));
63 }
64
65 let twizzled_name = if is_camel_case(&struct_type.to_string()) {
66 let temp_name = format!("{}_derived_lenses", to_snake_case(&struct_type.to_string()));
67 proc_macro2::Ident::new(&temp_name, proc_macro2::Span::call_site())
68 } else {
69 return Err(syn::Error::new(
70 struct_type.span(),
71 "Lens implementations can only be derived from CamelCase types",
72 ));
73 };
74 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
75
76 let mut lens_ty_idents = Vec::new();
77 let mut phantom_decls = Vec::new();
78 let mut phantom_inits = Vec::new();
79 let mut lens_ty_decls = Vec::new();
80
81 for gp in input.generics.params.iter() {
82 if let GenericParam::Type(TypeParam { ident, .. }) = gp {
83 lens_ty_idents.push(quote! {#ident});
84 lens_ty_decls.push(quote! {#ident: 'static});
85 phantom_decls.push(quote! {::std::marker::PhantomData<*const #ident>});
86 phantom_inits.push(quote! {::std::marker::PhantomData});
87 }
88 }
89
90 let lens_ty_generics = quote! {
91 <#(#lens_ty_idents),*>
92 };
93
94 let lens_ty_generics_decls = quote! {
95 <#(#lens_ty_decls),*>
96 };
97
98 let defs = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
100 let field_name = &f.ident.unwrap_named();
101 let struct_docs = format!(
102 "Lens for the field `{field}` on [`{ty}`](super::{ty}).",
103 field = field_name,
104 ty = struct_type,
105 );
106
107 let fn_docs = format!(
108 "Creates a new lens for the field `{field}` on [`{ty}`](super::{ty}). \
109 Use [`{ty}::{field}`](super::{ty}::{field}) instead.",
110 field = field_name,
111 ty = struct_type,
112 );
113
114 quote! {
115 #[doc = #struct_docs]
116 #[allow(non_camel_case_types)]
117 #[derive(::std::cmp::PartialEq, ::std::cmp::Eq)]
118 #struct_vis struct #field_name #lens_ty_generics(#(#phantom_decls),*);
119
120 impl #lens_ty_generics #field_name #lens_ty_generics {
121 #[doc = #fn_docs]
122 pub const fn new()->Self{
123 Self(#(#phantom_inits),*)
124 }
125 }
126
127 impl #lens_ty_generics_decls ::std::hash::Hash for #field_name #lens_ty_generics {
128 fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
129 ::std::any::TypeId::of::<Self>().hash(state);
130 }
131 }
132
133 impl #lens_ty_generics ::std::fmt::Debug for #field_name #lens_ty_generics {
134 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
135 write!(f, "{}:{}", stringify!(#struct_type), stringify!(#field_name))
136 }
137 }
138
139 impl #lens_ty_generics ::std::clone::Clone for #field_name #lens_ty_generics {
140 fn clone(&self) -> #field_name #lens_ty_generics {
141 *self
142 }
143 }
144
145 impl #lens_ty_generics ::std::marker::Copy for #field_name #lens_ty_generics {}
146 }
147 });
148
149 let impls = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
173 let field_name = &f.ident.unwrap_named();
174 let field_ty = &f.ty;
175 quote! {
176 impl #impl_generics Lens for #twizzled_name::#field_name #lens_ty_generics #where_clause {
177 type Source = #struct_type #ty_generics;
178 type Target = #field_ty;
179
180 fn view<'a>(&self, source: &'a #struct_type #ty_generics) -> ::std::option::Option<LensValue<'a, Self::Target>> {
181 ::std::option::Option::Some(LensValue::Borrowed(&source.#field_name))
182 }
183 }
184 }
185 });
186
187 let associated_items = fields.iter().filter(|f| !f.attrs.ignore).map(|f| {
188 let field_name = &f.ident.unwrap_named();
189 let lens_field_name = f.attrs.lens_name_override.as_ref().unwrap_or(field_name);
190 let field_vis = &f.vis;
191
192 quote! {
193 #field_vis const #lens_field_name: Wrapper<#twizzled_name::#field_name #lens_ty_generics> = Wrapper(#twizzled_name::#field_name::new());
195 }
196 });
197
198 let mod_docs = format!("Derived lenses for [`{}`].", struct_type);
199 let root_docs = format!("Lens for the whole [`{ty}`](super::{ty}) struct.", ty = struct_type);
200 let expanded = quote! {
203 #[doc = #mod_docs]
204 #module_vis mod #twizzled_name {
205 #(#defs)*
206 #[derive(::std::cmp::PartialEq, ::std::cmp::Eq)]
207 #[doc = #root_docs]
208 #[allow(non_camel_case_types)]
209 #struct_vis struct root #lens_ty_generics(#(#phantom_decls),*);
210
211 impl #lens_ty_generics root #lens_ty_generics {
212 pub const fn new()->Self{
214 Self(#(#phantom_inits),*)
215 }
216 }
217
218 impl #lens_ty_generics_decls ::std::hash::Hash for root #lens_ty_generics {
219 fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
220 ::std::any::TypeId::of::<Self>().hash(state);
221 }
222 }
223
224 impl #lens_ty_generics ::std::fmt::Debug for root #lens_ty_generics {
225 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
226 write!(f,"{}",stringify!(#struct_type))
227 }
228 }
229
230 impl #lens_ty_generics ::std::clone::Clone for root #lens_ty_generics {
231 fn clone(&self) -> root #lens_ty_generics {
232 *self
233 }
234 }
235
236 impl #lens_ty_generics ::std::marker::Copy for root #lens_ty_generics {}
237 }
238
239 #(#impls)*
240
241 impl #impl_generics Lens for #twizzled_name::root #lens_ty_generics {
242 type Source = #struct_type #ty_generics;
243 type Target = #struct_type #ty_generics;
244
245 fn view<'a>(&self, source: &'a Self::Source) -> ::std::option::Option<LensValue<'a, Self::Target>> {
246 ::std::option::Option::Some(LensValue::Borrowed(source))
247 }
248 }
249
250 #[allow(non_upper_case_globals)]
251 #[doc(hidden)]
252 impl #impl_generics #struct_type #ty_generics #where_clause {
253 #(#associated_items)*
254
255 pub const root: Wrapper<#twizzled_name::root #lens_ty_generics> = Wrapper(#twizzled_name::root::new());
256 }
257 };
258
259 Ok(expanded)
260}
261
262pub(crate) fn char_has_case(c: char) -> bool {
264 c.is_lowercase() || c.is_uppercase()
265}
266
267fn is_camel_case(name: &str) -> bool {
268 let name = name.trim_matches('_');
269 if name.is_empty() {
270 return true;
271 }
272
273 !name.chars().next().unwrap().is_lowercase()
276 && !name.contains("__")
277 && !name.chars().collect::<Vec<_>>().windows(2).any(|pair| {
278 char_has_case(pair[0]) && pair[1] == '_' || char_has_case(pair[1]) && pair[0] == '_'
280 })
281}
282
283fn to_snake_case(mut str: &str) -> String {
284 let mut words = vec![];
285 str = str.trim_start_matches(|c: char| {
287 if c == '_' {
288 words.push(String::new());
289 true
290 } else {
291 false
292 }
293 });
294 for s in str.split('_') {
295 let mut last_upper = false;
296 let mut buf = String::new();
297 if s.is_empty() {
298 continue;
299 }
300 for ch in s.chars() {
301 if !buf.is_empty() && buf != "'" && ch.is_uppercase() && !last_upper {
302 words.push(buf);
303 buf = String::new();
304 }
305 last_upper = ch.is_uppercase();
306 buf.extend(ch.to_lowercase());
307 }
308 words.push(buf);
309 }
310 words.join("_")
311}
312
313fn derive_enum(input: &syn::DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
314 let enum_type = &input.ident;
315
316 let module_vis = &input.vis;
318 let struct_vis = increase_visibility(module_vis);
319
320 let variants = if let syn::Data::Enum(syn::DataEnum { variants, .. }) = &input.data {
321 variants
322 } else {
323 panic!("I don't know what this case being triggered means. Please open an issue!")
324 };
325
326 let usable_variants = variants
327 .iter()
328 .filter_map(|v| match &v.fields {
329 syn::Fields::Unnamed(f) => {
330 if f.unnamed.len() == 1 {
331 Some((&v.ident, &f.unnamed.first().unwrap().ty))
332 } else {
333 None
334 }
335 }
336 _ => None,
337 })
338 .collect::<Vec<_>>();
339 if usable_variants.is_empty() {
340 panic!("This enum has no variants which can have Lenses built. A valid variant has exactly one unnamed field. If you think this is unreasonable, please work on https://github.com/rust-lang/rfcs/pull/2593")
341 }
342
343 let twizzled_name = if is_camel_case(&enum_type.to_string()) {
344 let temp_name = format!("{}_derived_lenses", to_snake_case(&enum_type.to_string()));
345 proc_macro2::Ident::new(&temp_name, proc_macro2::Span::call_site())
346 } else {
347 return Err(syn::Error::new(
348 enum_type.span(),
349 "Lens implementations can only be derived from CamelCase types",
350 ));
351 };
352
353 if !input.generics.params.is_empty() {
354 panic!("Lens implementations can only be derived from non-generic enums (for now)");
355 }
356
357 let defs = usable_variants.iter().map(|(variant_name, _)| {
358 quote! {
359 #[allow(non_camel_case_types)]
360 #[derive(:std::marker::Copy, ::std::clone::Clone, ::std::hash::Hash, ::std::cmp::PartialEq, ::std::cmp::Eq)]
361 #struct_vis struct #variant_name();
362
363 impl #variant_name {
364 pub const fn new() -> Self {
365 Self()
366 }
367 }
368
369 impl ::std::fmt::Debug for #variant_name {
370 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
371 write!(f,"{}:{}",stringify!(#enum_type), stringify!(#variant_name))
372 }
373 }
374 }
375 });
376
377 let impls = usable_variants.iter().map(|(variant_name, variant_type)| {
378 let name = format!("{}:{}", enum_type, variant_name);
379 quote! {
380 impl Lens for #twizzled_name::#variant_name {
381 type Source = #enum_type;
382 type Target = #variant_type;
383
384 fn view<'a>(&self, source: &'a Self::Source) -> Option<LensValue<'a, Self::Target>> {
385 if let #enum_type::#variant_name(inner_value) = source {
386 ::std::option::Option::Some(LensValue::Borrowed(inner_value))
387 } else {
388 ::std::panic!("failed")
389 }
390
391 }
392 }
393
394 impl ::std::fmt::Debug for #twizzled_name::#variant_name {
395 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
396 f.write_str(#name)
397 }
398 }
399 }
400 });
401
402 let associated_items = usable_variants.iter().map(|(variant_name, _)| {
403 let variant_const_name = to_snake_case(&variant_name.to_string());
404 let variant_const_name = proc_macro2::Ident::new(&variant_const_name, proc_macro2::Span::call_site());
405 quote! {
406 pub const #variant_const_name: #twizzled_name::#variant_name = #twizzled_name::#variant_name::new();
407 }
408 });
409
410 let expanded = quote! {
411 #module_vis mod #twizzled_name {
412 #(#defs)*
413 }
414
415 #(#impls)*
416
417 #[allow(non_upper_case_globals)]
418 impl #enum_type {
419 #(#associated_items)*
420 }
421 };
422
423 Ok(expanded)
424}
425
426pub(crate) fn increase_visibility(vis: &Visibility) -> Visibility {
429 match vis {
430 Visibility::Inherited => Visibility::Restricted(VisRestricted {
432 pub_token: Token),
433 paren_token: syn::token::Paren(vis.span()),
434 in_token: None,
435 path: Box::new(syn::Path::from(Token))),
436 }),
437 Visibility::Restricted(vis @ VisRestricted { path, .. })
440 if path.segments.first().map(|segment| &segment.ident)
441 == Some(&Ident::from(Token))) =>
442 {
443 let mut new_path = syn::Path::from(Token));
444 for segment in &path.segments {
445 new_path.segments.push(segment.clone());
446 }
447
448 Visibility::Restricted(VisRestricted {
449 path: Box::new(new_path),
450 in_token: Some(Token)),
452 ..*vis
453 })
454 }
455 vis => vis.clone(),
457 }
458}