(* $Id: rpc_transport.ml 217 2006-06-15 00:20:27Z gerd $
 * ----------------------------------------------------------------------
 *
 *)

open Rtypes
open Xdr
open Rpc
open Rpc_common

type t =
      { descr : Unix.file_descr;
	prot : protocol;
	mode : mode;

	(* --- additional socket information (esp. UDP) --- *)

	mutable sender : Unix.sockaddr option; (* originator of last message *)
	mutable next_receiver : Unix.sockaddr option;
	     (* who gets the next message *)

	(* --- output --- *)

	mutable out_value : string;      (* includes record mark if needed *)
	mutable out_count : int;

	(* --- input --- *)

	mutable in_buf : string;
	mutable in_buf_pos : int;
	mutable in_buf_len : int;
        (* The substring of in_buf beginning at in_buf_pos with length
	 * in_buf_len is the unprocessed rest of the input buffer
	 *)

	mutable in_values : string list;  (* received parts in rev order *)
	mutable in_total : int;           (* sum of in_values *)
	mutable in_dropping : bool;       (* are we dropping currently? *)
	mutable in_complete : bool;       (* all parts received? *)
	mutable in_eof : bool;

        (* In Tcp mode:
	 * - in_rm:        the current record marker, perhaps incomplete
	 * - in_rm_count:  (0..4) number of valid bytes in in_rm
	 * - in_length:    numeric value of in_rm if complete
	 * - in_last:      true if last fragment
	 * - in_count:     number of bytes read from fragment
	 *)

	mutable in_rm : string;
	mutable in_rm_count : int;
	mutable in_length : int;
	mutable in_last : bool;
	mutable in_count : int
      }


let debug = ref false

  (****)

let create d p m =
    { descr = d;
      prot = p;
      mode = m;
      sender = None;
      next_receiver = None;
      out_value = "";
      out_count = 0;
      in_buf = String.create 8192;                     (* can be adjusted *)
      in_buf_pos = 0;
      in_buf_len = 0;
      in_values = [];
      in_complete = true;
      in_eof = false;
      in_rm = "";
      in_rm_count = 0;
      in_length = 0;
      in_last = false;
      in_count = 0;
      in_dropping = false;
      in_total = 0;
    }

  (****)

let descriptor t = t.descr

  (****)

let set_receiver t r =
    if t.mode = Socket & t.prot = Udp then
      t.next_receiver <- Some r
    else
      failwith "Rpc_transport.set_receiver"

  (****)

let put t pv =
    let v = Rpc_packer.rm_string_of_packed_value pv in
    (* Has four null bytes at the beginning for the record mark *)

    if String.length t.out_value <> t.out_count then
      failwith "Rpc_transport.put";

    if String.length v <= 4 then
      failwith "Rpc_transport.put: empty message not allowed";

    t.out_value <- v;

    match t.prot with
	Udp ->
	  if String.length v > 8004 then
	    failwith "Rpc_transport.put: UDP message too large";
	  t.out_count <- 4  (* skip RM *)

      | Tcp ->
	  let l = String.length v - 4 in
          let l_uint4 = uint4_of_int l in
	  let l_uint4_32 = logical_int32_of_uint4 l_uint4 in
	  let l_uint4_32' = Int32.logor l_uint4_32 Int32.min_int in
                                                 (* 0x80000000 *)
	  let l_uint4' = logical_uint4_of_int32 l_uint4_32' in
	  write_uint4 v 0 l_uint4';
	  t.out_count <- 0

  (****)

let send_part t =
    let l = String.length t.out_value in
    if l <> t.out_count then begin
      ignore(Unix.select [] [t.descr] [] (-1.0));
        (* block until writing possible *)
      match t.prot with
	Udp ->
	  let n =
	    unix_call
	      (fun () ->
		match t.next_receiver with
		  None -> Unix.write t.descr t.out_value 4 (l-4)
		| Some a -> Unix.sendto t.descr t.out_value 4 (l-4) [] a) in
	  if n < l-4 then
	    failwith "Rpc_transport.send_part: UDP message too large";
	  t.out_count <- l;
	  true
      |	Tcp ->
	  let n =
	    unix_call
	      (fun () -> Unix.write
		           t.descr
		           t.out_value
		           t.out_count
		           (l - t.out_count)) in
	  t.out_count <- t.out_count + n;
	  t.out_count = l
    end
    else
      true

  (****)

let is_sending_complete t =
    String.length t.out_value = t.out_count

  (****)

let clean_output t =
    t.out_value <- "";
    t.out_count <- 0

  (****)

let send_sync t pv =
    put t pv;
    try
      while not (is_sending_complete t) do
	ignore(send_part t)
      done;
      clean_output t
    with
      any ->
	clean_output t;
	raise any

  (****)

let receive_part t =
    let reset_for_next_message() =
      if t.in_complete then begin
	t.in_complete <- false;
	t.in_values <- [];
	t.in_total <- 0;
	t.in_rm <- "";
	t.in_rm_count <- 0;
	t.in_length <- 0;
	t.in_dropping <- false;
	t.in_last <- false;
	t.in_count <- 0
      end
    in

    if t.in_buf_pos >= t.in_buf_len then begin
      ignore(Unix.select [t.descr] [] [] (-1.0));
        (* In the case we are in non-blocking mode:
	 * block until reading becomes possible
	 *)
      ()
    end;

    match t.prot with
      Udp ->
        let n,sender =
	  match t.mode with
	    Socket ->
	      let n, s =
	      	unix_call
	    	  (fun () -> Unix.recvfrom
		               t.descr t.in_buf 0 (String.length t.in_buf) [])
	      in
	        n, Some s
	  | BiPipe ->
	      unix_call
	    	(fun () -> Unix.read
		             t.descr t.in_buf 0 (String.length t.in_buf)),
	      None
	in
	if n = 0 then begin
	  t.sender <- sender;
	  t.in_eof <- true;
	  false
	end else
	begin
	  t.sender <- sender;
          t.in_values <- [String.sub t.in_buf 0 n];
	  t.in_total <- n;
	  t.in_complete <- true;
	  true
	end
    | Tcp ->
        let at_eof =
	  if t.in_buf_pos >= t.in_buf_len then
	    unix_call
	      (fun () ->
		 let m = Unix.read t.descr t.in_buf 0 (String.length t.in_buf)
		 in
		 t.in_buf_pos <- 0;
		 t.in_buf_len <- m;
		 if !debug then prerr_endline ("len=" ^ string_of_int m);
		 m = 0     (* EOF criterion *)
	      )
	  else
	    false
	in

	t.sender <- None;

	if at_eof then begin
	  t.in_eof <- true;
	  false
	end
	else begin
	  reset_for_next_message();

	  (* Interpret the n bytes at position k until the message becomes
	   * complete:
	   *)
	  let k = ref t.in_buf_pos in
	  let n = t.in_buf_len in
	  while !k < n  &&  not t.in_complete do
  	    (* case: read record marker *)
	    if t.in_rm_count < 4 then begin
	      if !debug then
		prerr_endline ("Record marker k = " ^ string_of_int !k);
	      let m = min (4 - t.in_rm_count) (n - !k) in
	      t.in_rm <- t.in_rm ^ (String.sub t.in_buf !k m);
	      t.in_rm_count <- t.in_rm_count + m;
	      k := !k + m;
	      (* case: record marker has become complete *)
	      if t.in_rm_count = 4 then begin
	      	let rm = t.in_rm in
	      	t.in_last <- (Char.code rm.[0]) >= 128;
		if t.in_last && !debug then
		  prerr_endline ("Recognized Last Fragment");
	      	let rm_0 = (Char.chr ((Char.code rm.[0]) land 0x7f)) in
	      	t.in_length <-
		   int_of_uint4
		     (mk_uint4 (rm_0,rm.[1],rm.[2],rm.[3]));
	      	t.in_count <- 0
	      end
	    end
	    (* case: read fragment *)
	    else begin
	      if !debug then prerr_endline ("fragment k = " ^ string_of_int !k);
	      let m = min (t.in_length - t.in_count) (n - !k) in
	      if not t.in_dropping then
		t.in_values <-
		  (String.sub t.in_buf !k m) :: t.in_values;
	      k := !k + m;
	      t.in_count <- t.in_count + m;
	      (* case: fragment complete *)
	      if t.in_count = t.in_length then begin
	      	if !debug then
		  prerr_endline ("fragment complete k = " ^ string_of_int !k);
                (* adjust estimated length: *)
                let sum = t.in_total + t.in_length in
		let sum' =
		  if sum < 0 then (* numeric overflow *) max_int else sum in
  	        t.in_total <- sum';
	        (* case: last fragment complete *)
	      	if t.in_last then begin
		  if !debug then prerr_endline ("last fragment");
		  t.in_complete <- true
	      	end;
	      	t.in_rm <- "";
	      	t.in_rm_count <- 0;
		t.in_length <- 0;
	      end
	    end
	  done;

	  (* Store !k for the next round: *)
	  t.in_buf_pos <- !k;

	  if !debug then
	    prerr_endline ("rest n = " ^ string_of_int (n - !k));

	  t.in_complete
	end

  (****)

let get_sender t =
    match t.sender with
      Some x -> x
    | None   -> raise Not_found

  (****)

let is_message_complete t =
    t.in_complete && t.in_values <> []

  (****)

let is_message_incomplete t =
    not t.in_complete

  (****)

let no_message t =
    t.in_complete && t.in_values = []

  (****)

let at_eof t =
    t.in_eof

  (****)

let is_buffer_empty t =
    t.in_buf_pos >= t.in_buf_len

  (****)

let peek_length t =
  if no_message t then failwith "Rpc_transport.peek_length: no message";
  t.in_total + t.in_length
  (* t.in_total: the summed lengths of already complete fragments
   * t.in_length: the announced length of the current fragment
   *   (or 0, if no current fragment)
   *)

  (****)

let get t =
    if is_message_complete t then
      Rpc_packer.packed_value_of_string
        (if t.in_dropping then
           ""
	 else
	   (String.concat "" (List.rev t.in_values)))
    else
      failwith "Rpc_transport.get: message not yet complete"

  (****)

let drop t =
  if no_message t then failwith "Rpc_transport.drop: no message";
  t.in_dropping <- true

  (****)

let clean_input t =
    if t.in_complete or t.in_eof then
      t.in_values <- []
    else
      failwith "Rpc_transport.clean_input: message not yet complete"

  (****)

let receive_sync t =
    if not t.in_complete then
      failwith "Rpc_transport.receive_sync: last message not complete";

    let complete = ref false in

    while not !complete do
      complete := receive_part t;
      if at_eof t then raise End_of_file
    done;

    let v = get t in

    clean_input t;

    v


let verbose b =
  debug := b

