Sunday, October 05, 2008

A first refactoring of A-star in Erlang

Clearly in earlier (unpublished) attempts at using the case construct in Erlang, I had been making some very n00bish mistakes. Today with a set of code that worked, and could be modified step-wise, I find that when used properly, it does make things terser -- fewer intermediate variables, and without the true branch meaning "false" -- and then lets me pattern match on expressions that are functions of arguments, rather than only the arguments themselves.

-module(torch).
-export([main/0]).

astar(Start,Goal) ->
  Closedset = sets:new(), % The set of nodes already evaluated.
  Openset = sets:add_element(Start,sets:new()), %The set of tentative nodes to be evaluated

  Fscore = dict:append(Start, h_score(Start), dict:new()),
  Gscore = dict:append(Start, 0, dict:new()), % Distance from start along optimal path.
  
  CameFrom = dict:append(Start, none, dict:new()),
  
  astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom).
  
astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom) ->
  case sets:size(Openset) of
    0 ->
      failure;
    _ ->
      case best_step(sets:to_list(Openset), Fscore, none, infinity) of
        Goal ->
          reconstruct_path(CameFrom, Goal);
        X ->
          NextOpen = sets:del_element(X, Openset),
          NextClosed = sets:add_element(X, Closedset),
          Neighbours = neighbour_nodes(X),
          {NewOpen, NewF, NewG, NewFrom} = scan(X, Neighbours, NextOpen, NextClosed, Fscore, Gscore, CameFrom),
          astar_step(Goal, NextClosed, NewOpen, NewF, NewG, NewFrom)
      end
  end.
          
scan(_X, [], Open, _Closed, F, G, From) ->
  {Open, F, G, From};
scan(X, [Y|N], Open, Closed, F, G, From) ->
  case sets:is_element(Y, Closed) of
    true ->
      scan(X, N, Open, Closed, F, G, From);
    false ->
      [G0] = dict:fetch(X, G),
      TrialG = G0 + dist_between(X,Y),
      case sets:is_element(Y, Open) of
        true ->
          [OldG] = dict:fetch(Y, G),
          case TrialG < OldG of
            true ->
              {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
              scan(X, N, Open, Closed, NewF, NewG, NewFrom);
            false ->
              scan(X, N, Open, Closed, F, G, From)
            end;
        false ->
          NewOpen = sets:add_element(Y, Open),
          {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
          scan(X, N, NewOpen, Closed, NewF, NewG, NewFrom)
      end
  end.
          
update(X, Y, OldF, OldG, OldFrom, GValue) ->
  KeyF = dict:is_key(Y, OldF),
  KeyG = dict:is_key(Y, OldG),  
  KeyFrom = dict:is_key(Y, OldFrom),
  case {KeyF, KeyG, KeyFrom} of
    {true, _, _} ->
      update(X, Y, dict:erase(Y, OldF), OldG, OldFrom, GValue);
    {_, true, _} ->
      update(X, Y, OldF, dict:erase(Y, OldG), OldFrom, GValue);
    {_, _, true} ->
      update(X, Y, OldF, OldG, dict:erase(Y, OldFrom), GValue);
    _ ->
      NewFrom = dict:append(Y, X, OldFrom),
      NewG = dict:append(Y, GValue, OldG),
      NewF = dict:append(Y, GValue + h_score(Y), OldF), % Estimated total distance from start to goal through y.
      {NewF, NewG, NewFrom}
  end.
     
reconstruct_path(CameFrom, Node) ->
  case dict:fetch(Node, CameFrom) of
    [none] ->
      [Node];
    [Value] ->
      [Node | reconstruct_path(CameFrom, Value)]
  end.

best_step([H|Open], Score, none, _) ->
    [V] = dict:fetch(H, Score),
    best_step(Open, Score, H, V);
best_step([], _Score, Best, _BestValue) ->
     Best;
best_step([H|Open], Score, Best, BestValue) ->
  [Value] = dict:fetch(H, Score),
  case Value < BestValue of
    true ->
      best_step(Open, Score, H, Value);
    false ->
      best_step(Open, Score, Best, BestValue)
  end.

%% specialize for the torch-carrier problem
%% bits 0-4 represent torch, 1m, 2m, 5m, 10m
%% set bit = on the near side

%% Every possible crossing of one or two
swaps() ->
  [3,5,9,17, 7, 11,13, 19,21,25].

crossing_time(Swap) ->
  if

    Swap band 16 > 0 ->
      10;
    Swap band 8 > 0 ->
      5;
    Swap band 4 > 0 ->
      2;
    true ->
      1
  end.

%% presentation form
display(Swap) ->
  if
    Swap band 16 > 0 ->
      compound(Swap, 16);
    Swap band 8 > 0 ->
      compound(Swap, 8);
    Swap band 4 > 0 ->
      compound(Swap, 4);
    Swap band 2 > 0 ->
      compound(Swap, 2);
    true ->
      ""
    end.
      
compound(Swap, Bit) ->
  string:concat( erlang:integer_to_list(crossing_time(Swap)),
      decorate(display(Swap bxor Bit))).
  
decorate(Value) ->
  case string:len(Value) of
    0 ->
      Value;
    _ ->
      string:concat("+", Value)
  end.
  
neighbour_nodes(X) ->
  Result = [],
  compatible(X, equivalent_point(X), Result, swaps()).
  
equivalent_point(X) ->
  case X band 1 of
    1 ->
      X;
    0 ->
      31 bxor X
  end.
  
compatible(_X, _Y, Outlist, []) ->
  Outlist;
compatible(X, Y, Outlist, [Swap|Inlist]) ->
  case (Y band Swap) of
    Swap ->
      New = X bxor Swap,
      compatible(X, Y, [New|Outlist], Inlist);
    _ ->
      compatible(X, Y, Outlist, Inlist)
  end.

dist_between(X,Y) ->
  Swap = X bxor Y,
  crossing_time(Swap).
  
h_score(Node) ->
  crossing_time(Node).
  
main() ->
  [H|Trace] = lists:reverse(astar(31, 0)),
  Time = print_result(Trace, H, 0),
  io:format("Time taken = ~B minutes~n", [Time]) .

print_result([], _Prev, Time) ->
  Time;
print_result([H|Trace], Prev, Time) ->
  Swap = H bxor Prev,
  print_swap(Swap, H band 1),
  Interval = Time + crossing_time(Swap),
  print_result(Trace, H, Interval).

print_swap(Swap, Side) ->
  case Side of
    0 ->
      io:format(  "    ~s -->~n", [display(Swap)]);
    1 ->
      io:format(  "<-- ~s~n", [display(Swap)])
  end.

The use of if for crossing_time and display gives a natural if/else if cascade (as the most "traditional" analogue for what the Erlang if does). Most of the time, using case has made the code terser (removing the need for most temporaries in simple binary choices, if nothing else -- here best_step is an exception, where the temporary looks to be a necessity); for update, it was about neutral. Maybe there is a trick I am missing, but anything going through a binary looks like it would make crossing_time and display more cumbersome, rather than less.

4 comments:

Steve Vinoski said...

Wow, that code cleaned up nicely!

The remaining "if" statements in crossing_time and display are fine. However, just for the mental exercise, you can also think of them in terms of operations over lists. The crossing_time fun, for example, operates over a list of [16, 8, 4, 2] where 2 represents the current "true" branch of the "if", and display operates over the list [16, 8, 4, 2, 1] where 1 represents the "true" branch. Folding over these lists would work, but note that doing so would require walking the whole list. While that's admittedly not very expensive, it still costs more than the short-circuiting "if". You could do this to walk the list while still short-circuiting:

crossing_time(Swap) ->
  crossing_time(Swap, [{16, 10}, {8,5}, {4,2}, {2,1}]).
crossing_time(_, [{2,1}]) -> 1;
crossing_time(Swap, [{Val, Result} | _])
 when Swap band Val =/= 0 ->
  Result;
crossing_time(Swap, [_ | T]) -> crossing_time(Swap, T).

I wouldn't argue that that's clearer than the "if" though.

Doug said...

I can't see the Erlang Code. It shows up for an instant when I click on the title, then it disapears. Thanks!

Steve said...

It appears that the syntax highlighting for Erlang is working for FireFox, but is doing something weird for IE.

If you go to the whole-month page for October, the Erlang code shows but isn't highlighted; and as noted vanishes altogether if you just go to the single entry.

Steve said...

There was a spurious comma in the highlighting script, which borked IE. Fixed now.