2022 day 16 refactor no 1

This commit is contained in:
Maciej Jur 2022-12-16 21:20:53 +01:00
parent 46351760bf
commit a7ee9464d8

View file

@ -2,8 +2,6 @@ use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use crate::utils;
use rayon::prelude::*;
pub fn run() -> () {
let lines = utils::read_lines(utils::Source::Day(16));
@ -15,57 +13,57 @@ pub fn run() -> () {
}
struct Valve<'data> {
name: &'data str,
struct Valve {
name: usize,
rate: u32,
next: Vec<&'data str>
next: Vec<usize>,
}
#[derive(Eq, PartialEq)]
struct State<'data> {
name: &'data str,
struct State {
name: usize,
cost: u32,
}
impl Ord for State<'_> {
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.cmp(&self.cost)
}
}
impl PartialOrd<Self> for State<'_> {
impl PartialOrd<Self> for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
fn closed_valves<'data>(data: &'data [Valve<'data>]) -> HashSet<&'data str> {
fn closed_valves(data: &[Valve]) -> HashSet<usize> {
data.iter()
.filter(|valve| valve.rate != 0)
.map(|valve| valve.name)
.collect()
}
fn build_map<'data, 'a>(valves: &'a [Valve<'data>]) -> HashMap<&'data str, &'a Valve<'data>> {
fn build_map(valves: &[Valve]) -> HashMap<usize, &Valve> {
valves.iter().map(|valve| (valve.name, valve)).collect()
}
fn release_pressure<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, keys: &'a HashSet<&'data str>) -> u32 {
fn release_pressure(map: &HashMap<usize, &Valve>, keys: &HashSet<usize>) -> u32 {
keys.iter()
.map(|&key| map.get(key).unwrap().rate)
.map(|&key| map.get(&key).unwrap().rate)
.sum()
}
fn find_distance<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, start: &'data str, goal: &'data str, ) -> u32 {
fn find_distance(map: &HashMap<usize, &Valve>, start: usize, goal: usize) -> u32 {
let mut frontier: BinaryHeap<State> = BinaryHeap::new();
let mut costs: HashMap<&'data str, u32> = HashMap::from([(start, 0)]);
let mut costs: HashMap<usize, u32> = HashMap::from([(start, 0)]);
frontier.push(State { name: start, cost: 0 });
while let Some(State { name: current, .. }) = frontier.pop() {
if current == goal { break };
for &neighbour in map.get(current).unwrap().next.iter() {
for &neighbour in map.get(&current).unwrap().next.iter() {
let cost = costs.get(&current).unwrap() + 1;
if !costs.contains_key(&neighbour) || cost < *costs.get(&neighbour).unwrap() {
@ -74,10 +72,10 @@ fn find_distance<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, star
}
}
}
*costs.get(goal).unwrap()
*costs.get(&goal).unwrap()
}
fn find_distances<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, data: &'data [Valve]) -> HashMap<(&'data str, &'data str), u32> {
fn find_distances(map: &HashMap<usize, &Valve>, data: &[Valve]) -> HashMap<(usize, usize), u32> {
data.iter()
.flat_map(|start|
data.iter().map(|goal| ((start.name, goal.name), find_distance(&map, start.name, goal.name)))
@ -85,18 +83,18 @@ fn find_distances<'data, 'a>(map: &'a HashMap<&'data str, &'a Valve<'data>>, dat
.collect()
}
struct MoveState<'data> {
curr: &'data str,
next: &'data str,
struct MoveState {
curr: usize,
next: usize,
time_left: u32,
released: u32,
}
fn move_to_open<'data, 'a>(
map: &'a HashMap<&'data str, &'a Valve<'data>>,
distances: &'a HashMap<(&'data str, &'data str), u32>,
closed: &'a HashSet<&'data str>,
opened: &'a HashSet<&'data str>,
fn move_to_open(
map: &HashMap<usize, &Valve>,
distances: &HashMap<(usize, usize), u32>,
closed: &HashSet<usize>,
opened: &HashSet<usize>,
state: MoveState,
) -> u32 {
let distance = state.time_left.min(*distances.get(&(state.curr, state.next)).unwrap());
@ -105,43 +103,44 @@ fn move_to_open<'data, 'a>(
let curr = state.next;
let released = released + release_pressure(map, opened);
let closed = { let mut closed = closed.clone(); closed.remove(curr); closed };
let closed = { let mut closed = closed.clone(); closed.remove(&curr); closed };
let opened = { let mut opened = opened.clone(); opened.insert(curr); opened };
let time_left = state.time_left - distance - 1;
closed.par_iter()
closed.iter()
.map(|&next| move_to_open(map, distances, &closed, &opened, MoveState { curr, next, time_left, released }))
.max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left)
}
fn find_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 {
fn find_max_for_start(data: &[Valve], start: usize, limit: u32) -> u32 {
let map = build_map(data);
let start_state = MoveState { curr: start, next: start, time_left: limit + 1, released: 0 };
move_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state)
}
fn solve1(data: &[Valve]) -> u32 {
find_max_for_start(data, "AA", 30)
fn solve1((map, data): &(HashMap<&str, usize>, Vec<Valve>)) -> u32 {
let start = *map.get("AA").unwrap();
find_max_for_start(data, start, 30)
}
// Pray this doesn't blow the stack
struct ParallelMoveState<'data> {
p_curr: &'data str,
p_next: &'data str,
struct ParallelMoveState {
p_curr: usize,
p_next: usize,
p_progress: u32,
e_curr: &'data str,
e_next: &'data str,
e_curr: usize,
e_next: usize,
e_progress: u32,
time_left: u32,
released: u32,
}
fn parallel_to_open<'data, 'a>(
map: &'a HashMap<&'data str, &'a Valve<'data>>,
distances: &'a HashMap<(&'data str, &'data str), u32>,
closed: &'a HashSet<&'data str>,
opened: &'a HashSet<&'data str>,
fn parallel_to_open(
map: &HashMap<usize, &Valve>,
distances: &HashMap<(usize, usize), u32>,
closed: &HashSet<usize>,
opened: &HashSet<usize>,
state: ParallelMoveState,
) -> u32 {
let p_distance_left = *distances.get(&(state.p_curr, state.p_next)).unwrap() - state.p_progress;
@ -155,25 +154,25 @@ fn parallel_to_open<'data, 'a>(
let (closed, opened) = {
let mut closed = closed.clone();
let mut opened = opened.clone();
if distance == p_distance_left { closed.remove(state.p_next); opened.insert(state.p_next); };
if distance == e_distance_left { closed.remove(state.e_next); opened.insert(state.e_next); };
if distance == p_distance_left { closed.remove(&state.p_next); opened.insert(state.p_next); };
if distance == e_distance_left { closed.remove(&state.e_next); opened.insert(state.e_next); };
(closed, opened)
};
let time_left = state.time_left - distance - 1;
match (distance == p_distance_left, distance == e_distance_left) {
(true, true) => closed.par_iter()
.flat_map(|&p_next| closed.par_iter()
.filter(move |&&e_next| e_next != p_next)
.map(|&e_next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_next, p_next, p_progress: 0,
e_curr: state.e_next, e_next, e_progress: 0,
(true, true) => closed.iter()
.flat_map(|p_next| closed.iter()
.filter(move |&e_next| e_next != p_next)
.map(|e_next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_next, p_next: *p_next, p_progress: 0,
e_curr: state.e_next, e_next: *e_next, e_progress: 0,
time_left, released,
}))
)
.max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left),
(true, false) => closed.par_iter()
(true, false) => closed.iter()
.map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_next, p_next: next, p_progress: 0,
e_curr: state.e_curr, e_next: state.e_next, e_progress: state.e_progress + distance + 1,
@ -181,7 +180,7 @@ fn parallel_to_open<'data, 'a>(
}))
.max()
.unwrap_or_else(|| released + release_pressure(map, &opened) * time_left),
(false, true) => closed.par_iter()
(false, true) => closed.iter()
.map(|&next| parallel_to_open(map, distances, &closed, &opened, ParallelMoveState {
p_curr: state.p_curr, p_next: state.p_next, p_progress: state.p_progress + distance + 1,
e_curr: state.e_next, e_next: next, e_progress: 0,
@ -193,7 +192,7 @@ fn parallel_to_open<'data, 'a>(
}
}
fn parallel_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 {
fn parallel_max_for_start(data: &[Valve], start: usize, limit: u32) -> u32 {
let map = build_map(data);
let start_state = ParallelMoveState {
p_curr: start, p_next: start, p_progress: 0,
@ -203,17 +202,19 @@ fn parallel_max_for_start(data: &[Valve], start: &str, limit: u32) -> u32 {
parallel_to_open(&map, &find_distances(&map, data), &closed_valves(data), &HashSet::new(), start_state)
}
fn solve2(data: &[Valve]) -> u32 {
parallel_max_for_start(data, "AA", 26)
fn solve2((map, data): &(HashMap<&str, usize>, Vec<Valve>)) -> u32 {
let start = *map.get("AA").unwrap();
parallel_max_for_start(data, start, 26)
}
fn parse_data<T: AsRef<str>>(data: &[T]) -> Vec<Valve> {
data.iter()
.map(|line| {
fn parse_data<T: AsRef<str>>(data: &[T]) -> (HashMap<&str, usize>, Vec<Valve>) {
let valves = data.iter()
.enumerate()
.map(|(index, line)| {
let mut line = line.as_ref().split(" ");
let name = line.nth(1).unwrap();
let rate = line.nth(2).unwrap()
let rate: u32 = line.nth(2).unwrap()
.rsplit("=")
.next().unwrap()
.split(";")
@ -222,9 +223,16 @@ fn parse_data<T: AsRef<str>>(data: &[T]) -> Vec<Valve> {
let next = line.skip(4)
.map(|str| str.split(",").next().unwrap())
.collect::<Vec<_>>();
Valve { name, rate, next }
(index, name, rate, next)
})
.collect()
.collect::<Vec<_>>();
let map = valves.iter().map(|&(index, name, _, _)| (name, index)).collect::<HashMap<_, _>>();
let valves = valves.into_iter()
.map(|(name, _, rate, next)| Valve {
name, rate, next: next.iter().map(|&next| *map.get(next).unwrap()).collect(),
})
.collect();
(map, valves)
}