From fee28fe73f8e89ede401edf9e427bab360ffbe0b Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Sun, 5 Jul 2026 02:20:42 +0900 Subject: [PATCH] Implement ML-DSA --- Cargo.lock | 77 ++++++++++ ssh-key/Cargo.toml | 4 +- ssh-key/src/algorithm.rs | 166 ++++++++++++++++++++ ssh-key/src/lib.rs | 2 +- ssh-key/src/private.rs | 14 ++ ssh-key/src/private/keypair.rs | 44 +++++- ssh-key/src/private/mldsa.rs | 271 +++++++++++++++++++++++++++++++++ ssh-key/src/public.rs | 10 ++ ssh-key/src/public/key_data.rs | 40 ++++- ssh-key/src/public/mldsa.rs | 135 ++++++++++++++++ ssh-key/src/signature.rs | 92 +++++++++++ 11 files changed, 851 insertions(+), 4 deletions(-) create mode 100644 ssh-key/src/private/mldsa.rs create mode 100644 ssh-key/src/public/mldsa.rs diff --git a/Cargo.lock b/Cargo.lock index 5dd5739..4511c62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -439,6 +439,7 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "818356c5132c1fede50f837ca96afbe78ff42413047f4abb886217845e1b6c8c" dependencies = [ + "ctutils", "subtle", "typenum", "zeroize", @@ -453,6 +454,16 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "keccak" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e24a010dd405bd7ed803e5253182815b41bf2e6a80cc3bfc066658e03a198aa" +dependencies = [ + "cfg-if", + "cpufeatures", +] + [[package]] name = "libc" version = "0.2.186" @@ -465,6 +476,34 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" +[[package]] +name = "ml-dsa" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add6b9d92e496f16f4526d68ff29da1483aba4b119baeab8bed3b9e3544a6f3d" +dependencies = [ + "crypto-common", + "ctutils", + "hybrid-array", + "module-lattice", + "pkcs8", + "shake", + "signature", + "zeroize", +] + +[[package]] +name = "module-lattice" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c61b87c9683ab7cb1c6871d261ad5479b6b10ceb52c4352aaca3b5d35a8febe" +dependencies = [ + "ctutils", + "hybrid-array", + "num-traits", + "zeroize", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -552,6 +591,16 @@ dependencies = [ "ctutils", ] +[[package]] +name = "pkcs8" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451913da69c775a56034ea8d9003d27ee8948e12443eae7c038ba100a4f21cb7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "poly1305" version = "0.9.0" @@ -745,6 +794,17 @@ dependencies = [ "digest", ] +[[package]] +name = "shake" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09057cb2149ad4cbd2da1e26b351f9a4c354219421229c69c3063e6f61947c4a" +dependencies = [ + "digest", + "keccak", + "sponge-cursor", +] + [[package]] name = "signature" version = "3.0.0" @@ -755,6 +815,22 @@ dependencies = [ "rand_core", ] +[[package]] +name = "spki" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9efca8738c78ee9484207732f728b1ef517bbb1833d6fc0879ca898a522f6f" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sponge-cursor" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a0219bd7d979d58245a4f41f695e1ac9f8befdffadd7f61f1bae9e39abc6620" + [[package]] name = "ssh-cipher" version = "0.3.0" @@ -810,6 +886,7 @@ dependencies = [ "hex", "hex-literal", "hmac", + "ml-dsa", "p256", "p384", "p521", diff --git a/ssh-key/Cargo.toml b/ssh-key/Cargo.toml index ed7683d..eec6ab7 100644 --- a/ssh-key/Cargo.toml +++ b/ssh-key/Cargo.toml @@ -32,6 +32,7 @@ dsa = { version = "0.7.0-rc.16", optional = true, default-features = false, feat ed25519-dalek = { version = "3.0.0-rc.1", optional = true, default-features = false } hex = { version = "0.4", optional = true, default-features = false, features = ["alloc"] } hmac = { version = "0.13", optional = true } +ml-dsa = { version = "0.1", optional = true, default-features = false, features = ["alloc", "zeroize"] } p256 = { version = "0.14.0-rc.15", optional = true, default-features = false, features = ["ecdsa"] } p384 = { version = "0.14.0-rc.15", optional = true, default-features = false, features = ["ecdsa"] } p521 = { version = "0.14.0-rc.15", optional = true, default-features = false, features = ["ecdsa"] } @@ -51,7 +52,7 @@ default = ["ecdsa", "rand_core", "std"] alloc = ["encoding/alloc", "signature/alloc", "zeroize/alloc", ] std = ["alloc"] -crypto = ["ed25519", "p256", "p384", "p521", "rsa"] # NOTE: `dsa` is obsolete/weak +crypto = ["ed25519", "mldsa", "p256", "p384", "p521", "rsa"] # NOTE: `dsa` is obsolete/weak dsa = ["dep:dsa", "dep:sha1", "alloc", "encoding/bigint", "signature/rand_core"] ecdsa = ["dep:sec1"] ed25519 = ["dep:ed25519-dalek", "rand_core"] @@ -63,6 +64,7 @@ encryption = [ "rand_core" ] getrandom = ["cipher/getrandom", "rand_core"] +mldsa = ["dep:ml-dsa", "alloc", "rand_core"] p256 = ["dep:p256", "ecdsa"] p384 = ["dep:p384", "ecdsa"] p521 = ["dep:p521", "ecdsa"] diff --git a/ssh-key/src/algorithm.rs b/ssh-key/src/algorithm.rs index ac64492..e2aba01 100644 --- a/ssh-key/src/algorithm.rs +++ b/ssh-key/src/algorithm.rs @@ -47,6 +47,15 @@ const CERT_SK_ECDSA_SHA2_P256: &str = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.c /// OpenSSH certificate for Ed25519 U2F/FIDO security key const CERT_SK_SSH_ED25519: &str = "sk-ssh-ed25519-cert-v01@openssh.com"; +/// OpenSSH certificate for ML-DSA-44 public key +const CERT_MLDSA_44: &str = "ssh-mldsa-44-cert-v01@openssh.com"; + +/// OpenSSH certificate for ML-DSA-65 public key +const CERT_MLDSA_65: &str = "ssh-mldsa-65-cert-v01@openssh.com"; + +/// OpenSSH certificate for ML-DSA-87 public key +const CERT_MLDSA_87: &str = "ssh-mldsa-87-cert-v01@openssh.com"; + /// ECDSA with SHA-256 + NIST P-256 const ECDSA_SHA2_P256: &str = "ecdsa-sha2-nistp256"; @@ -86,6 +95,15 @@ const SK_ECDSA_SHA2_P256: &str = "sk-ecdsa-sha2-nistp256@openssh.com"; /// U2F/FIDO security key with Ed25519 const SK_SSH_ED25519: &str = "sk-ssh-ed25519@openssh.com"; +/// ML-DSA-44 (FIPS 204, security category 2) +const SSH_MLDSA_44: &str = "ssh-mldsa-44"; + +/// ML-DSA-65 (FIPS 204, security category 3) +const SSH_MLDSA_65: &str = "ssh-mldsa-65"; + +/// ML-DSA-87 (FIPS 204, security category 5) +const SSH_MLDSA_87: &str = "ssh-mldsa-87"; + /// SSH key algorithms, i.e. digital signature algorithms used with SSH private/public keys. #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] #[non_exhaustive] @@ -121,6 +139,12 @@ pub enum Algorithm { /// FIDO/U2F key with Ed25519 SkEd25519, + /// ML-DSA + MlDsa { + /// ML-DSA parameter set to use. + params: MlDsaParams, + }, + /// Other #[cfg(feature = "alloc")] Other(AlgorithmName), @@ -138,6 +162,9 @@ impl Algorithm { /// - `ssh-rsa` /// - `sk-ecdsa-sha2-nistp256@openssh.com` (FIDO/U2F key) /// - `sk-ssh-ed25519@openssh.com` (FIDO/U2F key) + /// - `ssh-mldsa-44` + /// - `ssh-mldsa-65` + /// - `ssh-mldsa-87` /// /// Any other algorithms are mapped to the [`Algorithm::Other`] variant. /// @@ -161,6 +188,9 @@ impl Algorithm { /// - `ssh-ed25519-cert-v01@openssh.com` /// - `sk-ecdsa-sha2-nistp256-cert-v01@openssh.com` (FIDO/U2F key) /// - `sk-ssh-ed25519-cert-v01@openssh.com` (FIDO/U2F key) + /// - `ssh-mldsa-44-cert-v01@openssh.com` + /// - `ssh-mldsa-65-cert-v01@openssh.com` + /// - `ssh-mldsa-87-cert-v01@openssh.com` /// /// Any other algorithms are mapped to the [`Algorithm::Other`] variant. /// @@ -190,6 +220,15 @@ impl Algorithm { }), CERT_SK_ECDSA_SHA2_P256 => Ok(Algorithm::SkEcdsaSha2NistP256), CERT_SK_SSH_ED25519 => Ok(Algorithm::SkEd25519), + CERT_MLDSA_44 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa44, + }), + CERT_MLDSA_65 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa65, + }), + CERT_MLDSA_87 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa87, + }), #[cfg(feature = "alloc")] _ => Ok(Algorithm::Other(AlgorithmName::from_certificate_type(id)?)), #[cfg(not(feature = "alloc"))] @@ -215,6 +254,11 @@ impl Algorithm { }, Algorithm::SkEcdsaSha2NistP256 => SK_ECDSA_SHA2_P256, Algorithm::SkEd25519 => SK_SSH_ED25519, + Algorithm::MlDsa { params } => match params { + MlDsaParams::MlDsa44 => SSH_MLDSA_44, + MlDsaParams::MlDsa65 => SSH_MLDSA_65, + MlDsaParams::MlDsa87 => SSH_MLDSA_87, + }, #[cfg(feature = "alloc")] Algorithm::Other(algorithm) => algorithm.as_str(), } @@ -247,6 +291,11 @@ impl Algorithm { } => CERT_RSA_SHA2_512, Algorithm::SkEcdsaSha2NistP256 => CERT_SK_ECDSA_SHA2_P256, Algorithm::SkEd25519 => CERT_SK_SSH_ED25519, + Algorithm::MlDsa { params } => match params { + MlDsaParams::MlDsa44 => CERT_MLDSA_44, + MlDsaParams::MlDsa65 => CERT_MLDSA_65, + MlDsaParams::MlDsa87 => CERT_MLDSA_87, + }, Algorithm::Other(algorithm) => return algorithm.certificate_type(), } .to_owned() @@ -276,6 +325,12 @@ impl Algorithm { matches!(self, Algorithm::Rsa { .. }) } + /// Is the algorithm ML-DSA? + #[must_use] + pub fn is_mldsa(self) -> bool { + matches!(self, Algorithm::MlDsa { .. }) + } + /// Return an error indicating this algorithm is unsupported. #[allow(dead_code)] pub(crate) fn unsupported_error(self) -> Error { @@ -322,6 +377,15 @@ impl str::FromStr for Algorithm { SSH_RSA => Ok(Algorithm::Rsa { hash: None }), SK_ECDSA_SHA2_P256 => Ok(Algorithm::SkEcdsaSha2NistP256), SK_SSH_ED25519 => Ok(Algorithm::SkEd25519), + SSH_MLDSA_44 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa44, + }), + SSH_MLDSA_65 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa65, + }), + SSH_MLDSA_87 => Ok(Algorithm::MlDsa { + params: MlDsaParams::MlDsa87, + }), #[cfg(feature = "alloc")] _ => Ok(Algorithm::Other(AlgorithmName::from_str(id)?)), #[cfg(not(feature = "alloc"))] @@ -412,6 +476,108 @@ impl str::FromStr for EcdsaCurve { } } +/// ML-DSA parameter sets supported for use with SSH as specified in [FIPS204]. +/// +/// Each parameter set corresponds to a NIST security category. +/// +/// [FIPS204]: https://csrc.nist.gov/pubs/fips/204/final +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub enum MlDsaParams { + /// ML-DSA-44 (security category 2). + MlDsa44, + + /// ML-DSA-65 (security category 3), the recommended parameter set. + MlDsa65, + + /// ML-DSA-87 (security category 5). + MlDsa87, +} + +impl MlDsaParams { + /// Decode an ML-DSA parameter set from the given SSH algorithm identifier. + /// + /// # Supported identifiers + /// + /// - `ssh-mldsa-44` + /// - `ssh-mldsa-65` + /// - `ssh-mldsa-87` + /// + /// # Errors + /// Returns [`Error::Encoding`] in the event the identifier is not known. + pub fn new(id: &str) -> Result { + Ok(id.parse()?) + } + + /// Get the SSH algorithm identifier which corresponds to this parameter set. + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + MlDsaParams::MlDsa44 => SSH_MLDSA_44, + MlDsaParams::MlDsa65 => SSH_MLDSA_65, + MlDsaParams::MlDsa87 => SSH_MLDSA_87, + } + } + + /// Size in bytes of a FIPS 204 public key for this parameter set. + #[must_use] + pub const fn public_key_size(self) -> usize { + match self { + MlDsaParams::MlDsa44 => 1312, + MlDsaParams::MlDsa65 => 1952, + MlDsaParams::MlDsa87 => 2592, + } + } + + /// Size in bytes of a FIPS 204 signature for this parameter set. + #[must_use] + pub const fn signature_size(self) -> usize { + match self { + MlDsaParams::MlDsa44 => 2420, + MlDsaParams::MlDsa65 => 3309, + MlDsaParams::MlDsa87 => 4627, + } + } + + /// Size in bytes of the seed (ξ) used to derive an ML-DSA key. + /// + /// This is 32 bytes for all parameter sets. + #[must_use] + pub const fn seed_size(self) -> usize { + 32 + } +} + +impl AsRef for MlDsaParams { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl From for Algorithm { + fn from(params: MlDsaParams) -> Algorithm { + Algorithm::MlDsa { params } + } +} + +impl fmt::Display for MlDsaParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl str::FromStr for MlDsaParams { + type Err = LabelError; + + fn from_str(id: &str) -> core::result::Result { + match id { + SSH_MLDSA_44 => Ok(MlDsaParams::MlDsa44), + SSH_MLDSA_65 => Ok(MlDsaParams::MlDsa65), + SSH_MLDSA_87 => Ok(MlDsaParams::MlDsa87), + _ => Err(LabelError::new(id)), + } + } +} + /// Hashing algorithms a.k.a. digest functions. #[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] #[non_exhaustive] diff --git a/ssh-key/src/lib.rs b/ssh-key/src/lib.rs index 820e701..c68f1ee 100644 --- a/ssh-key/src/lib.rs +++ b/ssh-key/src/lib.rs @@ -147,7 +147,7 @@ mod signature; mod sshsig; pub use crate::{ - algorithm::{Algorithm, AssociatedHashAlg, EcdsaCurve, HashAlg, KdfAlg}, + algorithm::{Algorithm, AssociatedHashAlg, EcdsaCurve, HashAlg, KdfAlg, MlDsaParams}, authorized_keys::AuthorizedKeys, error::{Error, Result}, fingerprint::Fingerprint, diff --git a/ssh-key/src/private.rs b/ssh-key/src/private.rs index a7bf4a2..e3f0eaa 100644 --- a/ssh-key/src/private.rs +++ b/ssh-key/src/private.rs @@ -100,6 +100,8 @@ mod ecdsa; mod ed25519; mod keypair; #[cfg(feature = "alloc")] +mod mldsa; +#[cfg(feature = "alloc")] mod opaque; #[cfg(feature = "alloc")] mod rsa; @@ -116,6 +118,7 @@ pub use crate::{ Comment, SshSig, private::{ dsa::{DsaKeypair, DsaPrivateKey}, + mldsa::{MlDsaKeypair, MlDsaPrivateKey}, opaque::{OpaqueKeypair, OpaqueKeypairBytes, OpaquePrivateKeyBytes}, rsa::{RsaKeypair, RsaPrivateKey}, sk::SkEd25519, @@ -622,6 +625,8 @@ impl PrivateKey { Algorithm::Rsa { .. } => { KeypairData::from(RsaKeypair::random(rng, DEFAULT_RSA_KEY_SIZE)?) } + #[cfg(feature = "mldsa")] + Algorithm::MlDsa { params } => KeypairData::from(MlDsaKeypair::random(rng, params)?), _ => return Err(Error::AlgorithmUnknown), }; let public_key = public::KeyData::try_from(&key_data)?; @@ -1002,6 +1007,15 @@ impl From for PrivateKey { } } +#[cfg(feature = "alloc")] +impl From for PrivateKey { + fn from(keypair: MlDsaKeypair) -> PrivateKey { + KeypairData::from(keypair) + .try_into() + .expect(CONVERSION_ERROR_MSG) + } +} + #[cfg(all(feature = "alloc", feature = "ecdsa"))] impl From for PrivateKey { fn from(keypair: SkEcdsaSha2NistP256) -> PrivateKey { diff --git a/ssh-key/src/private/keypair.rs b/ssh-key/src/private/keypair.rs index f187bba..674f68b 100644 --- a/ssh-key/src/private/keypair.rs +++ b/ssh-key/src/private/keypair.rs @@ -7,7 +7,7 @@ use encoding::{CheckedSum, Decode, Encode, Reader, Writer}; #[cfg(feature = "alloc")] use { - super::{DsaKeypair, OpaqueKeypair, RsaKeypair, SkEd25519}, + super::{DsaKeypair, MlDsaKeypair, OpaqueKeypair, RsaKeypair, SkEd25519}, alloc::vec::Vec, }; @@ -44,6 +44,10 @@ pub enum KeypairData { #[cfg(feature = "alloc")] Rsa(RsaKeypair), + /// ML-DSA (FIPS 204) keypair. + #[cfg(feature = "alloc")] + MlDsa(MlDsaKeypair), + /// Security Key (FIDO/U2F) using ECDSA/NIST P-256 as specified in [PROTOCOL.u2f]. /// /// [PROTOCOL.u2f]: https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.u2f?annotate=HEAD @@ -77,6 +81,8 @@ impl KeypairData { Self::Encrypted(_) => return Err(Error::Encrypted), #[cfg(feature = "alloc")] Self::Rsa(_) => Algorithm::Rsa { hash: None }, + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.algorithm(), #[cfg(all(feature = "alloc", feature = "ecdsa"))] Self::SkEcdsaSha2NistP256(_) => Algorithm::SkEcdsaSha2NistP256, #[cfg(feature = "alloc")] @@ -136,6 +142,16 @@ impl KeypairData { } } + /// Get ML-DSA keypair if this key is the correct type. + #[cfg(feature = "alloc")] + #[must_use] + pub fn mldsa(&self) -> Option<&MlDsaKeypair> { + match self { + Self::MlDsa(key) => Some(key), + _ => None, + } + } + /// Get FIDO/U2F ECDSA/NIST P-256 private key if this key is the correct type. #[cfg(all(feature = "alloc", feature = "ecdsa"))] #[must_use] @@ -207,6 +223,13 @@ impl KeypairData { matches!(self, Self::Rsa(_)) } + /// Is this key an ML-DSA key? + #[cfg(feature = "alloc")] + #[must_use] + pub fn is_mldsa(&self) -> bool { + matches!(self, Self::MlDsa(_)) + } + /// Is this key a FIDO/U2F ECDSA/NIST P-256 key? #[cfg(all(feature = "alloc", feature = "ecdsa"))] #[must_use] @@ -243,6 +266,8 @@ impl KeypairData { Self::Encrypted(ciphertext) => ciphertext.as_ref(), #[cfg(feature = "alloc")] Self::Rsa(rsa) => rsa.private().d().as_bytes(), + #[cfg(feature = "alloc")] + Self::MlDsa(mldsa) => mldsa.private.as_ref(), #[cfg(all(feature = "alloc", feature = "ecdsa"))] Self::SkEcdsaSha2NistP256(sk) => sk.key_handle(), #[cfg(feature = "alloc")] @@ -278,6 +303,8 @@ impl KeypairData { Algorithm::Ed25519 => Ed25519Keypair::decode(reader).map(Self::Ed25519), #[cfg(feature = "alloc")] Algorithm::Rsa { .. } => RsaKeypair::decode(reader).map(Self::Rsa), + #[cfg(feature = "alloc")] + Algorithm::MlDsa { params } => MlDsaKeypair::decode_as(reader, params).map(Self::MlDsa), #[cfg(all(feature = "alloc", feature = "ecdsa"))] Algorithm::SkEcdsaSha2NistP256 => { SkEcdsaSha2NistP256::decode(reader).map(Self::SkEcdsaSha2NistP256) @@ -307,6 +334,8 @@ impl CtEq for KeypairData { (Self::Encrypted(a), Self::Encrypted(b)) => a.ct_eq(b), #[cfg(feature = "alloc")] (Self::Rsa(a), Self::Rsa(b)) => a.ct_eq(b), + #[cfg(feature = "alloc")] + (Self::MlDsa(a), Self::MlDsa(b)) => a.ct_eq(b), #[cfg(all(feature = "alloc", feature = "ecdsa"))] (Self::SkEcdsaSha2NistP256(a), Self::SkEcdsaSha2NistP256(b)) => { // Security Keys store the actual private key in hardware. @@ -363,6 +392,8 @@ impl Encode for KeypairData { Self::Encrypted(ciphertext) => return Ok(ciphertext.len()), #[cfg(feature = "alloc")] Self::Rsa(key) => key.encoded_len()?, + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.encoded_len()?, #[cfg(all(feature = "alloc", feature = "ecdsa"))] Self::SkEcdsaSha2NistP256(sk) => sk.encoded_len()?, #[cfg(feature = "alloc")] @@ -389,6 +420,8 @@ impl Encode for KeypairData { Self::Encrypted(ciphertext) => writer.write(ciphertext)?, #[cfg(feature = "alloc")] Self::Rsa(key) => key.encode(writer)?, + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.encode(writer)?, #[cfg(all(feature = "alloc", feature = "ecdsa"))] Self::SkEcdsaSha2NistP256(sk) => sk.encode(writer)?, #[cfg(feature = "alloc")] @@ -415,6 +448,8 @@ impl TryFrom<&KeypairData> for public::KeyData { KeypairData::Encrypted(_) => return Err(Error::Encrypted), #[cfg(feature = "alloc")] KeypairData::Rsa(rsa) => public::KeyData::Rsa(rsa.into()), + #[cfg(feature = "alloc")] + KeypairData::MlDsa(mldsa) => public::KeyData::MlDsa(mldsa.into()), #[cfg(all(feature = "alloc", feature = "ecdsa"))] KeypairData::SkEcdsaSha2NistP256(sk) => { public::KeyData::SkEcdsaSha2NistP256(sk.public().clone()) @@ -454,6 +489,13 @@ impl From for KeypairData { } } +#[cfg(feature = "alloc")] +impl From for KeypairData { + fn from(keypair: MlDsaKeypair) -> KeypairData { + Self::MlDsa(keypair) + } +} + #[cfg(all(feature = "alloc", feature = "ecdsa"))] impl From for KeypairData { fn from(keypair: SkEcdsaSha2NistP256) -> KeypairData { diff --git a/ssh-key/src/private/mldsa.rs b/ssh-key/src/private/mldsa.rs new file mode 100644 index 0000000..afa7674 --- /dev/null +++ b/ssh-key/src/private/mldsa.rs @@ -0,0 +1,271 @@ +//! ML-DSA private keys. +//! +//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in [FIPS204]. +//! +//! [FIPS204]: https://csrc.nist.gov/pubs/fips/204/final + +use crate::{Algorithm, Error, MlDsaParams, Result, public::MlDsaPublicKey}; +use core::fmt; +use ctutils::{Choice, CtEq}; +use encoding::{CheckedSum, Encode, Reader, Writer}; +use zeroize::{Zeroize, Zeroizing}; + +#[cfg(feature = "mldsa")] +use alloc::vec::Vec; + +#[cfg(feature = "rand_core")] +use rand_core::CryptoRng; + +/// Size of an ML-DSA seed in bytes. This is the same for all parameter sets. +const SEED_SIZE: usize = 32; + +/// ML-DSA private key. +/// This is the seed representation, not the expanded private key. +#[derive(Clone)] +pub struct MlDsaPrivateKey([u8; SEED_SIZE]); + +impl MlDsaPrivateKey { + /// Size of an ML-DSA seed in bytes. + pub const BYTE_SIZE: usize = SEED_SIZE; + + /// Generate a random ML-DSA seed. + #[cfg(feature = "rand_core")] + pub fn random(rng: &mut R) -> Self { + let mut seed = [0u8; SEED_SIZE]; + rng.fill_bytes(&mut seed); + Self(seed) + } + + /// Parse an ML-DSA seed from bytes. + #[must_use] + pub fn from_bytes(bytes: &[u8; SEED_SIZE]) -> Self { + Self(*bytes) + } + + /// Convert to the inner seed byte array. + #[must_use] + pub fn to_bytes(&self) -> [u8; SEED_SIZE] { + self.0 + } +} + +impl AsRef<[u8; SEED_SIZE]> for MlDsaPrivateKey { + fn as_ref(&self) -> &[u8; SEED_SIZE] { + &self.0 + } +} + +impl CtEq for MlDsaPrivateKey { + fn ct_eq(&self, other: &Self) -> Choice { + self.as_ref().ct_eq(other.as_ref()) + } +} + +impl Eq for MlDsaPrivateKey {} + +impl PartialEq for MlDsaPrivateKey { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl TryFrom<&[u8]> for MlDsaPrivateKey { + type Error = Error; + + fn try_from(bytes: &[u8]) -> Result { + Ok(MlDsaPrivateKey::from_bytes(bytes.try_into()?)) + } +} + +impl fmt::Debug for MlDsaPrivateKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MlDsaPrivateKey").finish_non_exhaustive() + } +} + +impl Drop for MlDsaPrivateKey { + fn drop(&mut self) { + self.0.zeroize(); + } +} + +/// ML-DSA private/public keypair. +/// +/// The SSH encoding of the keypair consists of the [`MlDsaPublicKey`] followed +/// by the 32-byte seed encoded as an SSH `string`. +#[derive(Clone)] +pub struct MlDsaKeypair { + /// Public key. + pub public: MlDsaPublicKey, + + /// Private key (seed). + pub private: MlDsaPrivateKey, +} + +impl MlDsaKeypair { + /// Get the [`MlDsaParams`] parameter set for this keypair. + #[must_use] + pub fn params(&self) -> MlDsaParams { + self.public.params() + } + + /// Get the [`Algorithm`] for this keypair. + #[must_use] + pub fn algorithm(&self) -> Algorithm { + self.public.algorithm() + } + + /// Decode an ML-DSA keypair for the given parameter set. + /// + /// The parameter set is not encoded in the key body; it is taken from the + /// SSH algorithm identifier. + /// + /// # Errors + /// - Returns [`Error::Encoding`] in the event of an encoding error. + /// - Returns [`Error::PublicKey`] if the encoded public key does not match + /// the key derived from the seed (only checked when the `mldsa` feature is + /// enabled). + pub(crate) fn decode_as(reader: &mut impl Reader, params: MlDsaParams) -> Result { + let public = MlDsaPublicKey::decode_as(reader, params)?; + + let mut seed = Zeroizing::new([0u8; SEED_SIZE]); + reader.read_prefixed(|reader| reader.read(&mut *seed))?; + + let keypair = Self { + public, + private: MlDsaPrivateKey::from_bytes(&seed), + }; + + #[cfg(feature = "mldsa")] + keypair.validate()?; + + Ok(keypair) + } +} + +impl CtEq for MlDsaKeypair { + fn ct_eq(&self, other: &Self) -> Choice { + Choice::from(u8::from(self.public == other.public)) & self.private.ct_eq(&other.private) + } +} + +impl Eq for MlDsaKeypair {} + +impl PartialEq for MlDsaKeypair { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Encode for MlDsaKeypair { + fn encoded_len(&self) -> encoding::Result { + [self.public.encoded_len()?, 4, SEED_SIZE].checked_sum() + } + + fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> { + self.public.encode(writer)?; + Zeroizing::new(self.private.to_bytes()) + .as_slice() + .encode(writer)?; + Ok(()) + } +} + +impl fmt::Debug for MlDsaKeypair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MlDsaKeypair") + .field("public", &self.public) + .finish_non_exhaustive() + } +} + +impl From for MlDsaPublicKey { + fn from(keypair: MlDsaKeypair) -> MlDsaPublicKey { + keypair.public + } +} + +impl From<&MlDsaKeypair> for MlDsaPublicKey { + fn from(keypair: &MlDsaKeypair) -> MlDsaPublicKey { + keypair.public.clone() + } +} + +/// Derive the raw FIPS 204 public key bytes from a seed for the concrete parameter set `P`. +#[cfg(feature = "mldsa")] +fn derive_public(seed: &[u8; SEED_SIZE]) -> Vec { + use signature::Keypair; + + let seed = ml_dsa::B32::from(*seed); + let signing_key = ml_dsa::SigningKey::

::from_seed(&seed); + signing_key.verifying_key().encode().as_slice().to_vec() +} + +/// Sign a message with "pure" ML-DSA (empty context) for the concrete parameter set `P`. +#[cfg(feature = "mldsa")] +fn sign_with_params(seed: &[u8; SEED_SIZE], msg: &[u8]) -> Result> { + use signature::Signer; + + let seed = ml_dsa::B32::from(*seed); + let signing_key = ml_dsa::SigningKey::

::from_seed(&seed); + + // The `Signer` impl uses "pure" ML-DSA with an empty context string, as + // required by draft-sfluhrer-ssh-mldsa. + let signature = signing_key.try_sign(msg)?; + Ok(signature.encode().as_slice().to_vec()) +} + +#[cfg(feature = "mldsa")] +impl MlDsaKeypair { + /// Generate a random ML-DSA keypair for the given parameter set. + /// + /// # Errors + /// Returns [`Error::Encoding`] if key derivation produces a malformed key + /// (should not occur for valid parameter sets). + pub fn random(rng: &mut R, params: MlDsaParams) -> Result { + Self::from_seed(params, &MlDsaPrivateKey::random(rng).to_bytes()) + } + + /// Derive an ML-DSA keypair from a 32-byte seed (ξ) for the given parameter set. + /// + /// # Errors + /// Returns [`Error::Encoding`] if key derivation produces a malformed key + /// (should not occur for valid parameter sets). + pub fn from_seed(params: MlDsaParams, seed: &[u8; SEED_SIZE]) -> Result { + let key = match params { + MlDsaParams::MlDsa44 => derive_public::(seed), + MlDsaParams::MlDsa65 => derive_public::(seed), + MlDsaParams::MlDsa87 => derive_public::(seed), + }; + + Ok(Self { + public: MlDsaPublicKey::new(params, key)?, + private: MlDsaPrivateKey::from_bytes(seed), + }) + } + + /// Sign a message, producing the raw ML-DSA signature bytes. + pub(crate) fn sign_msg(&self, msg: &[u8]) -> Result> { + let seed = self.private.as_ref(); + match self.public.params() { + MlDsaParams::MlDsa44 => sign_with_params::(seed, msg), + MlDsaParams::MlDsa65 => sign_with_params::(seed, msg), + MlDsaParams::MlDsa87 => sign_with_params::(seed, msg), + } + } + + /// Verify that the stored public key matches the key derived from the seed. + fn validate(&self) -> Result<()> { + let expected = match self.public.params() { + MlDsaParams::MlDsa44 => derive_public::(self.private.as_ref()), + MlDsaParams::MlDsa65 => derive_public::(self.private.as_ref()), + MlDsaParams::MlDsa87 => derive_public::(self.private.as_ref()), + }; + + if expected.as_slice() == self.public.as_bytes() { + Ok(()) + } else { + Err(Error::PublicKey) + } + } +} diff --git a/ssh-key/src/public.rs b/ssh-key/src/public.rs index 9dd0b10..68707a6 100644 --- a/ssh-key/src/public.rs +++ b/ssh-key/src/public.rs @@ -9,6 +9,8 @@ mod ecdsa; mod ed25519; mod key_data; #[cfg(feature = "alloc")] +mod mldsa; +#[cfg(feature = "alloc")] mod opaque; #[cfg(feature = "alloc")] mod rsa; @@ -20,6 +22,7 @@ pub use self::{ed25519::Ed25519PublicKey, key_data::KeyData, sk::SkEd25519}; #[cfg(feature = "alloc")] pub use self::{ dsa::DsaPublicKey, + mldsa::MlDsaPublicKey, opaque::{OpaquePublicKey, OpaquePublicKeyBytes}, rsa::RsaPublicKey, }; @@ -448,6 +451,13 @@ impl From for PublicKey { } } +#[cfg(feature = "alloc")] +impl From for PublicKey { + fn from(public_key: MlDsaPublicKey) -> PublicKey { + KeyData::from(public_key).into() + } +} + #[cfg(feature = "alloc")] impl From for PublicKey { fn from(public_key: RsaPublicKey) -> PublicKey { diff --git a/ssh-key/src/public/key_data.rs b/ssh-key/src/public/key_data.rs index 5924ca9..c2094fb 100644 --- a/ssh-key/src/public/key_data.rs +++ b/ssh-key/src/public/key_data.rs @@ -6,7 +6,7 @@ use encoding::{CheckedSum, Decode, Encode, Reader, Writer}; #[cfg(feature = "alloc")] use { - super::{DsaPublicKey, OpaquePublicKey, RsaPublicKey}, + super::{DsaPublicKey, MlDsaPublicKey, OpaquePublicKey, RsaPublicKey}, crate::Certificate, alloc::boxed::Box, }; @@ -33,6 +33,10 @@ pub enum KeyData { #[cfg(feature = "alloc")] Rsa(RsaPublicKey), + /// ML-DSA public key data.q + #[cfg(feature = "alloc")] + MlDsa(MlDsaPublicKey), + /// Security Key (FIDO/U2F) using ECDSA/NIST P-256 as specified in [PROTOCOL.u2f]. /// /// [PROTOCOL.u2f]: https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL.u2f?annotate=HEAD @@ -69,6 +73,8 @@ impl KeyData { Self::Ed25519(_) => Algorithm::Ed25519, #[cfg(feature = "alloc")] Self::Rsa(_) => Algorithm::Rsa { hash: None }, + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.algorithm(), #[cfg(feature = "ecdsa")] Self::SkEcdsaSha2NistP256(_) => Algorithm::SkEcdsaSha2NistP256, Self::SkEd25519(_) => Algorithm::SkEd25519, @@ -127,6 +133,16 @@ impl KeyData { } } + /// Get ML-DSA public key if this key is the correct type. + #[cfg(feature = "alloc")] + #[must_use] + pub fn mldsa(&self) -> Option<&MlDsaPublicKey> { + match self { + Self::MlDsa(key) => Some(key), + _ => None, + } + } + /// Get FIDO/U2F ECDSA/NIST P-256 public key if this key is the correct type. #[cfg(feature = "ecdsa")] #[must_use] @@ -203,6 +219,13 @@ impl KeyData { matches!(self, Self::Rsa(_)) } + /// Is this key an ML-DSA key? + #[cfg(feature = "alloc")] + #[must_use] + pub fn is_mldsa(&self) -> bool { + matches!(self, Self::MlDsa(_)) + } + /// Is this key a FIDO/U2F ECDSA/NIST P-256 key? #[cfg(feature = "ecdsa")] #[must_use] @@ -248,6 +271,10 @@ impl KeyData { Algorithm::Ed25519 => Ed25519PublicKey::decode(reader).map(Self::Ed25519), #[cfg(feature = "alloc")] Algorithm::Rsa { .. } => RsaPublicKey::decode(reader).map(Self::Rsa), + #[cfg(feature = "alloc")] + Algorithm::MlDsa { params } => { + MlDsaPublicKey::decode_as(reader, params).map(Self::MlDsa) + } #[cfg(feature = "ecdsa")] Algorithm::SkEcdsaSha2NistP256 => { SkEcdsaSha2NistP256::decode(reader).map(Self::SkEcdsaSha2NistP256) @@ -280,6 +307,8 @@ impl KeyData { Self::Ed25519(key) => key.encoded_len(), #[cfg(feature = "alloc")] Self::Rsa(key) => key.encoded_len(), + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.encoded_len(), #[cfg(feature = "ecdsa")] Self::SkEcdsaSha2NistP256(sk) => sk.encoded_len(), Self::SkEd25519(sk) => sk.encoded_len(), @@ -300,6 +329,8 @@ impl KeyData { Self::Ed25519(key) => key.encode(writer), #[cfg(feature = "alloc")] Self::Rsa(key) => key.encode(writer), + #[cfg(feature = "alloc")] + Self::MlDsa(key) => key.encode(writer), #[cfg(feature = "ecdsa")] Self::SkEcdsaSha2NistP256(sk) => sk.encode(writer), Self::SkEd25519(sk) => sk.encode(writer), @@ -380,6 +411,13 @@ impl From for KeyData { } } +#[cfg(feature = "alloc")] +impl From for KeyData { + fn from(public_key: MlDsaPublicKey) -> KeyData { + Self::MlDsa(public_key) + } +} + #[cfg(feature = "ecdsa")] impl From for KeyData { fn from(public_key: SkEcdsaSha2NistP256) -> KeyData { diff --git a/ssh-key/src/public/mldsa.rs b/ssh-key/src/public/mldsa.rs new file mode 100644 index 0000000..e6f855b --- /dev/null +++ b/ssh-key/src/public/mldsa.rs @@ -0,0 +1,135 @@ +//! ML-DSA public keys. +//! +//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in [FIPS204]. +//! +//! [FIPS204]: https://csrc.nist.gov/pubs/fips/204/final + +use crate::{Algorithm, MlDsaParams, Result}; +use alloc::vec::Vec; +use encoding::{Decode, Encode, Reader, Writer}; + +#[cfg(feature = "mldsa")] +use crate::Error; + +/// ML-DSA public key. +/// +/// SSH encodings for ML-DSA public keys are described in +/// [draft-sfluhrer-ssh-mldsa]: +/// +/// +/// Here, 'key' is the public key as described in [FIPS204]. +/// +/// This type represents the `key` portion of the encoding together with the +/// [`MlDsaParams`] parameter set which is derived from the algorithm name. +/// +/// [draft-sfluhrer-ssh-mldsa]: https://datatracker.ietf.org/doc/draft-sfluhrer-ssh-mldsa/ +/// [FIPS204]: https://csrc.nist.gov/pubs/fips/204/final +#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct MlDsaPublicKey { + /// ML-DSA parameter set. + params: MlDsaParams, + + /// Raw FIPS 204 public key bytes. + key: Vec, +} + +impl MlDsaPublicKey { + /// Create a new ML-DSA public key from raw public key bytes + pub fn new(params: MlDsaParams, key: impl Into>) -> Result { + let key = key.into(); + + if key.len() != params.public_key_size() { + return Err(encoding::Error::Length.into()); + } + + Ok(Self { params, key }) + } + + /// Get the [`MlDsaParams`] parameter set for this public key. + #[must_use] + pub fn params(&self) -> MlDsaParams { + self.params + } + + /// Get the [`Algorithm`] for this public key. + #[must_use] + pub fn algorithm(&self) -> Algorithm { + Algorithm::MlDsa { + params: self.params, + } + } + + /// Get the raw [FIPS204] public key bytes. + #[must_use] + pub fn as_bytes(&self) -> &[u8] { + &self.key + } + + /// Decode an ML-DSA public key for the given parameter set. + /// + /// The parameter set is not encoded in the key body; it is taken from the + /// SSH algorithm identifier (see [`MlDsaPublicKey`]). + pub(crate) fn decode_as(reader: &mut impl Reader, params: MlDsaParams) -> Result { + let key = Vec::decode(reader)?; + Self::new(params, key) + } +} + +impl AsRef<[u8]> for MlDsaPublicKey { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +impl Encode for MlDsaPublicKey { + fn encoded_len(&self) -> encoding::Result { + self.key.encoded_len() + } + + fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> { + self.key.encode(writer) + } +} + +/// Verify an ML-DSA signature over `msg` with the concrete ML-DSA parameter set `P`. +#[cfg(feature = "mldsa")] +fn verify_with_params( + key: &[u8], + msg: &[u8], + signature: &[u8], +) -> Result<()> { + use signature::Verifier; + + let encoded = ml_dsa::EncodedVerifyingKey::

::try_from(key).map_err(|_| Error::PublicKey)?; + let verifying_key = ml_dsa::VerifyingKey::

::decode(&encoded); + let signature = ml_dsa::Signature::

::try_from(signature).map_err(|_| Error::Signature)?; + + // The `Verifier` impl uses "pure" ML-DSA with an empty context string, as + // required by draft-sfluhrer-ssh-mldsa. + verifying_key + .verify(msg, &signature) + .map_err(|_| Error::Signature) +} + +#[cfg(feature = "mldsa")] +impl MlDsaPublicKey { + /// Verify a raw ML-DSA signature over the given message using "pure" ML-DSA + /// with an empty context string. + /// + /// # Errors + /// Returns [`Error::Signature`] if the signature is invalid, or + /// [`Error::PublicKey`] if the public key is malformed. + pub(crate) fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + match self.params { + MlDsaParams::MlDsa44 => { + verify_with_params::(&self.key, msg, signature) + } + MlDsaParams::MlDsa65 => { + verify_with_params::(&self.key, msg, signature) + } + MlDsaParams::MlDsa87 => { + verify_with_params::(&self.key, msg, signature) + } + } + } +} diff --git a/ssh-key/src/signature.rs b/ssh-key/src/signature.rs index db132c3..10884ed 100644 --- a/ssh-key/src/signature.rs +++ b/ssh-key/src/signature.rs @@ -9,6 +9,9 @@ use signature::{SignatureEncoding, Signer, Verifier}; #[cfg(feature = "ed25519")] use crate::{private::Ed25519Keypair, public::Ed25519PublicKey}; +#[cfg(feature = "mldsa")] +use crate::{private::MlDsaKeypair, public::MlDsaPublicKey}; + #[cfg(feature = "dsa")] use { crate::{private::DsaKeypair, public::DsaPublicKey}, @@ -112,6 +115,7 @@ impl Signature { Algorithm::SkEd25519 if data.len() == SK_ED25519_SIGNATURE_SIZE => (), Algorithm::SkEcdsaSha2NistP256 => ecdsa_sig_size(&data, EcdsaCurve::NistP256, true)?, Algorithm::Rsa { .. } => (), + Algorithm::MlDsa { params } if data.len() == params.signature_size() => (), Algorithm::Other(_) if !data.is_empty() => (), _ => return Err(encoding::Error::Length.into()), } @@ -293,6 +297,8 @@ impl Signer for private::KeypairData { Self::Ed25519(keypair) => keypair.try_sign(message), #[cfg(feature = "rsa")] Self::Rsa(keypair) => keypair.try_sign(message), + #[cfg(feature = "mldsa")] + Self::MlDsa(keypair) => keypair.try_sign(message), _ => Err(self.algorithm()?.unsupported_error().into()), } } @@ -320,6 +326,8 @@ impl Verifier for public::KeyData { Self::SkEcdsaSha2NistP256(pk) => pk.verify(message, signature), #[cfg(feature = "rsa")] Self::Rsa(pk) => pk.verify(message, signature), + #[cfg(feature = "mldsa")] + Self::MlDsa(pk) => pk.verify(message, signature), #[allow(unreachable_patterns)] _ => Err(self.algorithm().unsupported_error().into()), } @@ -452,6 +460,30 @@ impl Verifier for Ed25519PublicKey { } } +#[cfg(feature = "mldsa")] +impl Signer for MlDsaKeypair { + fn try_sign(&self, message: &[u8]) -> signature::Result { + let data = self.sign_msg(message)?; + + Ok(Signature { + algorithm: self.algorithm(), + data, + }) + } +} + +#[cfg(feature = "mldsa")] +impl Verifier for MlDsaPublicKey { + fn verify(&self, message: &[u8], signature: &Signature) -> signature::Result<()> { + // The signature's algorithm (including parameter set) must match this key. + if signature.algorithm() != self.algorithm() { + return Err(Error::Signature.into()); + } + + Ok(self.verify_msg(message, signature.as_bytes())?) + } +} + #[cfg(feature = "ed25519")] impl Verifier for public::SkEd25519 { fn verify(&self, message: &[u8], signature: &Signature) -> signature::Result<()> { @@ -772,6 +804,13 @@ mod tests { #[cfg(feature = "ed25519")] use {super::Ed25519Keypair, signature::Signer}; + #[cfg(feature = "mldsa")] + use { + super::MlDsaKeypair, + crate::MlDsaParams, + signature::{Signer as _, Verifier as _}, + }; + #[cfg(feature = "p256")] use super::{Mpint, zero_pad_field_bytes}; @@ -1001,6 +1040,59 @@ mod tests { assert!(keypair.public.verify(EXAMPLE_MSG, &signature).is_ok()); } + #[cfg(feature = "mldsa")] + #[test] + fn sign_and_verify_mldsa() { + let msg = b"Hello, world!"; + + for params in [ + MlDsaParams::MlDsa44, + MlDsaParams::MlDsa65, + MlDsaParams::MlDsa87, + ] { + let keypair = MlDsaKeypair::from_seed(params, &[42; 32]).unwrap(); + let signature = keypair.sign(msg); + + assert_eq!(signature.algorithm(), Algorithm::MlDsa { params }); + assert_eq!(signature.as_bytes().len(), params.signature_size()); + assert!(keypair.public.verify(msg, &signature).is_ok()); + + // Signing is deterministic (empty context, deterministic variant). + assert_eq!(keypair.sign(msg), signature); + + // A tampered message must fail verification. + assert!(keypair.public.verify(b"tampered", &signature).is_err()); + + // Signature encode/decode round-trips. + let encoded = signature.encode_vec().unwrap(); + let decoded = Signature::try_from(&encoded[..]).unwrap(); + assert_eq!(decoded, signature); + } + } + + #[cfg(feature = "mldsa")] + #[test] + fn mldsa_key_roundtrip() { + use crate::{private::KeypairData, public::KeyData}; + use encoding::Decode; + + let params = MlDsaParams::MlDsa65; + let keypair = MlDsaKeypair::from_seed(params, &[7; 32]).unwrap(); + + // Public key wire-format round-trips and reports the correct algorithm. + let public: KeyData = keypair.public.clone().into(); + let encoded = public.encode_vec().unwrap(); + let decoded = KeyData::decode(&mut &encoded[..]).unwrap(); + assert_eq!(decoded, public); + assert_eq!(decoded.algorithm(), Algorithm::MlDsa { params }); + + // Keypair wire-format round-trips. + let keypair_data = KeypairData::from(keypair); + let encoded = keypair_data.encode_vec().unwrap(); + let decoded = KeypairData::decode(&mut &encoded[..]).unwrap(); + assert_eq!(decoded, keypair_data); + } + #[test] fn placeholder() { assert!(