1use std::collections::HashMap;
2use std::error::Error as StdError;
3use std::fmt::{self, Debug};
4use std::io::{self, Write};
5use std::str::FromStr;
6use std::result;
7
8use lazy_static::lazy_static;
9use regex::{Captures, Regex};
10use serde::de;
11use serde::de::IntoDeserializer;
12
13use crate::parse::Parser;
14use crate::synonym::SynonymMap;
15
16use self::Value::{Switch, Counted, Plain, List};
17use self::Error::{Usage, Argv, NoMatch, Deserialize, WithProgramUsage, Help, Version};
18
19use crate::cap_or_empty;
20
21#[derive(Debug)]
48pub enum Error {
49 Usage(String),
55
56 Argv(String),
64
65 NoMatch,
71
72 Deserialize(String),
75
76 WithProgramUsage(Box<Error>, String),
79
80 Help,
83
84 Version(String),
89}
90
91impl Error {
92 pub fn fatal(&self) -> bool {
98 match *self {
99 Help | Version(..) => false,
100 Usage(..) | Argv(..) | NoMatch | Deserialize(..) => true,
101 WithProgramUsage(ref b, _) => b.fatal(),
102 }
103 }
104
105 pub fn exit(&self) -> ! {
112 if self.fatal() {
113 werr!("{}\n", self);
114 ::std::process::exit(1)
115 } else {
116 let _ = writeln!(&mut io::stdout(), "{}", self);
117 ::std::process::exit(0)
118 }
119 }
120}
121
122type Result<T> = result::Result<T, Error>;
123
124impl fmt::Display for Error {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 match *self {
127 WithProgramUsage(ref other, ref usage) => {
128 let other = other.to_string();
129 if other.is_empty() {
130 write!(f, "{}", usage)
131 } else {
132 write!(f, "{}\n\n{}", other, usage)
133 }
134 }
135 Help => write!(f, ""),
136 NoMatch => write!(f, "Invalid arguments."),
137 Usage(ref s) |
138 Argv(ref s) |
139 Deserialize(ref s) |
140 Version(ref s) => write!(f, "{}", s),
141 }
142 }
143}
144
145impl StdError for Error {
146 fn source(&self) -> Option<&(dyn StdError + 'static)> {
147 match *self {
148 WithProgramUsage(ref cause, _) => Some(&**cause),
149 _ => None,
150 }
151 }
152}
153
154impl de::Error for Error {
155 fn custom<T: fmt::Display>(msg: T) -> Self {
156 Error::Deserialize(msg.to_string())
157 }
158}
159
160#[derive(Clone, Debug)]
164pub struct Docopt {
165 p: Parser,
166 argv: Option<Vec<String>>,
167 options_first: bool,
168 help: bool,
169 version: Option<String>,
170}
171
172impl Docopt {
173 pub fn new<S>(usage: S) -> Result<Docopt>
181 where S: ::std::ops::Deref<Target=str> {
182 Parser::new(usage.deref())
183 .map_err(Usage)
184 .map(|p| Docopt {
185 p: p,
186 argv: None,
187 options_first: false,
188 help: true,
189 version: None,
190 })
191 }
192
193 pub fn deserialize<'a, 'de: 'a, D>(&'a self) -> Result<D>
201 where D: de::Deserialize<'de>
202 {
203 self.parse().and_then(|vals| vals.deserialize())
204 }
205
206 pub fn parse(&self) -> Result<ArgvMap> {
221 let argv = self.argv.clone().unwrap_or_else(Docopt::get_argv);
222 let vals =
223 self.p.parse_argv(argv, self.options_first)
224 .map_err(|s| self.err_with_usage(Argv(s)))
225 .and_then(|argv|
226 match self.p.matches(&argv) {
227 Some(m) => Ok(ArgvMap { map: m }),
228 None => Err(self.err_with_usage(NoMatch)),
229 })?;
230 if self.help && vals.get_bool("--help") {
231 return Err(self.err_with_full_doc(Help));
232 }
233 match self.version {
234 Some(ref v) if vals.get_bool("--version") => {
235 return Err(Version(v.clone()))
236 }
237 _ => {},
238 }
239 Ok(vals)
240 }
241
242 pub fn argv<I, S>(mut self, argv: I) -> Docopt
251 where I: IntoIterator<Item=S>, S: AsRef<str> {
252 self.argv = Some(
253 argv.into_iter().skip(1).map(|s| s.as_ref().to_owned()).collect()
254 );
255 self
256 }
257
258 pub fn options_first(mut self, yes: bool) -> Docopt {
265 self.options_first = yes;
266 self
267 }
268
269 pub fn help(mut self, yes: bool) -> Docopt {
280 self.help = yes;
281 self
282 }
283
284 pub fn version(mut self, version: Option<String>) -> Docopt {
294 self.version = version;
295 self
296 }
297
298 #[doc(hidden)]
299 pub fn parser(&self) -> &Parser {
300 &self.p
301 }
302
303 fn err_with_usage(&self, e: Error) -> Error {
304 WithProgramUsage(
305 Box::new(e), self.p.usage.trim().into())
306 }
307
308 fn err_with_full_doc(&self, e: Error) -> Error {
309 WithProgramUsage(
310 Box::new(e), self.p.full_doc.trim().into())
311 }
312
313 fn get_argv() -> Vec<String> {
314 ::std::env::args().skip(1).collect()
316 }
317}
318
319#[derive(Clone)]
326pub struct ArgvMap {
327 #[doc(hidden)]
328 pub map: SynonymMap<String, Value>,
329}
330
331impl ArgvMap {
332 pub fn deserialize<'de, T: de::Deserialize<'de>>(self) -> Result<T> {
382 de::Deserialize::deserialize(&mut Deserializer {
383 vals: self,
384 stack: vec![],
385 })
386 }
387
388 pub fn get_bool(&self, key: &str) -> bool {
391 self.find(key).map_or(false, |v| v.as_bool())
392 }
393
394 pub fn get_count(&self, key: &str) -> u64 {
397 self.find(key).map_or(0, |v| v.as_count())
398 }
399
400 pub fn get_str(&self, key: &str) -> &str {
403 self.find(key).map_or("", |v| v.as_str())
404 }
405
406 pub fn get_vec(&self, key: &str) -> Vec<&str> {
409 self.find(key).map(|v| v.as_vec()).unwrap_or(vec!())
410 }
411
412 pub fn find(&self, key: &str) -> Option<&Value> {
417 self.map.find(&key.into())
418 }
419
420 pub fn len(&self) -> usize {
422 self.map.len()
423 }
424
425 #[doc(hidden)]
430 pub fn key_to_struct_field(name: &str) -> String {
431 lazy_static! {
432 static ref RE: Regex = regex!(
433 r"^(?:--?(?P<flag>\S+)|(?:(?P<argu>\p{Lu}+)|<(?P<argb>[^>]+)>)|(?P<cmd>\S+))$"
434 );
435 }
436 fn sanitize(name: &str) -> String {
437 name.replace("-", "_")
438 }
439
440 RE.replace(name, |cap: &Captures<'_>| {
441 let (flag, cmd) = (
442 cap_or_empty(cap, "flag"),
443 cap_or_empty(cap, "cmd"),
444 );
445 let (argu, argb) = (
446 cap_or_empty(cap, "argu"),
447 cap_or_empty(cap, "argb"),
448 );
449 let (prefix, name) =
450 if !flag.is_empty() {
451 ("flag_", flag)
452 } else if !argu.is_empty() {
453 ("arg_", argu)
454 } else if !argb.is_empty() {
455 ("arg_", argb)
456 } else if !cmd.is_empty() {
457 ("cmd_", cmd)
458 } else {
459 panic!("Unknown ArgvMap key: '{}'", name)
460 };
461 let mut prefix = prefix.to_owned();
462 prefix.push_str(&sanitize(name));
463 prefix
464 }).into_owned()
465 }
466
467 #[doc(hidden)]
469 pub fn struct_field_to_key(field: &str) -> String {
470 lazy_static! {
471 static ref FLAG: Regex = regex!(r"^flag_");
472 static ref ARG: Regex = regex!(r"^arg_");
473 static ref LETTERS: Regex = regex!(r"^\p{Lu}+$");
474 static ref CMD: Regex = regex!(r"^cmd_");
475 }
476 fn desanitize(name: &str) -> String {
477 name.replace("_", "-")
478 }
479 let name =
480 if field.starts_with("flag_") {
481 let name = FLAG.replace(field, "");
482 let mut pre_name = (if name.len() == 1 { "-" } else { "--" })
483 .to_owned();
484 pre_name.push_str(&*name);
485 pre_name
486 } else if field.starts_with("arg_") {
487 let name = ARG.replace(field, "").into_owned();
488 if LETTERS.is_match(&name) {
489 name
490 } else {
491 let mut pre_name = "<".to_owned();
492 pre_name.push_str(&*name);
493 pre_name.push('>');
494 pre_name
495 }
496 } else if field.starts_with("cmd_") {
497 CMD.replace(field, "").into_owned()
498 } else {
499 panic!("Unrecognized struct field: '{}'", field)
500 };
501 desanitize(&*name)
502 }
503}
504
505impl fmt::Debug for ArgvMap {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 if self.len() == 0 {
508 return write!(f, "{{EMPTY}}");
509 }
510
511 let reverse: HashMap<&String, &String> =
514 self.map.synonyms().map(|(from, to)| (to, from)).collect();
515 let mut keys: Vec<&String> = self.map.keys().collect();
516 keys.sort();
517 let mut first = true;
518 for &k in &keys {
519 if !first { write!(f, "\n")?; } else { first = false; }
520 match reverse.get(&k) {
521 None => {
522 write!(f, "{} => {:?}", k, self.map.get(k))?
523 }
524 Some(s) => {
525 write!(f, "{}, {} => {:?}", s, k, self.map.get(k))?
526 }
527 }
528 }
529 Ok(())
530 }
531}
532
533#[derive(Clone, Debug, PartialEq)]
541pub enum Value {
542 Switch(bool),
547
548 Counted(u64),
550
551 Plain(Option<String>),
557
558 List(Vec<String>),
562}
563
564impl Value {
565 pub fn as_bool(&self) -> bool {
571 match *self {
572 Switch(b) => b,
573 Counted(n) => n > 0,
574 Plain(None) => false,
575 Plain(Some(_)) => true,
576 List(ref vs) => !vs.is_empty(),
577 }
578 }
579
580 pub fn as_count(&self) -> u64 {
586 match *self {
587 Switch(b) => if b { 1 } else { 0 },
588 Counted(n) => n,
589 Plain(None) => 0,
590 Plain(Some(_)) => 1,
591 List(ref vs) => vs.len() as u64,
592 }
593 }
594
595 pub fn as_str(&self) -> &str {
599 match *self {
600 Switch(_) | Counted(_) | Plain(None) | List(_) => "",
601 Plain(Some(ref s)) => &**s,
602 }
603 }
604
605 pub fn as_vec(&self) -> Vec<&str> {
610 match *self {
611 Switch(_) | Counted(_) | Plain(None) => vec![],
612 Plain(Some(ref s)) => vec![&**s],
613 List(ref vs) => vs.iter().map(|s| &**s).collect(),
614 }
615 }
616}
617
618pub struct Deserializer<'de> {
639 vals: ArgvMap,
640 stack: Vec<DeserializerItem<'de>>,
641}
642
643#[derive(Debug)]
644struct DeserializerItem<'de> {
645 key: String,
646 struct_field: &'de str,
647 val: Option<Value>,
648}
649
650macro_rules! derr(
651 ($($arg:tt)*) => (return Err(Deserialize(format!($($arg)*))))
652);
653
654impl<'de> Deserializer<'de> {
655 fn push(&mut self, struct_field: &'de str) {
656 let key = ArgvMap::struct_field_to_key(struct_field);
657 self.stack
658 .push(DeserializerItem {
659 key: key.clone(),
660 struct_field: struct_field,
661 val: self.vals.find(&*key).cloned(),
662 });
663 }
664
665 fn pop(&mut self) -> Result<DeserializerItem<'_>> {
666 match self.stack.pop() {
667 None => derr!("Could not deserialize value into unknown key."),
668 Some(it) => Ok(it),
669 }
670 }
671
672 fn pop_key_val(&mut self) -> Result<(String, Value)> {
673 let it = self.pop()?;
674 match it.val {
675 None => {
676 derr!("Could not find argument '{}' (from struct field '{}').
677Note that each struct field must have the right key prefix, which must
678be one of `cmd_`, `flag_` or `arg_`.",
679 it.key,
680 it.struct_field)
681 }
682 Some(v) => Ok((it.key, v)),
683 }
684 }
685
686 fn pop_val(&mut self) -> Result<Value> {
687 let (_, v) = self.pop_key_val()?;
688 Ok(v)
689 }
690
691 fn to_number<T>(&mut self, expect: &str) -> Result<T>
692 where T: FromStr + ToString,
693 <T as FromStr>::Err: Debug
694 {
695 let (k, v) = self.pop_key_val()?;
696 match v {
697 Counted(n) => Ok(n.to_string().parse().unwrap()), _ => {
699 if v.as_str().trim().is_empty() {
700 Ok("0".parse().unwrap()) } else {
702 match v.as_str().parse() {
703 Err(_) => {
704 derr!("Could not deserialize '{}' to {} for '{}'.",
705 v.as_str(),
706 expect,
707 k)
708 }
709 Ok(v) => Ok(v),
710 }
711 }
712 }
713 }
714 }
715
716 fn to_float(&mut self, expect: &str) -> Result<f64> {
717 let (k, v) = self.pop_key_val()?;
718 match v {
719 Counted(n) => Ok(n as f64),
720 _ => {
721 match v.as_str().parse() {
722 Err(_) => {
723 derr!("Could not deserialize '{}' to {} for '{}'.",
724 v.as_str(),
725 expect,
726 k)
727 }
728 Ok(v) => Ok(v),
729 }
730 }
731 }
732 }
733}
734
735macro_rules! deserialize_num {
736 ($name:ident, $method:ident, $ty:ty) => (
737 fn $name<V>(self, visitor: V) -> Result<V::Value>
738 where V: de::Visitor<'de>
739 {
740 visitor.$method(self.to_number::<$ty>(stringify!($ty)).map(|n| n as $ty)?)
741 }
742 );
743}
744
745impl<'a, 'de> ::serde::Deserializer<'de> for &'a mut Deserializer<'de> {
746 type Error = Error;
747
748 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
749 where V: de::Visitor<'de>
750 {
751 unimplemented!()
752 }
753
754 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
755 where V: de::Visitor<'de>
756 {
757 visitor.visit_bool(self.pop_val().map(|v| v.as_bool())?)
758 }
759
760 deserialize_num!(deserialize_i8, visit_i8, i8);
762 deserialize_num!(deserialize_i16, visit_i16, i16);
763 deserialize_num!(deserialize_i32, visit_i32, i32);
764 deserialize_num!(deserialize_i64, visit_i64, i64);
765 deserialize_num!(deserialize_u8, visit_u8, u8);
766 deserialize_num!(deserialize_u16, visit_u16, u16);
767 deserialize_num!(deserialize_u32, visit_u32, u32);
768 deserialize_num!(deserialize_u64, visit_u64, u64);
769
770 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
771 where V: de::Visitor<'de>
772 {
773 visitor.visit_f32(self.to_float("f32").map(|n| n as f32)?)
774 }
775
776 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
777 where V: de::Visitor<'de>
778 {
779 visitor.visit_f64(self.to_float("f64")?)
780 }
781
782 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
783 where V: de::Visitor<'de>
784 {
785 let (k, v) = self.pop_key_val()?;
786 let vstr = v.as_str();
787 match vstr.chars().count() {
788 1 => visitor.visit_char(vstr.chars().next().unwrap()),
789 _ => derr!("Could not deserialize '{}' into char for '{}'.", vstr, k),
790 }
791 }
792
793 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
794 where V: de::Visitor<'de>
795 {
796 let s = self.pop_val()?;
797 visitor.visit_str(s.as_str())
798 }
799
800 fn deserialize_string<V>(self, visitor:V) -> Result<V::Value>
801 where V: de::Visitor<'de>
802 {
803 self.deserialize_str(visitor)
804 }
805
806 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
807 where V: de::Visitor<'de>
808 {
809 unimplemented!()
810 }
811
812 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value>
813 where V: de::Visitor<'de>
814 {
815 unimplemented!()
816 }
817
818 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
819 where V: de::Visitor<'de>
820 {
821 let is_some = match self.stack.last() {
822 None => derr!("Could not deserialize value into unknown key."),
823 Some(it) => it.val.as_ref().map_or(false, |v| v.as_bool()),
824 };
825 if is_some {
826 visitor.visit_some(self)
827 } else {
828 visitor.visit_none()
829 }
830 }
831
832 fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
833 where V: de::Visitor<'de>
834 {
835 panic!("I don't know how to read into a nil value.")
837 }
838
839 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
840 where V: de::Visitor<'de>
841 {
842 visitor.visit_unit()
843 }
844
845 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
846 where V: de::Visitor<'de>
847 {
848 visitor.visit_newtype_struct(self)
849 }
850
851 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
852 where V: de::Visitor<'de>
853 {
854 unimplemented!()
855 }
856
857 fn deserialize_tuple_struct<V>(self,
858 _name: &'static str,
859 _len: usize,
860 _visitor: V)
861 -> Result<V::Value>
862 where V: de::Visitor<'de>
863 {
864 unimplemented!()
865 }
866
867 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
868 where V: de::Visitor<'de>
869 {
870 unimplemented!()
871 }
872
873 fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value>
874 where V: de::Visitor<'de>
875 {
876 let (key, struct_field, val) = match self.stack.pop() {
877 None => derr!("Could not deserialize value into unknown key."),
878 Some(DeserializerItem {key, struct_field, val}) => (key, struct_field, val),
879 };
880 let list = val.unwrap_or(List(vec![]));
881 let vals = list.as_vec();
882 for val in vals.iter().rev() {
883 self.stack
884 .push(DeserializerItem {
885 key: key.clone(),
886 struct_field: struct_field,
887 val: Some(Plain(Some((*val).into()))),
888 });
889 }
890 visitor.visit_seq(SeqDeserializer::new(&mut self, vals.len()))
891 }
892
893 fn deserialize_struct<V>(mut self,
894 _: &str,
895 fields: &'static [&'static str],
896 visitor: V)
897 -> Result<V::Value>
898 where V: de::Visitor<'de>
899 {
900 visitor.visit_seq(StructDeserializer::new(&mut self, fields))
901 }
902
903 fn deserialize_enum<V>(self, _name: &str, variants: &[&str], visitor: V) -> Result<V::Value>
904 where V: de::Visitor<'de>
905 {
906 let v = self.pop_val()?.as_str().to_lowercase();
907 let s = match variants.iter().find(|&n| n.to_lowercase() == v) {
908 Some(s) => s,
909 None => {
910 derr!("Could not match '{}' with any of \
911 the allowed variants: {:?}",
912 v,
913 variants)
914 }
915 };
916 visitor.visit_enum(s.into_deserializer())
917 }
918
919 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
920 where V: de::Visitor<'de>
921 {
922 self.deserialize_str(visitor)
923 }
924
925 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
926 where V: de::Visitor<'de>
927 {
928 self.deserialize_any(visitor)
929 }
930}
931
932struct SeqDeserializer<'a, 'de: 'a> {
933 de: &'a mut Deserializer<'de>,
934 len: usize,
935}
936
937impl<'a, 'de> SeqDeserializer<'a, 'de> {
938 fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self {
939 SeqDeserializer { de: de, len: len }
940 }
941}
942
943impl<'a, 'de> de::SeqAccess<'de> for SeqDeserializer<'a, 'de> {
944 type Error = Error;
945
946 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
947 where T: de::DeserializeSeed<'de>
948 {
949 if self.len == 0 {
950 return Ok(None);
951 }
952 self.len -= 1;
953 seed.deserialize(&mut *self.de).map(Some)
954 }
955
956 fn size_hint(&self) -> Option<usize> {
957 return Some(self.len);
958 }
959}
960
961struct StructDeserializer<'a, 'de: 'a> {
962 de: &'a mut Deserializer<'de>,
963 fields: &'static [&'static str],
964}
965
966impl<'a, 'de> StructDeserializer<'a, 'de> {
967 fn new(de: &'a mut Deserializer<'de>, fields: &'static [&'static str]) -> Self {
968 StructDeserializer {
969 de: de,
970 fields: fields,
971 }
972 }
973}
974
975impl<'a, 'de> de::SeqAccess<'de> for StructDeserializer<'a, 'de> {
976 type Error = Error;
977
978 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
979 where T: de::DeserializeSeed<'de>
980 {
981 if self.fields.len() == 0 {
982 return Ok(None);
983 }
984 self.de.push(self.fields[0]);
985 self.fields = &self.fields[1..];
986 seed.deserialize(&mut *self.de).map(Some)
987 }
988
989 fn size_hint(&self) -> Option<usize> {
990 return Some(self.fields.len());
991 }
992}