diff options
-rw-r--r-- | 2017/src/bin/day07.rs | 69 |
1 files changed, 65 insertions, 4 deletions
diff --git a/2017/src/bin/day07.rs b/2017/src/bin/day07.rs index 9d2d75a..f579cbc 100644 --- a/2017/src/bin/day07.rs +++ b/2017/src/bin/day07.rs @@ -9,7 +9,42 @@ struct Program { disc: Vec<Rc<RefCell<Program>>>, } -fn solve1(input: &str) -> String { +impl Program { + fn total_weight(&self) -> u32 { + self.weight + self.disc.iter().map(|p| p.borrow().total_weight()).sum::<u32>() + } + + fn balance(&self) -> Option<u32> { + for child in &self.disc { + match child.borrow().balance() { + Some(x) => return Some(x), + None => (), + } + } + if self.disc.is_empty() { + return None; + } + + let mut disc = self.disc.clone(); + disc.sort_by_key(|p| p.borrow().total_weight()); + let max = disc.iter().map(|p| p.borrow().total_weight()).max().unwrap(); + let min = disc.iter().map(|p| p.borrow().total_weight()).min().unwrap(); + let avg = disc.iter().map(|p| p.borrow().total_weight()).sum::<u32>() / disc.len() as u32; + if min == max { + return None; + } else if avg - min < max - avg { + let child = disc.last().unwrap().borrow(); + let diff = child.total_weight() - min; + return Some(child.weight - diff); + } else { + let child = disc.first().unwrap().borrow(); + let diff = max - child.total_weight(); + return Some(child.weight + diff); + } + } +} + +fn solve1(input: &str) -> (String, Rc<RefCell<Program>>) { let mut programs: HashMap<String, Rc<RefCell<Program>>> = HashMap::new(); for line in input.lines() { let mut words = line.split_whitespace(); @@ -43,18 +78,23 @@ fn solve1(input: &str) -> String { program.weight = weight; program.disc = disc; } - programs.into_iter() .find(|&(_, ref p)| Rc::strong_count(p) == 1) .unwrap() - .0 +} + +fn solve2(input: &str) -> u32 { + let root = solve1(input).1; + let balance = root.borrow().balance().unwrap(); + balance } fn main() { let mut input = String::new(); io::stdin().read_to_string(&mut input).unwrap(); - println!("Part 1: {}", solve1(&input)); + println!("Part 1: {}", solve1(&input).0); + println!("Part 2: {}", solve2(&input)); } #[test] @@ -75,5 +115,26 @@ ugml (68) -> gyxo, ebii, jptl gyxo (61) cntj (57) " + ).0); +} + +#[test] +fn part2() { + assert_eq!(60, solve2( +"\ +pbga (66) +xhth (57) +ebii (61) +havc (66) +ktlj (57) +fwft (72) -> ktlj, cntj, xhth +qoyq (66) +padx (45) -> pbga, havc, qoyq +tknk (41) -> ugml, padx, fwft +jptl (61) +ugml (68) -> gyxo, ebii, jptl +gyxo (61) +cntj (57) +" )); } |