EonaCat.DnsTester/EonaCat.DnsTester/Helpers/DnsHelper.cs

333 lines
12 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
namespace EonaCat.DnsTester.Helpers
{
class DnsHelper
{
public static event EventHandler<string> OnLog;
private static readonly Random random = new Random();
public static async Task<DnsResponse> SendDnsQueryPacket(string dnsId, string server, int port, byte[] queryBytes)
{
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
var endPoint = new IPEndPoint(IPAddress.Parse(server), port);
using (var client = new UdpClient(endPoint.AddressFamily))
{
client.DontFragment = true;
client.EnableBroadcast = false;
client.Client.SendTimeout = DnsSendTimeout;
client.Client.ReceiveTimeout = DnsReceiveTimeout;
byte[] responseBytes;
if (FakeResponse)
{
responseBytes = GetExampleResponse();
}
else
{
await client.SendAsync(queryBytes, queryBytes.Length, endPoint);
var responseResult = await client.ReceiveAsync();
responseBytes = responseResult.Buffer;
}
var response = ParseDnsResponsePacket(dnsId, stopwatch.ElapsedTicks, server, responseBytes);
return response;
}
}
// For testing purposes
public static bool FakeResponse { get; set; }
public static int DnsSendTimeout { get; set; } = 5;
public static int DnsReceiveTimeout { get; set; } = 5;
public static byte[] CreateDnsQueryPacket(string domainName, DnsRecordType recordType)
{
var id = (ushort)random.Next(0, 65536);
var flags = (ushort)0x0100; // recursion desired
var qdcount = (ushort)1;
var ancount = (ushort)0;
var nscount = (ushort)0;
var arcount = (ushort)0;
using (var stream = new MemoryStream())
using (var writer = new BinaryWriter(stream))
{
writer.Write(id);
writer.Write(flags);
writer.Write(qdcount);
writer.Write(ancount);
writer.Write(nscount);
writer.Write(arcount);
var labels = domainName.Split('.');
foreach (var label in labels)
{
writer.Write((byte)label.Length);
writer.Write(Encoding.ASCII.GetBytes(label));
}
writer.Write((byte)0); // Null terminator
writer.Write((ushort)recordType);
writer.Write((ushort)1); // Record class: IN (Internet)
return stream.ToArray();
}
}
public static byte[] GetExampleResponse()
{
// Example response bytes for the A record of google.com
return new byte[]
{
0x9d, 0xa9, // Query ID
0x81, 0x80, // Flags
0x00, 0x01, // Questions: 1
0x00, 0x01, // Answer RRs: 1
0x00, 0x00, // Authority RRs: 0
0x00, 0x00, // Additional RRs: 0
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, // Query: google
0x03, 0x63, 0x6f, 0x6d, // Query: com
0x00,
0x00, 0x01, // Record type: A
0x00, 0x01, // Record class: IN
0xc0, 0x0c, // Name pointer to google.com
0x00, 0x01, // Record type: A
0x00, 0x01, // Record class: IN
0x00, 0x00, 0x00, 0x3d, // TTL: 61 seconds
0x00, 0x04, // Data length: 4 bytes
0xac, 0xd9, 0x03, 0x3d // Data: 172.217.3.61
};
}
private static DnsResponse ParseDnsResponsePacket(string dnsId, long startTime, string server, byte[] responseBytes)
{
if (responseBytes.Length < 12)
{
throw new Exception("Invalid DNS response");
}
var offset = 0;
var id = (ushort)((responseBytes[offset++] << 8) | responseBytes[offset++]);
var flags = (ushort)((responseBytes[offset++] << 8) | responseBytes[offset++]);
var isResponse = (flags & 0x8000) != 0;
var qdcount = (ushort)((responseBytes[offset++] << 8) | responseBytes[offset++]);
if (!isResponse)
{
throw new Exception("Invalid DNS response");
}
var nscount = (ushort)((responseBytes[offset++] << 8) | responseBytes[offset++]);
var arcount = (ushort)((responseBytes[offset++] << 8) | responseBytes[offset++]);
var questions = new List<DnsQuestion>();
for (int i = 0; i < qdcount; i++)
{
var question = ParseDnsQuestionRecord(responseBytes, ref offset);
if (question != null)
{
questions.Add(question);
}
}
var answers = new List<ResourceRecord>();
for (int i = 0; i < qdcount; i++)
{
try
{
var answer = ParseDnsAnswerRecord(responseBytes, ref offset);
if (answer != null)
{
answers.Add(answer);
}
}
catch (Exception exception)
{
OnLog?.Invoke(null, $"Answer exception: {exception.Message}");
}
}
// Parse the DNS authority records
var authorities = new List<ResourceRecord>();
for (int i = 0; i < nscount; i++)
{
try
{
var authority = ParseDnsAnswerRecord(responseBytes, ref offset);
if (authority != null)
{
authorities.Add(authority);
}
}
catch (Exception exception)
{
OnLog?.Invoke(null, $"Authority answer exception: {exception.Message}");
}
}
// Parse the DNS additional records
var additionals = new List<ResourceRecord>();
for (int i = 0; i < arcount; i++)
{
try
{
var additional = ParseDnsAnswerRecord(responseBytes, ref offset);
if (additional != null)
{
additionals.Add(additional);
}
}
catch (Exception exception)
{
OnLog?.Invoke(null, $"Additional answer exception: {exception.Message}");
}
}
return new DnsResponse
{
Id = id,
StartTime = startTime,
Resolver = server,
Flags = flags,
Class = (DnsRecordClass)((flags >> 3) & 0x0f),
DnsId = dnsId,
Questions = questions,
Answers = answers,
Authorities = authorities,
Additionals = additionals,
};
}
private static DnsQuestion ParseDnsQuestionRecord(byte[] queryBytes, ref int offset)
{
var name = DnsNameParser.ParseName(queryBytes, ref offset);
if (name == null)
{
return null;
}
var type = (DnsRecordType)((queryBytes[offset++] << 8) | queryBytes[offset++]);
var qclass = (DnsRecordClass)((queryBytes[offset++] << 8) | queryBytes[offset++]);
return new DnsQuestion
{
Name = name,
Type = type,
Class = qclass,
};
}
private static ResourceRecord ParseDnsAnswerRecord(byte[] responseBytes, ref int offset)
{
var name = DnsNameParser.ExtractDomainName(responseBytes, ref offset);
if (name == null)
{
return null;
}
var type = (DnsRecordType)((responseBytes[offset++] << 8) + responseBytes[offset++]);
var klass = (DnsRecordClass)((responseBytes[offset++] << 8) + responseBytes[offset++]);
var ttl = (responseBytes[offset++] << 24) + (responseBytes[offset++] << 16) + (responseBytes[offset++] << 8) + responseBytes[offset++];
var dataLength = (responseBytes[offset++] << 8) + responseBytes[offset++];
string dataAsString = null;
switch (type)
{
case DnsRecordType.A:
if (dataLength != 4)
{
return null;
}
dataAsString = new IPAddress(responseBytes, offset).ToString();
offset += dataLength;
break;
case DnsRecordType.CNAME:
case DnsRecordType.NS:
dataAsString = DnsNameParser.ExtractDomainName(responseBytes, ref offset);
break;
case DnsRecordType.MX:
var preference = (responseBytes[offset++] << 8) + responseBytes[offset++];
var exchange = DnsNameParser.ExtractDomainName(responseBytes, ref offset);
dataAsString = $"{preference} {exchange}";
break;
case DnsRecordType.TXT:
dataAsString = Encoding.ASCII.GetString(responseBytes, offset, dataLength);
offset += dataLength;
break;
default:
offset += dataLength;
break;
}
return new ResourceRecord
{
Name = name,
Type = type,
Class = klass,
Ttl = TimeSpan.FromSeconds(ttl),
Data = dataAsString,
DataLength = (ushort)dataLength,
};
}
}
public class DnsQuestion
{
public string Name { get; set; }
public DnsRecordType Type { get; set; }
public DnsRecordClass Class { get; set; }
}
public class ResourceRecord
{
public string Name { get; set; }
public DnsRecordType Type { get; set; }
public string Data { get; set; }
public DnsRecordClass Class { get; set; }
public TimeSpan Ttl { get; set; }
public ushort DataLength { get; set; }
}
public class DnsResponse
{
public ushort Id { get; set; }
public ushort Flags { get; set; }
public DnsRecordClass Class { get; set; }
public List<ResourceRecord> Answers { get; set; }
public long StartTime { get; set; }
public string Resolver { get; set; }
public string DnsId { get; set; }
public List<DnsQuestion> Questions { get; set; }
public List<ResourceRecord> Authorities { get; set; }
public List<ResourceRecord> Additionals { get; set; }
}
public enum DnsRecordType : ushort
{
A = 1,
NS = 2,
CNAME = 6,
MX = 15,
TXT = 16,
AAAA = 28,
}
public enum DnsRecordClass : ushort
{
Internet = 1,
CS = 2,
CH = 3,
HS = 4,
None = 254,
ANY = 255
}
}