/*
* Copyright 2004-2006 Luke Quinane and Daniel Frampton
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
using System;
using System.Collections;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using log4net;
using NDns.Message;
using NDns.Message.Records;
using NDns.Configuration;
namespace NDns{
/// <summary>
/// Provides a cache for DNS information.
/// </summary>
public class DnsCache {
#region Attributes
/// <summary>
/// Logging support for this class.
/// </summary>
protected static readonly ILog log = LogManager.GetLogger(typeof(DnsCache));
/// <summary>
/// A hash table of arrays of MX records indexed by string name
/// (e.g. an array of MX records for "anu.edu.au").
/// </summary>
private Hashtable mxServers;
/// <summary>
/// A list of MX records ordered according to last use.
/// </summary>
private ArrayList mxQueue;
/// <summary>
/// A hash table of string records indexed by IP address
/// </summary>
private Hashtable domains;
/// <summary>
/// A list of string records ordered according to last use.
/// </summary>
private ArrayList domainQueue;
/// <summary>
/// A hash table of ip address record arrays indexed by string name
/// (e.g. an array of A records for "www.anu.edu.au").
/// </summary>
private Hashtable addresses;
/// <summary>
/// A list of address records ordered according to last use.
/// </summary>
private ArrayList addressQueue;
/// <summary>
/// The maximum number of MX record arrays to store.
/// </summary>
private int MXCacheSize = 50;
/// <summary>
/// The maximum number of string names to store.
/// </summary>
private int DomainCacheSize = 50;
/// <summary>
/// The maximum number of A record arrays to store.
/// </summary>
private int AddressCacheSize = 50;
/// <summary>
/// The configuration for this DNS cache.
/// </summary>
private NDnsConfiguration config;
#endregion
#region Constructor
/// <summary>
/// Creates a new instance of the DNS this.
/// </summary>
public DnsCache(NDnsConfiguration config) {
this.config = config;
this.AddressCacheSize = this.config.MaximumAddressCacheSize;
this.DomainCacheSize = this.config.MaximumDomainCacheSize;
this.MXCacheSize = this.config.MaximumMXCacheSize;
// setup the MX cache data structures
this.mxServers = new Hashtable();
this.mxQueue = new ArrayList();
// setup the string name cache data structures
this.domains = new Hashtable();
this.domainQueue = new ArrayList();
// setup the address cache data structures
this.addresses = new Hashtable();
this.addressQueue = new ArrayList();
}
#endregion
#region API
#region DomainMatchesAddress
/// <summary>
/// Checks if the given string and address match.
/// </summary>
/// <param name="address">The address to check.</param>
/// <param name="domain">The string to check.</param>
/// <returns>True if the string corresponds to the address.</returns>
public bool DomainMatchesAddress(IPAddress address, string domain) {
string addrDomain = GetHostName(address);
if (addrDomain.Equals(domain)) {
return true;
}
else {
return false;
}
}
#endregion
#region GetHostName
/// <summary>
/// Gets a hostname for the specified end point.
/// </summary>
/// <param name="endPoint">The end point to lookup</param>
/// <returns>The hostname.</returns>
public string GetHostName(EndPoint endPoint) {
return GetHostName(((IPEndPoint) endPoint).Address);
}
/// <summary>
/// Gets a hostname for the specified address.
/// </summary>
/// <param name="address">The address to lookup</param>
/// <returns>The hostname.</returns>
public string GetHostName(IPAddress address) {
string[] hostnames = GetHostNames(address);
return (hostnames != null && hostnames.Length > 0) ? hostnames[0] : null;
}
/// <summary>
/// Gets a hostnames for the specified address.
/// </summary>
/// <param name="address">The address to lookup</param>
/// <returns>The address.</returns>
public string[] GetHostNames(IPAddress address) {
return ConvertToHostnames(GetPTRRecords(address));
}
#endregion
#region GetIPAddress(es)
/// <summary>
/// Gets the first IP address associated with the given domain name.
/// </summary>
/// <param name="domainName">The domain name to resolve.</param>
/// <returns>The IP address.</returns>
public IPAddress GetIPAddress(string domainName) {
IPAddress[] addresses = GetIPAddresses(domainName);
return (addresses != null && addresses.Length > 0) ? addresses[0] : null;
}
/// <summary>
/// Looks up IP addresses for the given string name.
/// </summary>
/// <param name="domainName">The string name to lookup the addresses for.</param>
/// <returns>The IP addresses for the string name.</returns>
public IPAddress[] GetIPAddresses(string domainName) {
return ConvertToIPAddresses(GetARecords(domainName));
}
#endregion
#region LookupRecords
/// <summary>
/// Looks up DNS records using the specified query string and query type.
/// </summary>
/// <param name="queryString">The query to send.</param>
/// <param name="queryType">The query type.</param>
/// <returns>The array of records from the DNS server.</returns>
public Record[] LookupRecords(string queryString, QType queryType) {
DNSMessage response = PerformQuery(QueryFactory.CreateQuery(queryString, queryType));
AnswerSection[] answerEntries = response.AnswerEntries;
ArrayList records = new ArrayList();
// Get all valid Records
for (int i = 0; i < answerEntries.Length; i++) {
records.Add(answerEntries[i].Record);
}
// convert the records to an array
Record[] result = (Record[]) records.ToArray(typeof(Record));
return result;
}
#endregion
#endregion
#region Caching
#region GetPTRRecords
/// <summary>
/// Gets the PTR records for the specified IP address caching the results.
/// This method also respects TTL for the records.
/// </summary>
/// <param name="address">The address to lookup</param>
/// <returns>The PTR records.</returns>
public PTRRecord[] GetPTRRecords(IPAddress address) {
lock (this.domains) {
string key = address.ToString();
if (this.domains.ContainsKey(key)) {
// records exists, remove from the remove queue
this.domainQueue.Remove(key);
if (((PTRRecord[]) this.domains[key])[0].TTL <= 0) {
// elapsed TTL, lookup again and add to end of remove queue
this.domains[key] = LookupPTRRecords(address);
}
// records now exist and have a valid TTL, add to end of remove queue
this.domainQueue.Add(key);
return (PTRRecord[]) this.domains[key];
}
else {
PTRRecord[] records = LookupPTRRecords(address);
if (records == null) return null;
if (this.domainQueue.Count >= this.DomainCacheSize) {
// remove the oldest (last accessed) entry
string removeName = (string) this.domainQueue[0];
this.domainQueue.Remove(removeName);
this.domains.Remove(removeName);
}
// add the new records
this.domainQueue.Add(key);
this.domains.Add(key, records);
return records;
}
}
}
#endregion
#region GetARecords
/// <summary>
/// Gets address records for the specified hostname caching the results.
/// This method also respects TTL for the records.
/// </summary>
/// <param name="domainName">The hostname to lookup</param>
/// <returns>The address records.</returns>
public ARecord[] GetARecords(string domainName) {
lock (this.addresses) {
string key = domainName;
if (this.addresses.ContainsKey(key)) {
// records exists, remove from the remove queue
this.addressQueue.Remove(key);
if (((ARecord[]) this.addresses[key])[0].TTL <= 0) {
// elapsed TTL, lookup again and add to end of remove queue
this.addresses[key] = LookupARecords(domainName);
}
// records now exist and have a valid TTL, add to end of remove queue
this.addressQueue.Add(key);
return (ARecord[]) this.addresses[key];
}
else {
ARecord[] records = LookupARecords(domainName);
if (records == null) return null;
if (this.addressQueue.Count >= this.AddressCacheSize) {
// remove the oldest (last accessed) entry
string removeName = (string) this.addressQueue[0];
this.addressQueue.Remove(removeName);
this.addresses.Remove(removeName);
}
// add the new records
this.addressQueue.Add(key);
this.addresses.Add(key, records);
return records;
}
}
}
#endregion
#region GetMXRecords
/// <summary>
/// Looks up MX records for the given string name caching the results.
/// This method also respects TTL for the records.
/// </summary>
/// <param name="domainName">The string to get the MX records for.</param>
/// <returns>The MX records</returns>
public MXRecord[] GetMXRecords(string domainName) {
lock (this.mxServers) {
string key = domainName;
if (this.mxServers.ContainsKey(key)) {
// records exists, remove from the remove queue
this.mxQueue.Remove(key);
if (((MXRecord[]) this.mxServers[key])[0].TTL <= 0) {
// elapsed TTL, lookup again and add to end of remove queue
this.mxServers[key] = LookupMXRecords(domainName);
}
// records now exist and have a valid TTL, add to end of remove queue
this.mxQueue.Add(key);
return (MXRecord[]) this.mxServers[key];
}
else {
MXRecord[] records = LookupMXRecords(domainName);
if (this.mxQueue.Count >= MXCacheSize) {
// remove the oldest (last accessed) entry
string removeName = (string) this.mxQueue[0];
this.mxQueue.Remove(removeName);
this.mxServers.Remove(removeName);
}
// add the new records
this.mxQueue.Add(key);
this.mxServers.Add(key, records);
return records;
}
}
}
#endregion
#endregion
#region Helper Methods
#region LookupMXRecords
/// <summary>
/// Looks up MX records from the DNS server handling empty answer
/// sections but not caching records.
/// </summary>
/// <param name="domainName">The string to look up the MX records for.</param>
/// <returns>The MX records for the given domain.</returns>
private MXRecord[] LookupMXRecords(string domainName) {
DNSMessage response = PerformQuery(QueryFactory.CreateQuery(domainName, QType.MX));
AnswerSection[] answerEntries = response.AnswerEntries;
MXRecord[] result;
CNameRecord cname = null;
if (answerEntries != null && answerEntries.Length > 0) {
// one or more records returned, add them to an array
ArrayList records = new ArrayList();
for (int i = 0; i < answerEntries.Length; i++) {
if (answerEntries[i].Record.Type == QType.MX) {
records.Add(answerEntries[i].Record);
} else if (answerEntries[i].Record.Type == QType.CNAME) {
cname = (CNameRecord) answerEntries[i].Record;
}
}
// only copy results if they exists, otherwise fall through
if (records.Count > 0) {
result = (MXRecord[]) Array.CreateInstance(typeof(MXRecord), records.Count);
Array.Copy(records.ToArray(), result, records.Count);
// Sort the list according to MX perference
Array.Sort(result);
return result;
} else if (cname != null) {
// we only got back a CNAME, look up MX records for that
return LookupMXRecords(cname.CName);
}
}
// not records returned, make one up!
result = (MXRecord[]) Array.CreateInstance(typeof(MXRecord), 1);
result[0] = new MXRecord(domainName, 10);
return result;
}
#endregion
#region LookupPTRRecords
/// <summary>
/// Looks up PTR records from the DNS server but does not cache records.
/// </summary>
/// <param name="address">The string to look up the PTR records for.</param>
/// <returns>The PTR records for the given domain.</returns>
protected PTRRecord[] LookupPTRRecords(IPAddress address) {
// Convert address to a valid lookup string
byte[] addressParts = address.GetAddressBytes();
string addressString = String.Format("{0}.{1}.{2}.{3}.in-addr.arpa",
addressParts[3], addressParts[2], addressParts[1], addressParts[0]);
DNSMessage response = PerformQuery(QueryFactory.CreateQuery(addressString, QType.PTR));
AnswerSection[] answerEntries = response.AnswerEntries;
ArrayList records = new ArrayList();
// Get all valid PTR records
for (int i = 0; i < answerEntries.Length; i++) {
if (answerEntries[i].Record.Type == QType.PTR) {
records.Add(answerEntries[i].Record);
}
}
// convert the records to an array
PTRRecord[] result = (PTRRecord[]) records.ToArray(typeof(PTRRecord));
return (result.Length > 0) ? result : null;
}
#endregion
#region LookupARecords
/// <summary>
/// Looks up address records from the DNS server but does not cache records.
/// </summary>
/// <param name="domainName">The string name to lookup the addresses for.</param>
/// <returns>The addresses records for the given domain.</returns>
protected ARecord[] LookupARecords(string domainName) {
DNSMessage response = PerformQuery(QueryFactory.CreateQuery(domainName, QType.A));
AnswerSection[] answerEntries = response.AnswerEntries;
ArrayList records = new ArrayList();
// Get all valid ARecords
for (int i = 0; i < answerEntries.Length; i++) {
if (answerEntries[i].Record.Type == QType.A) {
records.Add(answerEntries[i].Record);
}
}
// convert the records to an array
ARecord[] result = (ARecord[]) records.ToArray(typeof(ARecord));
return (result.Length > 0) ? result : null;
}
#endregion
#region ConvertToIPAddresses
/// <summary>
/// Converts an array of address records into an array of IP addresses.
/// </summary>
/// <param name="records">The records to convert.</param>
/// <returns>The IP addresses.</returns>
public IPAddress[] ConvertToIPAddresses(ARecord[] records) {
IPAddress[] result;
if (records == null) return null;
// convert the records to an array
result = (IPAddress[]) Array.CreateInstance(typeof(IPAddress), records.Length);
for (int i = 0; i < records.Length; i++) {
result[i] = records[i].Address;
}
return result;
}
#endregion
#region ConvertToHostnames
/// <summary>
/// Converts an array of PTR records into an array of hostnames.
/// </summary>
/// <param name="records">The records to convert.</param>
/// <returns>The hostnames.</returns>
public string[] ConvertToHostnames(PTRRecord[] records) {
string[] result;
if (records == null) return null;
// convert the records to an array
result = (string[]) Array.CreateInstance(typeof(string), records.Length);
for (int i = 0; i < records.Length; i++) {
result[i] = records[i].Name;
}
return result;
}
#endregion
#endregion
#region PerformQuery
/// <summary>
/// Queries a DNS server using the given DNS query.
/// </summary>
/// <param name="queryMessage">The query to send to the server.</param>
/// <returns>The response from the server.</returns>
public DNSMessage PerformQuery(DNSMessage queryMessage) {
byte[] data = queryMessage.ToByteArray();
Socket socket = new Socket(
AddressFamily.InterNetwork,
SocketType.Dgram,
ProtocolType.Udp);
foreach (IPAddress dnsServer in this.config.DnsServers) {
try {
// Send the query
IPEndPoint serverEndpoint = new IPEndPoint(dnsServer, DnsPort);
socket.SendTo(data, serverEndpoint);
// Poll for a response
for (int i = 0; i < 1000; i++) {
if (socket.Poll(0, SelectMode.SelectError)) {
// Error, try the next DNS server in the list.
continue;
} else if (socket.Poll(0, SelectMode.SelectRead)) {
EndPoint senderEndpoint = serverEndpoint;
socket.ReceiveFrom(data, ref senderEndpoint);
return new DNSMessage(data);
}
Thread.Sleep(10);
}
log.Warn("Timed out querying DNS server [" + dnsServer + "]");
} catch (SocketException e) {
log.Warn("Error querying DNS server [" + dnsServer + "]: " + e.Message);
} catch (Exception e) {
log.Error("Unexpected exception occurred querying DNS server [" + dnsServer + "]: " + e.Message);
log.Debug(e.Message, e);
}
}
log.Error("Error no DNS servers could be reached.");
// No response, return query
return queryMessage;
}
#endregion
/// <summary>
/// The port to query remote dns servers on.
/// </summary>
private const int DnsPort = 53;
}
}
|