Skip to content

Commit

Permalink
Proof of concept AXFR (for the dig format only) and TSIG support.
Browse files Browse the repository at this point in the history
  • Loading branch information
ximon18 committed Jul 25, 2024
1 parent 7098eb2 commit d970e03
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 36 deletions.
9 changes: 4 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ exclude = [ ".github", ".gitignore" ]
bytes = "1"
clap = { version = "4", features = ["derive", "unstable-doc"] }
chrono = { version = "0.4.38", features = [ "alloc", "clock" ] }
domain = { version = "0.10", features = ["resolv", "unstable-client-transport"]}
domain = { version = "0.10.1", features = [ "resolv", "tsig", "unstable-client-transport" ], git = "https://github.com/NLnetLabs/domain", branch = "xfr" }
tempfile = "3.1.0"
tokio = { version = "1.33", features = ["rt-multi-thread"] }
tokio-rustls = { version = "0.26.0", default-features = false, features = [ "ring", "logging", "tls12" ] }
Expand Down
110 changes: 83 additions & 27 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ use domain::base::message_builder::MessageBuilder;
use domain::base::name::ToName;
use domain::base::question::Question;
use domain::net::client::protocol::UdpConnect;
use domain::net::client::request::{RequestMessage, SendRequest};
use domain::net::client::{dgram, stream};
use domain::net::client::request::{
ComposeRequest, GetResponse, RequestMessage, SendRequest,
};
use domain::net::client::{dgram, stream, tsig, xfr};
use domain::resolv::stub::conf;
use domain::tsig::Key;
use std::fmt;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
Expand Down Expand Up @@ -53,6 +57,7 @@ impl Client {
pub async fn query<N: ToName, Q: Into<Question<N>>>(
&self,
question: Q,
tsig_key: Option<Key>,
) -> Result<Answer, Error> {
let mut res = MessageBuilder::new_vec();

Expand All @@ -62,16 +67,20 @@ impl Client {
let mut res = res.question();
res.push(question.into()).unwrap();

self.request(RequestMessage::new(res)).await
self.request(RequestMessage::new(res), tsig_key).await
}

pub async fn request(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
) -> Result<Answer, Error> {
let mut servers = self.servers.as_slice();
while let Some((server, tail)) = servers.split_first() {
match self.request_server(request.clone(), server).await {
match self
.request_server(request.clone(), tsig_key.clone(), server)
.await
{
Ok(answer) => return Ok(answer),
Err(err) => {
if tail.is_empty() {
Expand All @@ -87,24 +96,56 @@ impl Client {
pub async fn request_server(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
server: &Server,
) -> Result<Answer, Error> {
match server.transport {
Transport::Udp => self.request_udp(request, server).await,
Transport::UdpTcp => self.request_udptcp(request, server).await,
Transport::Tcp => self.request_tcp(request, server).await,
Transport::Tls => self.request_tls(request, server).await,
Transport::Udp => {
self.request_udp(request, tsig_key, server).await
}
Transport::UdpTcp => {
self.request_udptcp(request, tsig_key, server).await
}
Transport::Tcp => {
self.request_tcp(request, tsig_key, server).await
}
Transport::Tls => {
self.request_tls(request, tsig_key, server).await
}
}
}

async fn finalize_request(
mut send_request: Box<dyn GetResponse>,
mut stats: Stats,
streaming: bool,
) -> Result<Answer, Error> {
let mut msgs = Vec::with_capacity(1);
while !send_request.is_stream_complete() {
msgs.push(send_request.get_response().await?);
if !streaming {
break;
}
}
stats.finalize();
Ok(Answer {
msgs,
stats,
cur_idx: Default::default(),
})
}

pub async fn request_udptcp(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
server: &Server,
) -> Result<Answer, Error> {
let answer = self.request_udp(request.clone(), server).await?;
if answer.message.header().tc() {
self.request_tcp(request, server).await
let answer = self
.request_udp(request.clone(), tsig_key.clone(), server)
.await?;
if answer.message().header().tc() {
self.request_tcp(request, tsig_key, server).await
} else {
Ok(answer)
}
Expand All @@ -113,38 +154,45 @@ impl Client {
pub async fn request_udp(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
server: &Server,
) -> Result<Answer, Error> {
let mut stats = Stats::new(server.addr, Protocol::Udp);
let stats = Stats::new(server.addr, Protocol::Udp);
let conn = dgram::Connection::with_config(
UdpConnect::new(server.addr),
Self::dgram_config(server),
);
let message = conn.send_request(request).get_response().await?;
stats.finalize();
Ok(Answer { message, stats })
let conn =
tsig::Connection::new(tsig_key, xfr::Connection::new(conn));
let streaming = request.is_streaming();
let send_request = conn.send_request(request);
Self::finalize_request(send_request, stats, streaming).await
}

pub async fn request_tcp(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
server: &Server,
) -> Result<Answer, Error> {
let mut stats = Stats::new(server.addr, Protocol::Tcp);
let stats = Stats::new(server.addr, Protocol::Tcp);
let socket = TcpStream::connect(server.addr).await?;
let (conn, tran) = stream::Connection::with_config(
socket,
Self::stream_config(server),
);
let conn =
tsig::Connection::new(tsig_key, xfr::Connection::new(conn));
tokio::spawn(tran.run());
let message = conn.send_request(request).get_response().await?;
stats.finalize();
Ok(Answer { message, stats })
let streaming = request.is_streaming();
let send_request = conn.send_request(request);
Self::finalize_request(send_request, stats, streaming).await
}

pub async fn request_tls(
&self,
request: RequestMessage<Vec<u8>>,
tsig_key: Option<Key>,
server: &Server,
) -> Result<Answer, Error> {
let root_store = RootCertStore {
Expand All @@ -156,7 +204,7 @@ impl Client {
.with_no_client_auth(),
);

let mut stats = Stats::new(server.addr, Protocol::Tls);
let stats = Stats::new(server.addr, Protocol::Tls);
let tcp_socket = TcpStream::connect(server.addr).await?;
let tls_connector = tokio_rustls::TlsConnector::from(client_config);
let server_name = server
Expand All @@ -174,10 +222,12 @@ impl Client {
tls_socket,
Self::stream_config(server),
);
let conn =
tsig::Connection::new(tsig_key, xfr::Connection::new(conn));
tokio::spawn(tran.run());
let message = conn.send_request(request).get_response().await?;
stats.finalize();
Ok(Answer { message, stats })
let streaming = request.is_streaming();
let send_request = conn.send_request(request);
Self::finalize_request(send_request, stats, streaming).await
}

fn dgram_config(server: &Server) -> dgram::Config {
Expand Down Expand Up @@ -230,8 +280,9 @@ impl From<conf::Transport> for Transport {

/// An answer for a query.
pub struct Answer {
message: Message<Bytes>,
msgs: Vec<Message<Bytes>>,
stats: Stats,
cur_idx: AtomicUsize,
}

impl Answer {
Expand All @@ -240,17 +291,22 @@ impl Answer {
}

pub fn message(&self) -> &Message<Bytes> {
&self.message
&self.msgs[self.cur_idx.load(Ordering::SeqCst)]
}

pub fn msg_slice(&self) -> Message<&[u8]> {
self.message.for_slice_ref()
self.msgs[self.cur_idx.load(Ordering::SeqCst)].for_slice_ref()
}

pub fn has_next(&self) -> bool {
let old_cur_idx = self.cur_idx.fetch_add(1, Ordering::SeqCst);
(old_cur_idx + 1) < self.msgs.len()
}
}

impl AsRef<Message<Bytes>> for Answer {
fn as_ref(&self) -> &Message<Bytes> {
&self.message
&self.msgs[self.cur_idx.load(Ordering::SeqCst)]
}
}

Expand Down
53 changes: 50 additions & 3 deletions src/commands/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use domain::net::client::request::{ComposeRequest, RequestMessage};
use domain::rdata::{AllRecordData, Ns, Soa};
use domain::resolv::stub::conf::ResolvConf;
use domain::resolv::stub::StubResolver;
use domain::tsig::{Algorithm, Key, KeyName};
use domain::utils::base64;
use std::collections::HashSet;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
Expand All @@ -28,7 +30,7 @@ pub struct Query {
qname: NameOrAddr,

/// The record type to look up
#[arg(value_name = "QUERY_TYPE")]
#[arg(value_name = "QUERY_TYPE", default_value = "AAAA or PTR")]
qtype: Option<Rtype>,

/// The server to send the query to. System servers used if missing
Expand Down Expand Up @@ -114,6 +116,10 @@ pub struct Query {
#[arg(long = "no-rd")]
no_rd: bool,

/// TSIG signing key to use: <name>:[<alg>]:<base64 key>
#[arg(long = "tsig-key")]
tsig_key: Option<String>,

// No need to set the TC flag in the request.
/// Disable all sanity checks.
#[arg(long, short = 'f')]
Expand Down Expand Up @@ -181,7 +187,13 @@ impl Query {
}
};

let answer = client.request(self.create_request()).await?;
let tsig_key = if let Some(key_str) = &self.tsig_key {
key_from_str(key_str)?
} else {
None
};

let answer = client.request(self.create_request(), tsig_key).await?;
self.output.format.print(&answer)?;
if self.verify {
let auth_answer = self.auth_answer().await?;
Expand All @@ -202,6 +214,41 @@ impl Query {
}
}

fn key_from_str(key_str: &str) -> Result<Option<Key>, Error> {
let key_parts = key_str
.split(':')
.map(ToString::to_string)
.collect::<Vec<String>>();
if key_parts.len() < 2 {
return Err(
"--tsig-key format error: value should be colon ':' separated"
.into(),
);
}
let key_name = key_parts[0].trim_matches('"');
let (alg, base64) = match key_parts.len() {
2 => (Algorithm::Sha256, key_parts[1].clone()),
3 => {
let alg = Algorithm::from_str(&key_parts[1])
.map_err(|_| format!("--tsig-key format error: '{}' is not a valid TSIG algorithm", key_parts[1]))?;
(alg, key_parts[2].clone())
}
_ => return Err(
"--tsig-key format error: should be <name>:[<alg>]:<base64 key>"
.into(),
),
};
let key_name = KeyName::from_str(key_name).map_err(|err| {
format!("--tsig-key format error: '{key_name}' is not a valid key name: {err}")
})?;
let secret = base64::decode::<Vec<u8>>(&base64).map_err(|err| {
format!("--tsig-key format error: base64 decoding error: {err}")
})?;
let key = Key::new(alg, &secret, key_name, None, None)
.map_err(|err| format!("--tsig-key format error: {err}"))?;
Ok(Some(key))
}

/// # Configuration
///
impl Query {
Expand Down Expand Up @@ -339,7 +386,7 @@ impl Query {
self.get_ns_addrs(&ns_set, &resolver).await?
};
Client::with_servers(servers)
.query((self.qname.to_name(), self.qtype()))
.query((self.qname.to_name(), self.qtype()), None)
.await
}

Expand Down
8 changes: 8 additions & 0 deletions src/output/dig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ pub fn write(
for item in section {
write_record_item(target, &item?)?;
}

while answer.has_next() {
let msg = &mut answer.msg_slice();
let section = msg.answer().unwrap();
for item in section {
write_record_item(target, &item?)?;
}
}
}

// Authority
Expand Down

0 comments on commit d970e03

Please sign in to comment.