// Copyright (c) 1999-2004 Brian Wellington (bwelling@xbill.org)

package pluto.DNS;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.Socket;
import java.net.UnknownHostException;
import java.util.Iterator;
import java.util.List;

import lombok.extern.slf4j.Slf4j;
import pluto.DNS.utils.hexdump;

/**
 * An implementation of Resolver that sends one query to one server.
 * SimpleResolver handles TCP retries, transaction security (TSIG), and EDNS 0.
 * 
 * @see Resolver
 * @see TSIG
 * @see OPTRecord
 * 
 * @author Brian Wellington
 */
@Slf4j
public class SimpleResolver implements Resolver {

	/** The default port to send queries to */
	public static final int			DEFAULT_PORT	= 53;

	private InetAddress				addr;

	private int						port			= DEFAULT_PORT;

	private boolean					useTCP, ignoreTruncation;

	private byte					EDNSlevel		= -1;

	private TSIG					tsig;

	private int						timeoutValue	= 10 * 1000;

	private static final short		DEFAULT_UDPSIZE	= 512;

	private static final short		EDNS_UDPSIZE	= 1280;

	private static int				uniqueID		= 0;

	/**
	 * Creates a SimpleResolver that will query the specified host
	 * 
	 * @exception UnknownHostException
	 *                Failure occurred while finding the host
	 */
	public SimpleResolver(InetAddress hostname) throws UnknownHostException {
		this.addr = hostname;
	}

	public void setPort(int port) {
		this.port = port;
	}

	public String toString() {
		return this.addr.getHostAddress() + ":" + String.valueOf(this.port);

	}

	public void setTCP(boolean flag) {
		this.useTCP = flag;
	}

	public void setIgnoreTruncation(boolean flag) {
		this.ignoreTruncation = flag;
	}

	public void setEDNS(int level) {
		if( level != 0 && level != -1 )
			throw new UnsupportedOperationException("invalid EDNS level " + "- must be 0 or -1");
		this.EDNSlevel = (byte) level;
	}

	public void setTSIGKey(TSIG key) {
		tsig = key;
	}

	public void setTSIGKey(Name name, byte[] key) {
		tsig = new TSIG(name, key);
	}

	public void setTSIGKey(String name, String key) {
		tsig = new TSIG(name, key);
	}

	public void setTSIGKey(String key) throws UnknownHostException {
		setTSIGKey(InetAddress.getLocalHost().getHostName(), key);
	}

	public void setTimeout(int secs) {
		timeoutValue = secs * 1000;
	}

	private byte[] readUDP(DatagramSocket s, int max) throws IOException {
		DatagramPacket dp = new DatagramPacket(new byte[max], max);
		s.receive(dp);
		byte[] in = new byte[dp.getLength()];
		System.arraycopy(dp.getData(), 0, in, 0, in.length);
		if( Options.check("verbosemsg") )
			System.err.println(hexdump.dump("UDP read", in));
		return (in);
	}

	private void writeUDP(DatagramSocket s, byte[] out, InetAddress addr, int port) throws IOException {
		if( Options.check("verbosemsg") )
			System.err.println(hexdump.dump("UDP write", out));
		s.send(new DatagramPacket(out, out.length, addr, port));
	}

	private byte[] readTCP(Socket s) throws IOException {
		DataInputStream dataIn;

		dataIn = new DataInputStream(s.getInputStream());
		int inLength = dataIn.readUnsignedShort();
		byte[] in = new byte[inLength];
		dataIn.readFully(in);
		if( Options.check("verbosemsg") )
			System.err.println(hexdump.dump("TCP read", in));
		return (in);
	}

	private void writeTCP(Socket s, byte[] out) throws IOException {
		DataOutputStream dataOut;

		if( Options.check("verbosemsg") )
			System.err.println(hexdump.dump("TCP write", out));
		dataOut = new DataOutputStream(s.getOutputStream());
		dataOut.writeShort(out.length);
		dataOut.write(out);
		dataOut.flush();
	}

	private Message parseMessage(byte[] b) throws WireParseException {
		try {
			return (new Message(b));
		}
		catch(IOException e) {
			if( Options.check("verbose") )
				//e.printStackTrace();
				log.error(e.getMessage());
			if( !(e instanceof WireParseException) )
				e = new WireParseException("Error parsing message");
			throw (WireParseException) e;
		}
	}

	private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
		if( tsig == null )
			return;
		int error = tsig.verify(response, b, query.getTSIG());
		if( error == Rcode.NOERROR )
			response.tsigState = Message.TSIG_VERIFIED;
		else
			response.tsigState = Message.TSIG_FAILED;
		if( Options.check("verbose") )
			System.err.println("TSIG verify: " + Rcode.string(error));
	}

	private void applyEDNS(Message query) {
		if( EDNSlevel < 0 || query.getOPT() != null )
			return;
		OPTRecord opt = new OPTRecord(EDNS_UDPSIZE, Rcode.NOERROR, (byte) 0);
		query.addRecord(opt, Section.ADDITIONAL);
	}

	private int maxUDPSize(Message query) {
		OPTRecord opt = query.getOPT();
		if( opt == null ) {
			return DEFAULT_UDPSIZE;
		}

		return opt.getPayloadSize();
	}

	/**
	 * Sends a message to a single server and waits for a response. No checking
	 * is done to ensure that the response is associated with the query.
	 * 
	 * @param query
	 *            The query to send.
	 * @return The response.
	 * @throws IOException
	 *             An error occurred while sending or receiving.
	 */
	public Message send(Message query) throws IOException {
		if( Options.check("verbose") )
			System.err.println("Sending to " + addr.getHostAddress() + ":" + port);

		if( query.getHeader().getOpcode() == Opcode.QUERY ) {
			Record question = query.getQuestion();
			if( question != null && question.getType() == Type.AXFR )
				return sendAXFR(query);
		}

		query = (Message) query.clone();
		applyEDNS(query);
		if( tsig != null )
			tsig.apply(query, null);

		byte[] out = query.toWire(Message.MAXLENGTH);
		int udpSize = maxUDPSize(query);
		boolean tcp = false;
		boolean nowrite = false;
		do {
			byte[] in;

			if( useTCP || out.length > udpSize )
				tcp = true;
			if( tcp ) {
				if (log.isDebugEnabled()) {
					log.debug("TCP TYPE");
				}
				Socket s = new Socket(addr, port);
				SureShutdownSocket SURE_SOCKET = new SureShutdownSocket(s, timeoutValue);
				try {
					writeTCP(s, out);
					in = readTCP(s);
				}
				catch(IOException e) {
					if (log.isDebugEnabled()) 
						//e.printStackTrace();
						log.error(e.getMessage());
					throw e;
				}
				finally {
					SURE_SOCKET.close();
				}
			}
			else {
				if (log.isDebugEnabled()) 
					log.debug("UDP TYPE");
				DatagramSocket s = new DatagramSocket();
				SureShutdownDatagramSocket SURE_SOCKET = new SureShutdownDatagramSocket(s, timeoutValue);
				try {
					if( !nowrite ) {
						writeUDP(s, out, addr, port);
					}
					in = readUDP(s, udpSize);
				}
				finally {
					SURE_SOCKET.close();
				}
			}
			/*
			 * Check that the response is long enough.
			 */
			if( in.length < Header.LENGTH ) {
				throw new WireParseException("invalid DNS header - " + "too short");
			}
			/*
			 * Check that the response ID matches the query ID. We want to check
			 * this before actually parsing the message, so that if there's a
			 * malformed response that's not ours, it doesn't confuse us.
			 */
			int id = ((in[0] & 0xFF) << 8) + (in[1] & 0xFF);
			int qid = query.getHeader().getID();
			if( id != qid ) {
				String error = "invalid message id: expected " + qid + "; got id " + id;
				if( tcp ) {
					throw new WireParseException(error);
				}
				
				if( Options.check("verbose") ) {
					System.err.println(error);
				}
				nowrite = true;
				continue;
			}
			Message response = parseMessage(in);
			verifyTSIG(query, response, in, tsig);
			if( !tcp && !ignoreTruncation && response.getHeader().getFlag(Flags.TC) ) {
				tcp = true;
				continue;
			}
			return response;
		} while (true);
	}

	/**
	 * Asynchronously sends a message to a single server, registering a listener
	 * to receive a callback on success or exception. Multiple asynchronous
	 * lookups can be performed in parallel. Since the callback may be invoked
	 * before the function returns, external synchronization is necessary.
	 * 
	 * @param query
	 *            The query to send
	 * @param listener
	 *            The object containing the callbacks.
	 * @return An identifier, which is also a parameter in the callback
	 */
	public Object sendAsync(final Message query, final ResolverListener listener) {
		if (log.isDebugEnabled()) 
			log.debug("call sendAsync ");
		final Object id;
		synchronized (this) {
			id = new Integer(uniqueID++);
		}
		Record question = query.getQuestion();
		String qname;
		if( question != null )
			qname = question.getName().toString();
		else
			qname = "(none)";
		String name = this.getClass() + ": " + qname;
		Thread thread = new ResolveThread(this, query, id, listener);
		thread.setName(name);
		thread.setDaemon(true);
		thread.start();
		return id;
	}

	private Message sendAXFR(Message query) throws IOException {
		Name qname = query.getQuestion().getName();
		ZoneTransferIn xfrin = ZoneTransferIn.newAXFR(qname, this);
		try {
			xfrin.run();
		}
		catch(ZoneTransferException e) {
			throw new WireParseException(e.getMessage());
		}
		List records = xfrin.getAXFR();
		Message response = new Message(query.getHeader().getID());
		response.getHeader().setFlag(Flags.AA);
		response.getHeader().setFlag(Flags.QR);
		response.addRecord(query.getQuestion(), Section.QUESTION);
		Iterator it = records.iterator();
		while (it.hasNext())
			response.addRecord((Record) it.next(), Section.ANSWER);
		return response;
	}

	/**
	 * TCP 쿼리를 할때 소켓이 먹통되어 Hang 걸리는 것을 방지하기 위한 모니터 Class
	 */
	static class SureShutdownSocket implements Runnable {
		long	timeout;

		Socket	socket;

		boolean	connection_end;

		Thread	inner_monitor_thread	= null;

		SureShutdownSocket(Socket s, int t) throws IOException {
			socket = s;
			timeout = t;
			socket.setSoTimeout(t);
			connection_end = false;
			inner_monitor_thread = new Thread(this, "DNS Session Monitor");
			inner_monitor_thread.start();
		}

		void close() throws IOException {
			connection_end = true;
			inner_monitor_thread.interrupt();
			socket.close();
		}

		public void run() {
			try {
				Thread.sleep(timeout);
			}
			catch(Exception e) {
			}

			if( connection_end ) {
				// 이미 닫혔으므로 아무것도 하지 않는다.
				if (log.isDebugEnabled()) 
					log.debug("Already Close");
			}
			else {
				if (log.isDebugEnabled()) 
					log.debug("DNS Session Not Complete in : " + String.valueOf(timeout) + "ms So Close Socket");
				// 스트림을 shutdown 하고
				try {
					socket.shutdownInput();
					socket.shutdownOutput();
				}
				catch(Throwable thw) {
					log.error("socket shudown error", thw);
				}
				// 소켓을 닫는다.
				try {
					socket.close();
				}
				catch(Throwable thw) {
					log.error("socket close error", thw);
				}
			}
		}
	}

	/**
	 * UDP 쿼리를 할때 소켓이 먹통되어 Hang 걸리는 것을 방지하기 위한 모니터 Class
	 */
	static class SureShutdownDatagramSocket implements Runnable {
		long			timeout;

		DatagramSocket	socket;

		boolean			connection_end;

		Thread			inner_monitor_thread	= null;

		SureShutdownDatagramSocket(DatagramSocket s, int t) throws IOException {
			socket = s;
			timeout = t;
			socket.setSoTimeout(t);
			connection_end = false;
			inner_monitor_thread = new Thread(this, "DNS Session Monitor");
			inner_monitor_thread.start();
		}

		void close() throws IOException {
			connection_end = true;
			inner_monitor_thread.interrupt();
			socket.close();
		}

		public void run() {
			try {
				Thread.currentThread().sleep(timeout);
			}
			catch(Exception e) {
			}

			if( connection_end ) {
				// 이미 닫혔으므로 아무것도 하지 않는다.
				if (log.isDebugEnabled()) 
					log.debug("Already Close");
			}
			else {
				if (log.isDebugEnabled()) 
					log.debug("DNS Session Not Complete in : " + String.valueOf(timeout) + "ms So Close Socket");
				// 소켓을 닫는다.
				try {
					socket.close();
				}
				catch(Throwable thw) {
					log.error("socket error", thw);
				}
			}
		}
	}

	static class Stream {
		SimpleResolver		res;

		Socket				sock;

		TSIG				tsig;

		TSIG.StreamVerifier	verifier;

		Stream(SimpleResolver res) throws IOException {
			if (log.isDebugEnabled()) 
				log.debug("CREATE INNER STREAM INSTANCE");
			this.res = res;
			sock = new Socket(res.addr, res.port);
			sock.setSoTimeout(res.timeoutValue);
			tsig = res.tsig;
		}

		void send(Message query) throws IOException {
			if( tsig != null ) {
				tsig.apply(query, null);
				verifier = new TSIG.StreamVerifier(tsig, query.getTSIG());
			}

			byte[] out = query.toWire(Message.MAXLENGTH);
			res.writeTCP(sock, out);
		}

		Message next() throws IOException {
			byte[] in = res.readTCP(sock);
			Message response = res.parseMessage(in);
			if( response.getHeader().getRcode() != Rcode.NOERROR )
				return response;
			if( verifier != null ) {
				TSIGRecord tsigrec = response.getTSIG();

				int error = verifier.verify(response, in);
				if( error == Rcode.NOERROR && tsigrec != null )
					response.tsigState = Message.TSIG_VERIFIED;
				else if( error == Rcode.NOERROR )
					response.tsigState = Message.TSIG_INTERMEDIATE;
				else
					response.tsigState = Message.TSIG_FAILED;
			}
			return response;
		}

		void close() {
			try {
				sock.close();
			}
			catch(IOException e) {
			}
		}
	}

}
