2022 day 16 refactor no 1

This commit is contained in:
Maciej Jur 2022-12-16 21:20:53 +01:00
parent 46351760bf
commit a7ee9464d8

View file

@ -2,8 +2,6 @@ use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet}; use std::collections::{BinaryHeap, HashMap, HashSet};
use crate::utils; use crate::utils;
use rayon::prelude::*;
pub fn run() -> () { pub fn run() -> () {
let lines = utils::read_lines(utils::Source::Day(16)); let lines = utils::read_lines(utils::Source::Day(16));
@ -15,57 +13,57 @@ pub fn run() -> () {
} }
struct Valve<'data> { struct Valve {
name: &'data str, name: usize,
rate: u32, rate: u32,
next: Vec<&'data str> next: Vec<usize>,
} }
#[derive(Eq, PartialEq)] #[derive(Eq, PartialEq)]
struct State<'data> { struct State {
name: &'data str, name: usize,
cost: u32, cost: u32,
} }
impl Ord for State<'_> { impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost) other.cost.cmp(&self.cost)
} }
} }
impl PartialOrd<Self> for State<'_> { impl PartialOrd<Self> for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other)) Some(self.cmp(other))
} }
} }
fn closed_valves<'data>(data: &'data [Valve<'data>]) -> HashSet<&'data str> { fn closed_valves(data: &[Valve]) -> HashSet<usize> {
data.iter() data.iter()
.filter(|valve| valve.rate != 0) .filter(|valve| valve.rate != 0)
.map(|valve| valve.name) .map(|valve| valve.name)
.collect() .collect()
} }
fn build_map<'data, 'a>(valves: &'a [Valve<'data>]) -> HashMap<&'data str, &'a Valve<'data>> { fn build_map(valves: &[Valve]) -> HashMap<usize, &Valve> {
valves.iter().map(|valve| (valve.name, valve)).collect() valves.iter().map(|valve| (valve.name, valve)).collect()
} }
fn release_pressure<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, keys: &'a HashSet<&'data str>) -> u32 { fn release_pressure(map: &HashMap<usize, &Valve>, keys: &HashSet<usize>) -> u32 {
keys.iter() keys.iter()
.map(|&key| map.get(key).unwrap().rate) .map(|&key| map.get(&key).unwrap().rate)
.sum() .sum()
} }
fn find_distance<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, start: &'data str, goal: &'data str, ) -> u32 { fn find_distance(map: &HashMap<usize, &Valve>, start: usize, goal: usize) -> u32 {
let mut frontier: BinaryHeap<State> = BinaryHeap::new(); let mut frontier: BinaryHeap<State> = BinaryHeap::new();
let mut costs: HashMap<&'data str, u32> = HashMap::from([(start, 0)]); let mut costs: HashMap<usize, u32> = HashMap::from([(start, 0)]);
frontier.push(State { name: start, cost: 0 }); frontier.push(State { name: start, cost: 0 });
while let Some(State { name: current, .. }) = frontier.pop() { while let Some(State { name: current, .. }) = frontier.pop() {
if current == goal { break }; if current == goal { break };
for &neighbour in map.get(current).unwrap().next.iter() { for &neighbour in map.get(&current).unwrap().next.iter() {
let cost = costs.get(&current).unwrap() + 1; let cost = costs.get(&current).unwrap() + 1;
if !costs.contains_key(&neighbour) || cost < *costs.get(&neighbour).unwrap() { if !costs.contains_key(&neighbour) || cost < *costs.get(&neighbour).unwrap() {
@ -74,10 +72,10 @@ fn find_distance<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, star
} }
} }
} }
*costs.get(goal).unwrap() *costs.get(&goal).unwrap()
} }
fn find_distances<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, data: &'data [Valve]) -> HashMap<(&'data str, &'data str), u32> { fn find_distances(map: &HashMap<usize, &Valve>, data: &[Valve]) -> HashMap<(usize, usize), u32> {
data.iter() data.iter()
.flat_map(|start| .flat_map(|start|
data.iter().map(|goal| ((start.name, goal.name), find_distance(&map, start.name, goal.name))) data.iter().map(|goal| ((start.name, goal.name), find_distance(&map, start.name, goal.name)))
@ -85,18 +83,18 @@ fn find_distances<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, dat
.collect() .collect()
} }
struct MoveState<'data> { struct MoveState {
curr: &'data str, curr: usize,
next: &'data str, next: usize,
time_left: u32, time_left: u32,
released: u32, released: u32,
} }
fn move_to_open<'data, 'a>( fn move_to_open(
map: &'a HashMap<&'data str, &'a Valve<'data>>, map: &HashMap<usize, &Valve>,
distances: &'a HashMap<(&'data str, &'data str), u32>, distances: &HashMap<(usize, usize), u32>,
closed: &'a HashSet<&'data str>, closed: &HashSet<usize>,
opened: &'a HashSet<&'data str>, opened: &HashSet<usize>,
state: MoveState, state: MoveState,
) -> u32 { ) -> u32 {
let distance = state.time_left.min(*distances.get(&(state.curr, state.next)).unwrap()); let distance = state.time_left.min(*distances.get(&(state.curr, state.next)).unwrap());
@ -105,43 +103,44 @@ fn move_to_open<'data, 'a>(
let curr = state.next; let curr = state.next;
let released = released + release_pressure(map, opened); let released = released + release_pressure(map, opened);
let closed = { let mut closed = closed.clone(); closed.remove(curr); closed }; let closed = { let mut closed = closed.clone(); closed.remove(&curr); closed };
let opened = { let mut opened = opened.clone(); opened.insert(curr); opened }; let opened = { let mut opened = opened.clone(); opened.insert(curr); opened };
let time_left = state.time_left - distance - 1; let time_left = state.time_left - distance - 1;
closed.par_iter() closed.iter()
.map(|&next| move_to_open(map, distances, &closed, &opened, MoveState { curr, next, time_left, released })) .map(|&next| move_to_open(map, distances, &closed, &opened, MoveState { curr, next, time_left, released }))
.max() .max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left) .unwrap_or_else(|| released + release_pressure(map, &opened) * time_left)
} }
fn find_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 { fn find_max_for_start(data: &[Valve], start: usize, limit: u32) -> u32 {
let map = build_map(data); let map = build_map(data);
let start_state = MoveState { curr: start, next: start, time_left: limit + 1, released: 0 }; let start_state = MoveState { curr: start, next: start, time_left: limit + 1, released: 0 };
move_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state) move_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state)
} }
fn solve1(data: &[Valve]) -> u32 { fn solve1((map, data): &(HashMap<&str, usize>, Vec<Valve>)) -> u32 {
find_max_for_start(data, "AA", 30) let start = *map.get("AA").unwrap();
find_max_for_start(data, start, 30)
} }
// Pray this doesn't blow the stack // Pray this doesn't blow the stack
struct ParallelMoveState<'data> { struct ParallelMoveState {
p_curr: &'data str, p_curr: usize,
p_next: &'data str, p_next: usize,
p_progress: u32, p_progress: u32,
e_curr: &'data str, e_curr: usize,
e_next: &'data str, e_next: usize,
e_progress: u32, e_progress: u32,
time_left: u32, time_left: u32,
released: u32, released: u32,
} }
fn parallel_to_open<'data, 'a>( fn parallel_to_open(
map: &'a HashMap<&'data str, &'a Valve<'data>>, map: &HashMap<usize, &Valve>,
distances: &'a HashMap<(&'data str, &'data str), u32>, distances: &HashMap<(usize, usize), u32>,
closed: &'a HashSet<&'data str>, closed: &HashSet<usize>,
opened: &'a HashSet<&'data str>, opened: &HashSet<usize>,
state: ParallelMoveState, state: ParallelMoveState,
) -> u32 { ) -> u32 {
let p_distance_left = *distances.get(&(state.p_curr, state.p_next)).unwrap() - state.p_progress; let p_distance_left = *distances.get(&(state.p_curr, state.p_next)).unwrap() - state.p_progress;
@ -155,25 +154,25 @@ fn parallel_to_open<'data, 'a>(
let (closed, opened) = { let (closed, opened) = {
let mut closed = closed.clone(); let mut closed = closed.clone();
let mut opened = opened.clone(); let mut opened = opened.clone();
if distance == p_distance_left { closed.remove(state.p_next); opened.insert(state.p_next); }; if distance == p_distance_left { closed.remove(&state.p_next); opened.insert(state.p_next); };
if distance == e_distance_left { closed.remove(state.e_next); opened.insert(state.e_next); }; if distance == e_distance_left { closed.remove(&state.e_next); opened.insert(state.e_next); };
(closed, opened) (closed, opened)
}; };
let time_left = state.time_left - distance - 1; let time_left = state.time_left - distance - 1;
match (distance == p_distance_left, distance == e_distance_left) { match (distance == p_distance_left, distance == e_distance_left) {
(true, true) => closed.par_iter() (true, true) => closed.iter()
.flat_map(|&p_next| closed.par_iter() .flat_map(|p_next| closed.iter()
.filter(move |&&e_next| e_next != p_next) .filter(move |&e_next| e_next != p_next)
.map(|&e_next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState { .map(|e_next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_next, p_next, p_progress: 0, p_curr: state.p_next, p_next: *p_next, p_progress: 0,
e_curr: state.e_next, e_next, e_progress: 0, e_curr: state.e_next, e_next: *e_next, e_progress: 0,
time_left, released, time_left, released,
})) }))
) )
.max() .max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left), .unwrap_or_else(|| released + release_pressure(map, &opened) * time_left),
(true, false) => closed.par_iter() (true, false) => closed.iter()
.map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState { .map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_next, p_next: next, p_progress: 0, p_curr: state.p_next, p_next: next, p_progress: 0,
e_curr: state.e_curr, e_next: state.e_next, e_progress: state.e_progress + distance + 1, e_curr: state.e_curr, e_next: state.e_next, e_progress: state.e_progress + distance + 1,
@ -181,7 +180,7 @@ fn parallel_to_open<'data, 'a>(
})) }))
.max() .max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left), .unwrap_or_else(|| released + release_pressure(map, &opened) * time_left),
(false, true) => closed.par_iter() (false, true) => closed.iter()
.map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState { .map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_curr, p_next: state.p_next, p_progress: state.p_progress + distance + 1, p_curr: state.p_curr, p_next: state.p_next, p_progress: state.p_progress + distance + 1,
e_curr: state.e_next, e_next: next, e_progress: 0, e_curr: state.e_next, e_next: next, e_progress: 0,
@ -193,7 +192,7 @@ fn parallel_to_open<'data, 'a>(
} }
} }
fn parallel_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 { fn parallel_max_for_start(data: &[Valve], start: usize, limit: u32) -> u32 {
let map = build_map(data); let map = build_map(data);
let start_state = ParallelMoveState { let start_state = ParallelMoveState {
p_curr: start, p_next: start, p_progress: 0, p_curr: start, p_next: start, p_progress: 0,
@ -203,17 +202,19 @@ fn parallel_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 {
parallel_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state) parallel_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state)
} }
fn solve2(data: &[Valve]) -> u32 { fn solve2((map, data): &(HashMap<&str, usize>, Vec<Valve>)) -> u32 {
parallel_max_for_start(data, "AA", 26) let start = *map.get("AA").unwrap();
parallel_max_for_start(data, start, 26)
} }
fn parse_data<T: AsRef<str>>(data: &[T]) -> Vec<Valve> { fn parse_data<T: AsRef<str>>(data: &[T]) -> (HashMap<&str, usize>, Vec<Valve>) {
data.iter() let valves = data.iter()
.map(|line| { .enumerate()
.map(|(index, line)| {
let mut line = line.as_ref().split(" "); let mut line = line.as_ref().split(" ");
let name = line.nth(1).unwrap(); let name = line.nth(1).unwrap();
let rate = line.nth(2).unwrap() let rate: u32 = line.nth(2).unwrap()
.rsplit("=") .rsplit("=")
.next().unwrap() .next().unwrap()
.split(";") .split(";")
@ -222,9 +223,16 @@ fn parse_data<T: AsRef<str>>(data: &[T]) -> Vec<Valve> {
let next = line.skip(4) let next = line.skip(4)
.map(|str| str.split(",").next().unwrap()) .map(|str| str.split(",").next().unwrap())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Valve { name, rate, next } (index, name, rate, next)
}) })
.collect() .collect::<Vec<_>>();
let map = valves.iter().map(|&(index, name, _, _)| (name, index)).collect::<HashMap<_, _>>();
let valves = valves.into_iter()
.map(|(name, _, rate, next)| Valve {
name, rate, next: next.iter().map(|&next| *map.get(next).unwrap()).collect(),
})
.collect();
(map, valves)
} }