Sunday, October 05, 2008

A first refactoring of A-star in Erlang

Clearly in earlier (unpublished) attempts at using the case construct in Erlang, I had been making some very n00bish mistakes. Today with a set of code that worked, and could be modified step-wise, I find that when used properly, it does make things terser -- fewer intermediate variables, and without the true branch meaning "false" -- and then lets me pattern match on expressions that are functions of arguments, rather than only the arguments themselves.

-module(torch).
-export([main/0]).

astar(Start,Goal) ->
  Closedset = sets:new(), % The set of nodes already evaluated.
  Openset = sets:add_element(Start,sets:new()), %The set of tentative nodes to be evaluated

  Fscore = dict:append(Start, h_score(Start), dict:new()),
  Gscore = dict:append(Start, 0, dict:new()), % Distance from start along optimal path.
  
  CameFrom = dict:append(Start, none, dict:new()),
  
  astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom).
  
astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom) ->
  case sets:size(Openset) of
    0 ->
      failure;
    _ ->
      case best_step(sets:to_list(Openset), Fscore, none, infinity) of
        Goal ->
          reconstruct_path(CameFrom, Goal);
        X ->
          NextOpen = sets:del_element(X, Openset),
          NextClosed = sets:add_element(X, Closedset),
          Neighbours = neighbour_nodes(X),
          {NewOpen, NewF, NewG, NewFrom} = scan(X, Neighbours, NextOpen, NextClosed, Fscore, Gscore, CameFrom),
          astar_step(Goal, NextClosed, NewOpen, NewF, NewG, NewFrom)
      end
  end.
          
scan(_X, [], Open, _Closed, F, G, From) ->
  {Open, F, G, From};
scan(X, [Y|N], Open, Closed, F, G, From) ->
  case sets:is_element(Y, Closed) of
    true ->
      scan(X, N, Open, Closed, F, G, From);
    false ->
      [G0] = dict:fetch(X, G),
      TrialG = G0 + dist_between(X,Y),
      case sets:is_element(Y, Open) of
        true ->
          [OldG] = dict:fetch(Y, G),
          case TrialG < OldG of
            true ->
              {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
              scan(X, N, Open, Closed, NewF, NewG, NewFrom);
            false ->
              scan(X, N, Open, Closed, F, G, From)
            end;
        false ->
          NewOpen = sets:add_element(Y, Open),
          {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
          scan(X, N, NewOpen, Closed, NewF, NewG, NewFrom)
      end
  end.
          
update(X, Y, OldF, OldG, OldFrom, GValue) ->
  KeyF = dict:is_key(Y, OldF),
  KeyG = dict:is_key(Y, OldG),  
  KeyFrom = dict:is_key(Y, OldFrom),
  case {KeyF, KeyG, KeyFrom} of
    {true, _, _} ->
      update(X, Y, dict:erase(Y, OldF), OldG, OldFrom, GValue);
    {_, true, _} ->
      update(X, Y, OldF, dict:erase(Y, OldG), OldFrom, GValue);
    {_, _, true} ->
      update(X, Y, OldF, OldG, dict:erase(Y, OldFrom), GValue);
    _ ->
      NewFrom = dict:append(Y, X, OldFrom),
      NewG = dict:append(Y, GValue, OldG),
      NewF = dict:append(Y, GValue + h_score(Y), OldF), % Estimated total distance from start to goal through y.
      {NewF, NewG, NewFrom}
  end.
     
reconstruct_path(CameFrom, Node) ->
  case dict:fetch(Node, CameFrom) of
    [none] ->
      [Node];
    [Value] ->
      [Node | reconstruct_path(CameFrom, Value)]
  end.

best_step([H|Open], Score, none, _) ->
    [V] = dict:fetch(H, Score),
    best_step(Open, Score, H, V);
best_step([], _Score, Best, _BestValue) ->
     Best;
best_step([H|Open], Score, Best, BestValue) ->
  [Value] = dict:fetch(H, Score),
  case Value < BestValue of
    true ->
      best_step(Open, Score, H, Value);
    false ->
      best_step(Open, Score, Best, BestValue)
  end.

%% specialize for the torch-carrier problem
%% bits 0-4 represent torch, 1m, 2m, 5m, 10m
%% set bit = on the near side

%% Every possible crossing of one or two
swaps() ->
  [3,5,9,17, 7, 11,13, 19,21,25].

crossing_time(Swap) ->
  if

    Swap band 16 > 0 ->
      10;
    Swap band 8 > 0 ->
      5;
    Swap band 4 > 0 ->
      2;
    true ->
      1
  end.

%% presentation form
display(Swap) ->
  if
    Swap band 16 > 0 ->
      compound(Swap, 16);
    Swap band 8 > 0 ->
      compound(Swap, 8);
    Swap band 4 > 0 ->
      compound(Swap, 4);
    Swap band 2 > 0 ->
      compound(Swap, 2);
    true ->
      ""
    end.
      
compound(Swap, Bit) ->
  string:concat( erlang:integer_to_list(crossing_time(Swap)),
      decorate(display(Swap bxor Bit))).
  
decorate(Value) ->
  case string:len(Value) of
    0 ->
      Value;
    _ ->
      string:concat("+", Value)
  end.
  
neighbour_nodes(X) ->
  Result = [],
  compatible(X, equivalent_point(X), Result, swaps()).
  
equivalent_point(X) ->
  case X band 1 of
    1 ->
      X;
    0 ->
      31 bxor X
  end.
  
compatible(_X, _Y, Outlist, []) ->
  Outlist;
compatible(X, Y, Outlist, [Swap|Inlist]) ->
  case (Y band Swap) of
    Swap ->
      New = X bxor Swap,
      compatible(X, Y, [New|Outlist], Inlist);
    _ ->
      compatible(X, Y, Outlist, Inlist)
  end.

dist_between(X,Y) ->
  Swap = X bxor Y,
  crossing_time(Swap).
  
h_score(Node) ->
  crossing_time(Node).
  
main() ->
  [H|Trace] = lists:reverse(astar(31, 0)),
  Time = print_result(Trace, H, 0),
  io:format("Time taken = ~B minutes~n", [Time]) .

print_result([], _Prev, Time) ->
  Time;
print_result([H|Trace], Prev, Time) ->
  Swap = H bxor Prev,
  print_swap(Swap, H band 1),
  Interval = Time + crossing_time(Swap),
  print_result(Trace, H, Interval).

print_swap(Swap, Side) ->
  case Side of
    0 ->
      io:format(  "    ~s -->~n", [display(Swap)]);
    1 ->
      io:format(  "<-- ~s~n", [display(Swap)])
  end.

The use of if for crossing_time and display gives a natural if/else if cascade (as the most "traditional" analogue for what the Erlang if does). Most of the time, using case has made the code terser (removing the need for most temporaries in simple binary choices, if nothing else -- here best_step is an exception, where the temporary looks to be a necessity); for update, it was about neutral. Maybe there is a trick I am missing, but anything going through a binary looks like it would make crossing_time and display more cumbersome, rather than less.

Post a Comment