From c46ed8ed9ec5b0cc61a12115963c98f7f633578f Mon Sep 17 00:00:00 2001 From: marinthiercelin Date: Wed, 30 Jun 2021 16:49:30 +0200 Subject: [PATCH] Add a streaming api to KeyRing and SessionKey (#131) * barebone streaming functionality * encryption needs to return a writecloser * added eof check * workaround for reader problem with copies * separate mobile wrappers from main api * add a clone in the read result to avoid memory corruption * refactor to reuse code, and fix verification * have to give the verify key at the start of the decryption * enfore readAll before signature verification * streaming api for SessionKey * add split message stream apis * name interface params * fix streaming api so it's supported by go-mobile * hide internal writeCloser * fix nil access * added detached sigs methods * started unit testing * unit testing and fixed a bug where key and data packets where inverted * remove unecessary error wrapping * figured out closing order and error handling * add GC calls to mobile writer and reader * remove debugging values and arrays * writer with builtin sha256 * unit testing the mobile helpers * comments and linting * Typo in error Co-authored-by: wussler * Add GetKeyPacket doc Co-authored-by: wussler * Add rfc reference in comments Co-authored-by: wussler * small improvements * add compatibility tests with normal methods * remove unecessary copies in the tests * update go-crypto to the merged changes commit * update comments of core internal functions * remove unused nolint comment * group message metadata in a struct * fix comments * change default values for metadata * change the mobile reader wrapper to fit the behavior of java * remove gc calls in the wrappers to avoid performance penalties * bring back the former Go2MobileReader to be used for ios * Update crypto/keyring_streaming.go Co-authored-by: wussler * return an error when verifying an embedded sig with no keyring * Update crypto/sessionkey_streaming.go Co-authored-by: wussler * linter error * update changelog * update changelog Co-authored-by: wussler --- CHANGELOG.md | 127 +++++++ crypto/keyring_message.go | 95 ++++-- crypto/keyring_streaming.go | 299 +++++++++++++++++ crypto/keyring_streaming_test.go | 491 ++++++++++++++++++++++++++++ crypto/sessionkey.go | 152 +++++---- crypto/sessionkey_streaming.go | 105 ++++++ crypto/sessionkey_streaming_test.go | 176 ++++++++++ go.mod | 2 +- go.sum | 4 + helper/mobile_stream.go | 182 +++++++++++ helper/mobile_stream_test.go | 182 +++++++++++ 11 files changed, 1718 insertions(+), 97 deletions(-) create mode 100644 crypto/keyring_streaming.go create mode 100644 crypto/keyring_streaming_test.go create mode 100644 crypto/sessionkey_streaming.go create mode 100644 crypto/sessionkey_streaming_test.go create mode 100644 helper/mobile_stream.go create mode 100644 helper/mobile_stream_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 142354d..8df5f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,133 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased +### Added +- Streaming API: + - New structs: + - `PlainMessageMetadata`: holds the metadata of a plain PGP message + ```go + type PlainMessageMetadata struct { + IsBinary bool + Filename string + ModTime int64 + } + ``` + - `PlainMessageReader` implements `Reader` and: + ```go + func (msg *PlainMessageReader) GetMetadata() *PlainMessageMetadata + func (msg *PlainMessageReader) VerifySignature() (err error) + ``` + - `EncryptSplitResult` implements `WriteCloser` and: + ```go + func (res *EncryptSplitResult) GetKeyPacket() (keyPacket []byte, err error) + ``` + - Keyring methods: + - Encrypt (and optionally sign) a message directly into a `Writer`: + ```go + func (keyRing *KeyRing) EncryptStream( + pgpMessageWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, + ) (plainMessageWriter WriteCloser, err error) + ``` + - Encrypt (and optionally sign) a message directly into a `Writer` (split keypacket and datapacket): + ```go + func (keyRing *KeyRing) EncryptSplitStream( + dataPacketWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, + ) (*EncryptSplitResult, error) + ``` + + - Decrypt (and optionally verify) a message from a `Reader`: + ```go + func (keyRing *KeyRing) DecryptStream( + message Reader, + verifyKeyRing *KeyRing, + verifyTime int64, + ) (plainMessage *PlainMessageReader, err error) + ``` + N.B.: to verify the signature, you will need to call `plainMessage.VerifySignature()` after all the data has been read from `plainMessage`. + - Decrypt (and optionally verify) a split message, getting the datapacket from a `Reader`: + ```go + func (keyRing *KeyRing) DecryptSplitStream( + keypacket []byte, + dataPacketReader Reader, + verifyKeyRing *KeyRing, verifyTime int64, + ) (plainMessage *PlainMessageReader, err error) + ``` + N.B.: to verify the signature, you will need to call `plainMessage.VerifySignature()` after all the data has been read from `plainMessage`. + - Generate a detached signature from a `Reader`: + ```go + func (keyRing *KeyRing) SignDetachedStream(message Reader) (*PGPSignature, error) + ``` + - Verify a detached signature for a `Reader`: + ```go + func (keyRing *KeyRing) VerifyDetachedStream( + message Reader, + signature *PGPSignature, + verifyTime int64, + ) error + ``` + - Generate an encrypted detached signature from a `Reader`: + ```go + func (keyRing *KeyRing) SignDetachedEncryptedStream( + message Reader, + encryptionKeyRing *KeyRing, + ) (encryptedSignature *PGPMessage, err error) + ``` + - Verify an encrypted detached signature for a `Reader`: + ```go + func (keyRing *KeyRing) VerifyDetachedEncryptedStream( + message Reader, + encryptedSignature *PGPMessage, + decryptionKeyRing *KeyRing, + verifyTime int64, + ) error + ``` + - SessionKey methods: + - Encrypt (and optionally sign) a message into a `Writer`: + ```go + func (sk *SessionKey) EncryptStream( + dataPacketWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, + ) (plainMessageWriter WriteCloser, err error) + ``` + - Decrypt (and optionally verify) a message from a `Reader`: + ```go + func (sk *SessionKey) DecryptStream( + dataPacketReader Reader, + verifyKeyRing *KeyRing, + verifyTime int64, + ) (plainMessage *PlainMessageReader, err error) + ``` + N.B.: to verify the signature, you will need to call `plainMessage.VerifySignature()` after all the data has been read from `plainMessage`. + - Mobile apps helpers for `Reader` and `Writer`: + Due to limitations of `gomobile`, mobile apps can't implement the `Reader` and `Writer` interfaces directly. + + - Implementing `Reader`: Apps should implement the interface: + ```go + type MobileReader interface { + Read(max int) (result *MobileReadResult, err error) + } + type MobileReadResult struct { + N int // N, The number of bytes read + IsEOF bool // IsEOF, If true, then the reader has reached the end of the data to read. + Data []byte // Data, the data that has been read + } + ``` + And then wrap it with `Mobile2GoReader(mobileReader)` to turn it into a `Reader`. + + - Implementing `Writer`: + + The apps should implement the `Writer` interface directly, but still need to wrap the writer with `Mobile2GoWriter(mobileWriter)`. We also provide the `Mobile2GoWriterWithSHA256` if you want to compute the SHA256 hash of the written data. + + - Using a `Reader`: To use a reader returned by golang in mobile apps: you should wrap it with: + - Android: `Go2AndroidReader(reader)`, implements the `Reader` interface, but returns `n == -1` instead of `err == io.EOF` + - iOS: `Go2IOSReader(reader)`, implements `MobileReader`. + - Using a `Writer`: you can use a writer returned by golang directly. ## [2.1.10] 2021-06-16 ### Fixed - Removed time interpolation via monotonic clock that can cause signatures in the future diff --git a/crypto/keyring_message.go b/crypto/keyring_message.go index eb6ef07..d819dad 100644 --- a/crypto/keyring_message.go +++ b/crypto/keyring_message.go @@ -119,7 +119,7 @@ func (keyRing *KeyRing) VerifyDetachedEncrypted(message *PlainMessage, encrypted // ------ INTERNAL FUNCTIONS ------- -// Core for encryption+signature functions. +// Core for encryption+signature (non-streaming) functions. func asymmetricEncrypt( plainMessage *PlainMessage, publicKey, privateKey *KeyRing, @@ -127,30 +127,17 @@ func asymmetricEncrypt( ) ([]byte, error) { var outBuf bytes.Buffer var encryptWriter io.WriteCloser - var signEntity *openpgp.Entity var err error - if privateKey != nil && len(privateKey.entities) > 0 { - var err error - signEntity, err = privateKey.getSigningEntity() - if err != nil { - return nil, err - } - } - hints := &openpgp.FileHints{ IsBinary: plainMessage.IsBinary(), FileName: plainMessage.Filename, ModTime: plainMessage.getFormattedTime(), } - if plainMessage.IsBinary() { - encryptWriter, err = openpgp.Encrypt(&outBuf, publicKey.entities, signEntity, hints, config) - } else { - encryptWriter, err = openpgp.EncryptText(&outBuf, publicKey.entities, signEntity, hints, config) - } + encryptWriter, err = asymmetricEncryptStream(hints, &outBuf, &outBuf, publicKey, privateKey, config) if err != nil { - return nil, errors.Wrap(err, "gopenpgp: error in encrypting asymmetrically") + return nil, err } _, err = encryptWriter.Write(plainMessage.GetBinary()) @@ -166,26 +153,46 @@ func asymmetricEncrypt( return outBuf.Bytes(), nil } -// Core for decryption+verification functions. +// Core for encryption+signature (all) functions. +func asymmetricEncryptStream( + hints *openpgp.FileHints, + keyPacketWriter io.Writer, + dataPacketWriter io.Writer, + publicKey, privateKey *KeyRing, + config *packet.Config, +) (encryptWriter io.WriteCloser, err error) { + var signEntity *openpgp.Entity + + if privateKey != nil && len(privateKey.entities) > 0 { + var err error + signEntity, err = privateKey.getSigningEntity() + if err != nil { + return nil, err + } + } + + if hints.IsBinary { + encryptWriter, err = openpgp.EncryptSplit(keyPacketWriter, dataPacketWriter, publicKey.entities, signEntity, hints, config) + } else { + encryptWriter, err = openpgp.EncryptTextSplit(keyPacketWriter, dataPacketWriter, publicKey.entities, signEntity, hints, config) + } + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in encrypting asymmetrically") + } + return encryptWriter, nil +} + +// Core for decryption+verification (non streaming) functions. func asymmetricDecrypt( encryptedIO io.Reader, privateKey *KeyRing, verifyKey *KeyRing, verifyTime int64, ) (message *PlainMessage, err error) { - privKeyEntries := privateKey.entities - var additionalEntries openpgp.EntityList - - if verifyKey != nil { - additionalEntries = verifyKey.entities - } - - if additionalEntries != nil { - privKeyEntries = append(privKeyEntries, additionalEntries...) - } - - config := &packet.Config{Time: getTimeGenerator()} - - messageDetails, err := openpgp.ReadMessage(encryptedIO, privKeyEntries, nil, config) + messageDetails, err := asymmetricDecryptStream( + encryptedIO, + privateKey, + verifyKey, + ) if err != nil { - return nil, errors.Wrap(err, "gopenpgp: error in reading message") + return nil, err } body, err := ioutil.ReadAll(messageDetails.UnverifiedBody) @@ -205,3 +212,27 @@ func asymmetricDecrypt( Time: messageDetails.LiteralData.Time, }, err } + +// Core for decryption+verification (all) functions. +func asymmetricDecryptStream( + encryptedIO io.Reader, privateKey *KeyRing, verifyKey *KeyRing, +) (messageDetails *openpgp.MessageDetails, err error) { + privKeyEntries := privateKey.entities + var additionalEntries openpgp.EntityList + + if verifyKey != nil { + additionalEntries = verifyKey.entities + } + + if additionalEntries != nil { + privKeyEntries = append(privKeyEntries, additionalEntries...) + } + + config := &packet.Config{Time: getTimeGenerator()} + + messageDetails, err = openpgp.ReadMessage(encryptedIO, privKeyEntries, nil, config) + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in reading message") + } + return messageDetails, err +} diff --git a/crypto/keyring_streaming.go b/crypto/keyring_streaming.go new file mode 100644 index 0000000..e16be5a --- /dev/null +++ b/crypto/keyring_streaming.go @@ -0,0 +1,299 @@ +package crypto + +import ( + "bytes" + "crypto" + "io" + "time" + + "github.com/ProtonMail/go-crypto/openpgp" + "github.com/ProtonMail/go-crypto/openpgp/packet" + "github.com/pkg/errors" +) + +type Reader interface { + Read(b []byte) (n int, err error) +} + +type Writer interface { + Write(b []byte) (n int, err error) +} + +type WriteCloser interface { + Write(b []byte) (n int, err error) + Close() (err error) +} + +type PlainMessageMetadata struct { + IsBinary bool + Filename string + ModTime int64 +} + +func NewPlainMessageMetadata(isBinary bool, filename string, modTime int64) *PlainMessageMetadata { + return &PlainMessageMetadata{IsBinary: isBinary, Filename: filename, ModTime: modTime} +} + +// EncryptStream is used to encrypt data as a Writer. +// It takes a writer for the encrypted data and returns a WriteCloser for the plaintext data +// If signKeyRing is not nil, it is used to do an embedded signature. +func (keyRing *KeyRing) EncryptStream( + pgpMessageWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, +) (plainMessageWriter WriteCloser, err error) { + config := &packet.Config{DefaultCipher: packet.CipherAES256, Time: getTimeGenerator()} + + if plainMessageMetadata == nil { + // Use sensible default metadata + plainMessageMetadata = &PlainMessageMetadata{ + IsBinary: true, + Filename: "", + ModTime: GetUnixTime(), + } + } + + hints := &openpgp.FileHints{ + FileName: plainMessageMetadata.Filename, + IsBinary: plainMessageMetadata.IsBinary, + ModTime: time.Unix(plainMessageMetadata.ModTime, 0), + } + + plainMessageWriter, err = asymmetricEncryptStream(hints, pgpMessageWriter, pgpMessageWriter, keyRing, signKeyRing, config) + if err != nil { + return nil, err + } + return plainMessageWriter, nil +} + +// EncryptSplitResult is used to wrap the encryption writecloser while storing the key packet. +type EncryptSplitResult struct { + isClosed bool + keyPacketBuf *bytes.Buffer + keyPacket []byte + plainMessageWriter WriteCloser // The writer to writer plaintext data in. +} + +func (res *EncryptSplitResult) Write(b []byte) (n int, err error) { + return res.plainMessageWriter.Write(b) +} + +func (res *EncryptSplitResult) Close() (err error) { + err = res.plainMessageWriter.Close() + if err != nil { + return err + } + res.isClosed = true + res.keyPacket = res.keyPacketBuf.Bytes() + return nil +} + +// GetKeyPacket returns the Public-Key Encrypted Session Key Packets (https://datatracker.ietf.org/doc/html/rfc4880#section-5.1). +// This can be retrieved only after the message has been fully written and the writer is closed. +func (res *EncryptSplitResult) GetKeyPacket() (keyPacket []byte, err error) { + if !res.isClosed { + return nil, errors.New("gopenpgp: can't access key packet until the message writer has been closed") + } + return res.keyPacket, nil +} + +// EncryptSplitStream is used to encrypt data as a stream. +// It takes a writer for the Symmetrically Encrypted Data Packet +// (https://datatracker.ietf.org/doc/html/rfc4880#section-5.7) +// and returns a writer for the plaintext data and the key packet. +// If signKeyRing is not nil, it is used to do an embedded signature. +func (keyRing *KeyRing) EncryptSplitStream( + dataPacketWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, +) (*EncryptSplitResult, error) { + config := &packet.Config{DefaultCipher: packet.CipherAES256, Time: getTimeGenerator()} + + if plainMessageMetadata == nil { + // Use sensible default metadata + plainMessageMetadata = &PlainMessageMetadata{ + IsBinary: true, + Filename: "", + ModTime: GetUnixTime(), + } + } + + hints := &openpgp.FileHints{ + FileName: plainMessageMetadata.Filename, + IsBinary: plainMessageMetadata.IsBinary, + ModTime: time.Unix(plainMessageMetadata.ModTime, 0), + } + + var keyPacketBuf bytes.Buffer + plainMessageWriter, err := asymmetricEncryptStream(hints, &keyPacketBuf, dataPacketWriter, keyRing, signKeyRing, config) + if err != nil { + return nil, err + } + return &EncryptSplitResult{ + keyPacketBuf: &keyPacketBuf, + plainMessageWriter: plainMessageWriter, + }, nil +} + +// PlainMessageReader is used to wrap the data of the decrypted plain message. +// It can be used to read the decrypted data and verify the embedded signature. +type PlainMessageReader struct { + details *openpgp.MessageDetails + verifyKeyRing *KeyRing + verifyTime int64 + readAll bool +} + +// GetMetadata returns the metadata of the decrypted message. +func (msg *PlainMessageReader) GetMetadata() *PlainMessageMetadata { + return &PlainMessageMetadata{ + Filename: msg.details.LiteralData.FileName, + IsBinary: msg.details.LiteralData.IsBinary, + ModTime: int64(msg.details.LiteralData.Time), + } +} + +// Read is used to access the message decrypted data. +// Makes PlainMessageReader implement the Reader interface. +func (msg *PlainMessageReader) Read(b []byte) (n int, err error) { + n, err = msg.details.UnverifiedBody.Read(b) + if errors.Is(err, io.EOF) { + msg.readAll = true + } + return +} + +// VerifySignature is used to verify that the signature is valid. +// This method needs to be called once all the data has been read. +// It will return an error if the signature is invalid +// or if the message hasn't been read entirely. +func (msg *PlainMessageReader) VerifySignature() (err error) { + if !msg.readAll { + return errors.New("gopenpgp: can't verify the signature until the message reader has been read entirely") + } + if msg.verifyKeyRing != nil { + processSignatureExpiration(msg.details, msg.verifyTime) + err = verifyDetailsSignature(msg.details, msg.verifyKeyRing) + } else { + err = errors.New("gopenpgp: no verify keyring was provided before decryption") + } + return +} + +// DecryptStream is used to decrypt a pgp message as a Reader. +// It takes a reader for the message data +// and returns a PlainMessageReader for the plaintext data. +// If verifyKeyRing is not nil, PlainMessageReader.VerifySignature() will +// verify the embedded signature with the given key ring and verification time. +func (keyRing *KeyRing) DecryptStream( + message Reader, + verifyKeyRing *KeyRing, + verifyTime int64, +) (plainMessage *PlainMessageReader, err error) { + messageDetails, err := asymmetricDecryptStream( + message, + keyRing, + verifyKeyRing, + ) + if err != nil { + return nil, err + } + + return &PlainMessageReader{ + messageDetails, + verifyKeyRing, + verifyTime, + false, + }, err +} + +// DecryptSplitStream is used to decrypt a split pgp message as a Reader. +// It takes a key packet and a reader for the data packet +// and returns a PlainMessageReader for the plaintext data. +// If verifyKeyRing is not nil, PlainMessageReader.VerifySignature() will +// verify the embedded signature with the given key ring and verification time. +func (keyRing *KeyRing) DecryptSplitStream( + keypacket []byte, + dataPacketReader Reader, + verifyKeyRing *KeyRing, verifyTime int64, +) (plainMessage *PlainMessageReader, err error) { + messageReader := io.MultiReader( + bytes.NewReader(keypacket), + dataPacketReader, + ) + return keyRing.DecryptStream( + messageReader, + verifyKeyRing, + verifyTime, + ) +} + +// SignDetachedStream generates and returns a PGPSignature for a given message Reader. +func (keyRing *KeyRing) SignDetachedStream(message Reader) (*PGPSignature, error) { + signEntity, err := keyRing.getSigningEntity() + if err != nil { + return nil, err + } + + config := &packet.Config{DefaultHash: crypto.SHA512, Time: getTimeGenerator()} + var outBuf bytes.Buffer + // sign bin + if err := openpgp.DetachSign(&outBuf, signEntity, message, config); err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in signing") + } + + return NewPGPSignature(outBuf.Bytes()), nil +} + +// VerifyDetachedStream verifies a message reader with a detached PGPSignature +// and returns a SignatureVerificationError if fails. +func (keyRing *KeyRing) VerifyDetachedStream( + message Reader, + signature *PGPSignature, + verifyTime int64, +) error { + return verifySignature( + keyRing.entities, + message, + signature.GetBinary(), + verifyTime, + ) +} + +// SignDetachedEncryptedStream generates and returns a PGPMessage +// containing an encrypted detached signature for a given message Reader. +func (keyRing *KeyRing) SignDetachedEncryptedStream( + message Reader, + encryptionKeyRing *KeyRing, +) (encryptedSignature *PGPMessage, err error) { + if encryptionKeyRing == nil { + return nil, errors.New("gopenpgp: no encryption key ring provided") + } + signature, err := keyRing.SignDetachedStream(message) + if err != nil { + return nil, err + } + plainMessage := NewPlainMessage(signature.GetBinary()) + encryptedSignature, err = encryptionKeyRing.Encrypt(plainMessage, nil) + return +} + +// VerifyDetachedEncryptedStream verifies a PlainMessage +// with a PGPMessage containing an encrypted detached signature +// and returns a SignatureVerificationError if fails. +func (keyRing *KeyRing) VerifyDetachedEncryptedStream( + message Reader, + encryptedSignature *PGPMessage, + decryptionKeyRing *KeyRing, + verifyTime int64, +) error { + if decryptionKeyRing == nil { + return errors.New("gopenpgp: no decryption key ring provided") + } + plainMessage, err := decryptionKeyRing.Decrypt(encryptedSignature, nil, 0) + if err != nil { + return err + } + signature := NewPGPSignature(plainMessage.GetBinary()) + return keyRing.VerifyDetachedStream(message, signature, verifyTime) +} diff --git a/crypto/keyring_streaming_test.go b/crypto/keyring_streaming_test.go new file mode 100644 index 0000000..c9a96de --- /dev/null +++ b/crypto/keyring_streaming_test.go @@ -0,0 +1,491 @@ +package crypto + +import ( + "bytes" + "io" + "reflect" + "testing" + + "github.com/pkg/errors" +) + +var testMeta = &PlainMessageMetadata{ + IsBinary: true, + Filename: "filename.txt", + ModTime: GetUnixTime(), +} + +func TestKeyRing_EncryptDecryptStream(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var ciphertextBuf bytes.Buffer + messageWriter, err := keyRingTestPublic.EncryptStream( + &ciphertextBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting stream with key ring, got:", err) + } + reachedEnd := false + bufferSize := 2 + buffer := make([]byte, bufferSize) + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + ciphertextBytes := ciphertextBuf.Bytes() + decryptedReader, err := keyRingTestPrivate.DecryptStream( + bytes.NewReader(ciphertextBytes), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling decrypting stream with key ring, got:", err) + } + err = decryptedReader.VerifySignature() + if err == nil { + t.Fatal("Expected an error while verifying the signature before reading the data, got nil") + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } + decryptedReaderNoVerify, err := keyRingTestPrivate.DecryptStream( + bytes.NewReader(ciphertextBytes), + nil, + 0, + ) + if err != nil { + t.Fatal("Expected no error while calling decrypting stream with key ring, got:", err) + } + decryptedBytes, err = io.ReadAll(decryptedReaderNoVerify) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + decryptedMeta = decryptedReaderNoVerify.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } + err = decryptedReaderNoVerify.VerifySignature() + if err == nil { + t.Fatal("Expected an error while verifying the signature with no keyring, got nil") + } +} + +func TestKeyRing_EncryptStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var ciphertextBuf bytes.Buffer + messageWriter, err := keyRingTestPublic.EncryptStream( + &ciphertextBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting stream with key ring, got:", err) + } + reachedEnd := false + bufferSize := 2 + buffer := make([]byte, bufferSize) + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + encryptedData := ciphertextBuf.Bytes() + decryptedMsg, err := keyRingTestPrivate.Decrypt( + NewPGPMessage(encryptedData), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling normal decrypt with key ring, got:", err) + } + decryptedBytes := decryptedMsg.GetBinary() + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the normally decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + if testMeta.IsBinary != decryptedMsg.IsBinary() { + t.Fatalf("Expected isBinary to be %t got %t", testMeta.IsBinary, decryptedMsg.IsBinary()) + } + if testMeta.Filename != decryptedMsg.GetFilename() { + t.Fatalf("Expected filename to be %s got %s", testMeta.Filename, decryptedMsg.GetFilename()) + } + if testMeta.ModTime != int64(decryptedMsg.GetTime()) { + t.Fatalf("Expected modification time to be %d got %d", testMeta.ModTime, int64(decryptedMsg.GetTime())) + } +} + +func TestKeyRing_DecryptStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + pgpMessage, err := keyRingTestPublic.Encrypt( + &PlainMessage{ + Data: messageBytes, + TextType: !testMeta.IsBinary, + Time: uint32(testMeta.ModTime), + Filename: testMeta.Filename, + }, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting plaintext, got:", err) + } + decryptedReader, err := keyRingTestPrivate.DecryptStream( + bytes.NewReader(pgpMessage.GetBinary()), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling decrypting stream with key ring, got:", err) + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } +} + +func TestKeyRing_EncryptDecryptSplitStream(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var dataPacketBuf bytes.Buffer + encryptionResult, err := keyRingTestPublic.EncryptSplitStream( + &dataPacketBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while calling encrypting split stream with key ring, got:", err) + } + messageWriter := encryptionResult + reachedEnd := false + bufferSize := 2 + buffer := make([]byte, bufferSize) + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + keyPacket, err := encryptionResult.GetKeyPacket() + if err != nil { + t.Fatal("Expected no error while accessing key packet, got:", err) + } + dataPacket := dataPacketBuf.Bytes() + decryptedReader, err := keyRingTestPrivate.DecryptSplitStream( + keyPacket, + bytes.NewReader(dataPacket), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while decrypting split stream with key ring, got:", err) + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } +} + +func TestKeyRing_EncryptSplitStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var dataPacketBuf bytes.Buffer + encryptionResult, err := keyRingTestPublic.EncryptSplitStream( + &dataPacketBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while calling encrypting split stream with key ring, got:", err) + } + messageWriter := encryptionResult + reachedEnd := false + bufferSize := 2 + buffer := make([]byte, bufferSize) + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + keyPacket, err := encryptionResult.GetKeyPacket() + if err != nil { + t.Fatal("Expected no error while accessing key packet, got:", err) + } + dataPacket := dataPacketBuf.Bytes() + decryptedMsg, err := keyRingTestPrivate.Decrypt( + NewPGPSplitMessage(keyPacket, dataPacket).GetPGPMessage(), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while decrypting split stream with key ring, got:", err) + } + decryptedBytes := decryptedMsg.GetBinary() + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + if testMeta.IsBinary != decryptedMsg.IsBinary() { + t.Fatalf("Expected isBinary to be %t got %t", testMeta.IsBinary, decryptedMsg.IsBinary()) + } + if testMeta.Filename != decryptedMsg.GetFilename() { + t.Fatalf("Expected filename to be %s got %s", testMeta.Filename, decryptedMsg.GetFilename()) + } + if testMeta.ModTime != int64(decryptedMsg.GetTime()) { + t.Fatalf("Expected modification time to be %d got %d", testMeta.ModTime, int64(decryptedMsg.GetTime())) + } +} + +func TestKeyRing_DecryptSplitStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + pgpMessage, err := keyRingTestPublic.Encrypt( + &PlainMessage{ + Data: messageBytes, + TextType: !testMeta.IsBinary, + Time: uint32(testMeta.ModTime), + Filename: testMeta.Filename, + }, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting plaintext, got:", err) + } + armored, err := pgpMessage.GetArmored() + if err != nil { + t.Fatal("Expected no error while armoring ciphertext, got:", err) + } + splitMsg, err := NewPGPSplitMessageFromArmored(armored) + if err != nil { + t.Fatal("Expected no error while splitting the ciphertext, got:", err) + } + keyPacket := splitMsg.KeyPacket + if err != nil { + t.Fatal("Expected no error while accessing key packet, got:", err) + } + dataPacket := splitMsg.DataPacket + decryptedReader, err := keyRingTestPrivate.DecryptSplitStream( + keyPacket, + bytes.NewReader(dataPacket), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while decrypting split stream with key ring, got:", err) + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } +} + +func TestKeyRing_SignVerifyDetachedStream(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + signature, err := keyRingTestPrivate.SignDetachedStream(messageReader) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + _, err = messageReader.Seek(0, 0) + if err != nil { + t.Fatal("Expected no error while rewinding the message reader, got:", err) + } + err = keyRingTestPublic.VerifyDetachedStream(messageReader, signature, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} + +func TestKeyRing_SignDetachedStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + signature, err := keyRingTestPrivate.SignDetachedStream(messageReader) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + err = keyRingTestPublic.VerifyDetached(NewPlainMessage(messageBytes), signature, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} + +func TestKeyRing_VerifyDetachedStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + signature, err := keyRingTestPrivate.SignDetached(NewPlainMessage(messageBytes)) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + _, err = messageReader.Seek(0, 0) + if err != nil { + t.Fatal("Expected no error while rewinding the message reader, got:", err) + } + err = keyRingTestPublic.VerifyDetachedStream(messageReader, signature, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} + +func TestKeyRing_SignVerifyDetachedEncryptedStream(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + encSignature, err := keyRingTestPrivate.SignDetachedEncryptedStream(messageReader, keyRingTestPublic) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + _, err = messageReader.Seek(0, 0) + if err != nil { + t.Fatal("Expected no error while rewinding the message reader, got:", err) + } + err = keyRingTestPublic.VerifyDetachedEncryptedStream(messageReader, encSignature, keyRingTestPrivate, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} + +func TestKeyRing_SignDetachedEncryptedStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + encSignature, err := keyRingTestPrivate.SignDetachedEncryptedStream(messageReader, keyRingTestPublic) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + err = keyRingTestPublic.VerifyDetachedEncrypted(NewPlainMessage(messageBytes), encSignature, keyRingTestPrivate, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} + +func TestKeyRing_VerifyDetachedEncryptedStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + encSignature, err := keyRingTestPrivate.SignDetachedEncrypted(NewPlainMessage(messageBytes), keyRingTestPublic) + if err != nil { + t.Fatal("Expected no error while signing the message, got:", err) + } + _, err = messageReader.Seek(0, 0) + if err != nil { + t.Fatal("Expected no error while rewinding the message reader, got:", err) + } + err = keyRingTestPublic.VerifyDetachedEncryptedStream(messageReader, encSignature, keyRingTestPrivate, GetUnixTime()) + if err != nil { + t.Fatal("Expected no error while verifying the detached signature, got:", err) + } +} diff --git a/crypto/sessionkey.go b/crypto/sessionkey.go index 33f037f..49c0c5a 100644 --- a/crypto/sessionkey.go +++ b/crypto/sessionkey.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "fmt" "io" + "time" "github.com/ProtonMail/gopenpgp/v2/constants" "github.com/pkg/errors" @@ -171,67 +172,87 @@ func (sk *SessionKey) EncryptWithCompression(message *PlainMessage) ([]byte, err func encryptWithSessionKey(message *PlainMessage, sk *SessionKey, signEntity *openpgp.Entity, config *packet.Config) ([]byte, error) { var encBuf = new(bytes.Buffer) - var encryptWriter, signWriter io.WriteCloser - encryptWriter, err := packet.SerializeSymmetricallyEncrypted(encBuf, config.Cipher(), sk.Key, config) + encryptWriter, signWriter, err := encryptStreamWithSessionKey( + message.IsBinary(), + message.Filename, + message.Time, + encBuf, + sk, + signEntity, + config, + ) if err != nil { - return nil, errors.Wrap(err, "gopenpgp: unable to encrypt") + return nil, err } - - if algo := config.Compression(); algo != packet.CompressionNone { - encryptWriter, err = packet.SerializeCompressed(encryptWriter, algo, config.CompressionConfig) - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: error in compression") - } - } - - if signEntity != nil { // nolint:nestif - hints := &openpgp.FileHints{ - IsBinary: message.IsBinary(), - FileName: message.Filename, - ModTime: message.getFormattedTime(), - } - - signWriter, err = openpgp.Sign(encryptWriter, signEntity, hints, config) - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: unable to sign") - } - + if signEntity != nil { _, err = signWriter.Write(message.GetBinary()) if err != nil { return nil, errors.Wrap(err, "gopenpgp: error in writing signed message") } - err = signWriter.Close() if err != nil { return nil, errors.Wrap(err, "gopenpgp: error in closing signing writer") } } else { - encryptWriter, err = packet.SerializeLiteral( - encryptWriter, - message.IsBinary(), - message.Filename, - message.Time, - ) - - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: unable to serialize") - } - _, err = encryptWriter.Write(message.GetBinary()) - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: error in writing message") - } } - + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in writing message") + } err = encryptWriter.Close() if err != nil { return nil, errors.Wrap(err, "gopenpgp: error in closing encryption writer") } - return encBuf.Bytes(), nil } +func encryptStreamWithSessionKey( + isBinary bool, + filename string, + modTime uint32, + dataPacketWriter io.Writer, + sk *SessionKey, + signEntity *openpgp.Entity, + config *packet.Config, +) (encryptWriter, signWriter io.WriteCloser, err error) { + encryptWriter, err = packet.SerializeSymmetricallyEncrypted(dataPacketWriter, config.Cipher(), sk.Key, config) + if err != nil { + return nil, nil, errors.Wrap(err, "gopenpgp: unable to encrypt") + } + + if algo := config.Compression(); algo != packet.CompressionNone { + encryptWriter, err = packet.SerializeCompressed(encryptWriter, algo, config.CompressionConfig) + if err != nil { + return nil, nil, errors.Wrap(err, "gopenpgp: error in compression") + } + } + + if signEntity != nil { + hints := &openpgp.FileHints{ + IsBinary: isBinary, + FileName: filename, + ModTime: time.Unix(int64(modTime), 0), + } + + signWriter, err = openpgp.Sign(encryptWriter, signEntity, hints, config) + if err != nil { + return nil, nil, errors.Wrap(err, "gopenpgp: unable to sign") + } + } else { + encryptWriter, err = packet.SerializeLiteral( + encryptWriter, + isBinary, + filename, + modTime, + ) + if err != nil { + return nil, nil, errors.Wrap(err, "gopenpgp: unable to serialize") + } + } + return encryptWriter, signWriter, nil +} + // Decrypt decrypts pgp data packets using directly a session key. // * encrypted: PGPMessage. // * output: PlainMessage. @@ -246,8 +267,32 @@ func (sk *SessionKey) Decrypt(dataPacket []byte) (*PlainMessage, error) { // * output: PlainMessage. func (sk *SessionKey) DecryptAndVerify(dataPacket []byte, verifyKeyRing *KeyRing, verifyTime int64) (*PlainMessage, error) { var messageReader = bytes.NewReader(dataPacket) + + md, err := decryptStreamWithSessionKey(sk, messageReader, verifyKeyRing) + if err != nil { + return nil, err + } + messageBuf := new(bytes.Buffer) + _, err = messageBuf.ReadFrom(md.UnverifiedBody) + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in reading message body") + } + + if verifyKeyRing != nil { + processSignatureExpiration(md, verifyTime) + err = verifyDetailsSignature(md, verifyKeyRing) + } + + return &PlainMessage{ + Data: messageBuf.Bytes(), + TextType: !md.LiteralData.IsBinary, + Filename: md.LiteralData.FileName, + Time: md.LiteralData.Time, + }, err +} + +func decryptStreamWithSessionKey(sk *SessionKey, messageReader io.Reader, verifyKeyRing *KeyRing) (*openpgp.MessageDetails, error) { var decrypted io.ReadCloser - var decBuf bytes.Buffer var keyring openpgp.EntityList // Read symmetrically encrypted data packet @@ -273,45 +318,24 @@ func (sk *SessionKey) DecryptAndVerify(dataPacket []byte, verifyKeyRing *KeyRing default: return nil, errors.New("gopenpgp: invalid packet type") } - _, err = decBuf.ReadFrom(decrypted) - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: unable to read from decrypted symmetric packet") - } config := &packet.Config{ Time: getTimeGenerator(), } // Push decrypted packet as literal packet and use openpgp's reader - if verifyKeyRing != nil { keyring = verifyKeyRing.entities } else { keyring = openpgp.EntityList{} } - md, err := openpgp.ReadMessage(&decBuf, keyring, nil, config) + md, err := openpgp.ReadMessage(decrypted, keyring, nil, config) if err != nil { return nil, errors.Wrap(err, "gopenpgp: unable to decode symmetric packet") } - messageBuf := new(bytes.Buffer) - _, err = messageBuf.ReadFrom(md.UnverifiedBody) - if err != nil { - return nil, errors.Wrap(err, "gopenpgp: error in reading message body") - } - - if verifyKeyRing != nil { - processSignatureExpiration(md, verifyTime) - err = verifyDetailsSignature(md, verifyKeyRing) - } - - return &PlainMessage{ - Data: messageBuf.Bytes(), - TextType: !md.LiteralData.IsBinary, - Filename: md.LiteralData.FileName, - Time: md.LiteralData.Time, - }, err + return md, nil } func (sk *SessionKey) checkSize() error { diff --git a/crypto/sessionkey_streaming.go b/crypto/sessionkey_streaming.go new file mode 100644 index 0000000..02e80ab --- /dev/null +++ b/crypto/sessionkey_streaming.go @@ -0,0 +1,105 @@ +package crypto + +import ( + "github.com/ProtonMail/go-crypto/openpgp" + "github.com/ProtonMail/go-crypto/openpgp/packet" + "github.com/pkg/errors" +) + +type signAndEncryptWriteCloser struct { + signWriter WriteCloser + encryptWriter WriteCloser +} + +func (w *signAndEncryptWriteCloser) Write(b []byte) (int, error) { + return w.signWriter.Write(b) +} + +func (w *signAndEncryptWriteCloser) Close() error { + if err := w.signWriter.Close(); err != nil { + return err + } + return w.encryptWriter.Close() +} + +// EncryptStream is used to encrypt data as a Writer. +// It takes a writer for the encrypted data packet and returns a writer for the plaintext data. +// If signKeyRing is not nil, it is used to do an embedded signature. +func (sk *SessionKey) EncryptStream( + dataPacketWriter Writer, + plainMessageMetadata *PlainMessageMetadata, + signKeyRing *KeyRing, +) (plainMessageWriter WriteCloser, err error) { + dc, err := sk.GetCipherFunc() + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: unable to encrypt with session key") + } + + config := &packet.Config{ + Time: getTimeGenerator(), + DefaultCipher: dc, + } + var signEntity *openpgp.Entity + if signKeyRing != nil { + signEntity, err = signKeyRing.getSigningEntity() + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: unable to sign") + } + } + + if plainMessageMetadata == nil { + // Use sensible default metadata + plainMessageMetadata = &PlainMessageMetadata{ + IsBinary: true, + Filename: "", + ModTime: GetUnixTime(), + } + } + + encryptWriter, signWriter, err := encryptStreamWithSessionKey( + plainMessageMetadata.IsBinary, + plainMessageMetadata.Filename, + uint32(plainMessageMetadata.ModTime), + dataPacketWriter, + sk, + signEntity, + config, + ) + + if err != nil { + return nil, err + } + if signWriter != nil { + plainMessageWriter = &signAndEncryptWriteCloser{signWriter, encryptWriter} + } else { + plainMessageWriter = encryptWriter + } + return plainMessageWriter, err +} + +// DecryptStream is used to decrypt a data packet as a Reader. +// It takes a reader for the data packet +// and returns a PlainMessageReader for the plaintext data. +// If verifyKeyRing is not nil, PlainMessageReader.VerifySignature() will +// verify the embedded signature with the given key ring and verification time. +func (sk *SessionKey) DecryptStream( + dataPacketReader Reader, + verifyKeyRing *KeyRing, + verifyTime int64, +) (plainMessage *PlainMessageReader, err error) { + messageDetails, err := decryptStreamWithSessionKey( + sk, + dataPacketReader, + verifyKeyRing, + ) + if err != nil { + return nil, errors.Wrap(err, "gopenpgp: error in reading message") + } + + return &PlainMessageReader{ + messageDetails, + verifyKeyRing, + verifyTime, + false, + }, err +} diff --git a/crypto/sessionkey_streaming_test.go b/crypto/sessionkey_streaming_test.go new file mode 100644 index 0000000..0115814 --- /dev/null +++ b/crypto/sessionkey_streaming_test.go @@ -0,0 +1,176 @@ +package crypto + +import ( + "bytes" + "io" + "reflect" + "testing" + + "github.com/pkg/errors" +) + +func TestSessionKey_EncryptDecryptStream(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var dataPacketBuf bytes.Buffer + messageWriter, err := testSessionKey.EncryptStream( + &dataPacketBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting stream with session key, got:", err) + } + bufferSize := 2 + buffer := make([]byte, bufferSize) + reachedEnd := false + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + dataPacket := dataPacketBuf.Bytes() + decryptedReader, err := testSessionKey.DecryptStream( + bytes.NewReader(dataPacket), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling DecryptStream, got:", err) + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } +} + +func TestSessionKey_EncryptStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + messageReader := bytes.NewReader(messageBytes) + var dataPacketBuf bytes.Buffer + messageWriter, err := testSessionKey.EncryptStream( + &dataPacketBuf, + testMeta, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting stream with session key, got:", err) + } + bufferSize := 2 + buffer := make([]byte, bufferSize) + reachedEnd := false + for !reachedEnd { + n, err := messageReader.Read(buffer) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading data, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := messageWriter.Write(buffer[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing data, got:", err) + } + writtenTotal += written + } + } + err = messageWriter.Close() + if err != nil { + t.Fatal("Expected no error while closing plaintext writer, got:", err) + } + dataPacket := dataPacketBuf.Bytes() + decryptedMsg, err := testSessionKey.DecryptAndVerify( + dataPacket, + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling DecryptAndVerify, got:", err) + } + decryptedBytes := decryptedMsg.Data + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + if testMeta.IsBinary != decryptedMsg.IsBinary() { + t.Fatalf("Expected isBinary to be %t got %t", testMeta.IsBinary, decryptedMsg.IsBinary()) + } + if testMeta.Filename != decryptedMsg.GetFilename() { + t.Fatalf("Expected filename to be %s got %s", testMeta.Filename, decryptedMsg.GetFilename()) + } + if testMeta.ModTime != int64(decryptedMsg.GetTime()) { + t.Fatalf("Expected modification time to be %d got %d", testMeta.ModTime, int64(decryptedMsg.GetTime())) + } +} + +func TestSessionKey_DecryptStreamCompatible(t *testing.T) { + messageBytes := []byte("Hello World!") + dataPacket, err := testSessionKey.EncryptAndSign( + &PlainMessage{ + Data: messageBytes, + TextType: !testMeta.IsBinary, + Time: uint32(testMeta.ModTime), + Filename: testMeta.Filename, + }, + keyRingTestPrivate, + ) + if err != nil { + t.Fatal("Expected no error while encrypting plaintext, got:", err) + } + decryptedReader, err := testSessionKey.DecryptStream( + bytes.NewReader(dataPacket), + keyRingTestPublic, + GetUnixTime(), + ) + if err != nil { + t.Fatal("Expected no error while calling DecryptStream, got:", err) + } + decryptedBytes, err := io.ReadAll(decryptedReader) + if err != nil { + t.Fatal("Expected no error while reading the decrypted data, got:", err) + } + err = decryptedReader.VerifySignature() + if err != nil { + t.Fatal("Expected no error while verifying the signature, got:", err) + } + if !bytes.Equal(decryptedBytes, messageBytes) { + t.Fatalf("Expected the decrypted data to be %s got %s", string(decryptedBytes), string(messageBytes)) + } + decryptedMeta := decryptedReader.GetMetadata() + if !reflect.DeepEqual(testMeta, decryptedMeta) { + t.Fatalf("Expected the decrypted metadata to be %v got %v", testMeta, decryptedMeta) + } +} diff --git a/go.mod b/go.mod index 26b75d7..553be87 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/ProtonMail/gopenpgp/v2 go 1.15 require ( - github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7 + github.com/ProtonMail/go-crypto v0.0.0-20210512092938-c05353c2d58c github.com/ProtonMail/go-mime v0.0.0-20190923161245-9b5a4261663a github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 972817f..d724f5f 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 h1:1BDTz0u9nC3//pOC github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7 h1:YoJbenK9C67SkzkDfmQuVln04ygHj3vjZfd9FL+GmQQ= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= +github.com/ProtonMail/go-crypto v0.0.0-20210503074116-6a33223a51e6 h1:GAvhO5jJ2YkkyKuCWa2VKuHHpyZZjKgkcPvj913Mt40= +github.com/ProtonMail/go-crypto v0.0.0-20210503074116-6a33223a51e6/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= +github.com/ProtonMail/go-crypto v0.0.0-20210512092938-c05353c2d58c h1:bNpaLLv2Y4kslsdkdCwAYu8Bak1aGVtxwi8Z/wy4Yuo= +github.com/ProtonMail/go-crypto v0.0.0-20210512092938-c05353c2d58c/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= github.com/ProtonMail/go-mime v0.0.0-20190923161245-9b5a4261663a h1:W6RrgN/sTxg1msqzFFb+G80MFmpjMw61IU+slm+wln4= github.com/ProtonMail/go-mime v0.0.0-20190923161245-9b5a4261663a/go.mod h1:NYt+V3/4rEeDuaev/zw1zCq8uqVEuPHzDPo3OZrlGJ4= github.com/ProtonMail/go-mobile v0.0.0-20210326110230-f181c70e4e2b h1:XVeh08xp93T+xK6rzpCSQTZ+LwEo+ASHvOifrQ5ZgEE= diff --git a/helper/mobile_stream.go b/helper/mobile_stream.go new file mode 100644 index 0000000..299078c --- /dev/null +++ b/helper/mobile_stream.go @@ -0,0 +1,182 @@ +package helper + +import ( + "crypto/sha256" + "hash" + "io" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/pkg/errors" +) + +// Mobile2GoWriter is used to wrap a writer in the mobile app runtime, +// to be usable in the golang runtime (via gomobile). +type Mobile2GoWriter struct { + writer crypto.Writer +} + +// NewMobile2GoWriter wraps a writer to be usable in the golang runtime (via gomobile). +func NewMobile2GoWriter(writer crypto.Writer) *Mobile2GoWriter { + return &Mobile2GoWriter{writer} +} + +// Write writes the data in the provided buffer in the wrapped writer. +// It clones the provided data to prevent errors with garbage collectors. +func (w *Mobile2GoWriter) Write(b []byte) (n int, err error) { + bufferCopy := clone(b) + return w.writer.Write(bufferCopy) +} + +// Mobile2GoWriterWithSHA256 is used to wrap a writer in the mobile app runtime, +// to be usable in the golang runtime (via gomobile). +// It also computes the SHA256 hash of the data being written on the fly. +type Mobile2GoWriterWithSHA256 struct { + writer crypto.Writer + sha256 hash.Hash +} + +// NewMobile2GoWriterWithSHA256 wraps a writer to be usable in the golang runtime (via gomobile). +// The wrapper also computes the SHA256 hash of the data being written on the fly. +func NewMobile2GoWriterWithSHA256(writer crypto.Writer) *Mobile2GoWriterWithSHA256 { + return &Mobile2GoWriterWithSHA256{writer, sha256.New()} +} + +// Write writes the data in the provided buffer in the wrapped writer. +// It clones the provided data to prevent errors with garbage collectors. +// It also computes the SHA256 hash of the data being written on the fly. +func (w *Mobile2GoWriterWithSHA256) Write(b []byte) (n int, err error) { + bufferCopy := clone(b) + n, err = w.writer.Write(bufferCopy) + if err == nil { + hashedTotal := 0 + for hashedTotal < n { + hashed, err := w.sha256.Write(bufferCopy[hashedTotal:n]) + if err != nil { + return 0, errors.Wrap(err, "gopenpgp: couldn't hash encrypted data") + } + hashedTotal += hashed + } + } + return n, err +} + +// GetSHA256 returns the SHA256 hash of the data that's been written so far. +func (w *Mobile2GoWriterWithSHA256) GetSHA256() []byte { + return w.sha256.Sum(nil) +} + +// MobileReader is the interface that readers in the mobile runtime must use and implement. +// This is a workaround to some of the gomobile limitations. +type MobileReader interface { + Read(max int) (result *MobileReadResult, err error) +} + +// MobileReadResult is what needs to be returned by MobileReader.Read. +// The read data is passed as a return value rather than passed as an argument to the reader. +// This avoids problems introduced by gomobile that prevent the use of native golang readers. +type MobileReadResult struct { + N int // N, The number of bytes read + IsEOF bool // IsEOF, If true, then the reader has reached the end of the data to read. + Data []byte // Data, the data that has been read +} + +// NewMobileReadResult initialize a MobileReadResult with the correct values. +// It clones the data to avoid the garbage collector freeing the data too early. +func NewMobileReadResult(n int, eof bool, data []byte) *MobileReadResult { + return &MobileReadResult{N: n, IsEOF: eof, Data: clone(data)} +} + +func clone(src []byte) (dst []byte) { + dst = make([]byte, len(src)) + copy(dst, src) + return +} + +// Mobile2GoReader is used to wrap a MobileReader in the mobile app runtime, +// to be usable in the golang runtime (via gomobile) as a native Reader. +type Mobile2GoReader struct { + reader MobileReader +} + +// NewMobile2GoReader wraps a MobileReader to be usable in the golang runtime (via gomobile). +func NewMobile2GoReader(reader MobileReader) *Mobile2GoReader { + return &Mobile2GoReader{reader} +} + +// Read reads data from the wrapped MobileReader and copies the read data in the provided buffer. +// It also handles the conversion of EOF to an error. +func (r *Mobile2GoReader) Read(b []byte) (n int, err error) { + result, err := r.reader.Read(len(b)) + if err != nil { + return 0, errors.Wrap(err, "gopenpgp: couldn't read from mobile reader") + } + n = result.N + if n > 0 { + copy(b, result.Data[:n]) + } + if result.IsEOF { + err = io.EOF + } + return n, err +} + +// Go2AndroidReader is used to wrap a native golang Reader in the golang runtime, +// to be usable in the android app runtime (via gomobile). +type Go2AndroidReader struct { + isEOF bool + reader crypto.Reader +} + +// NewGo2AndroidReader wraps a native golang Reader to be usable in the mobile app runtime (via gomobile). +// It doesn't follow the standard golang Reader behavior, and returns n = -1 on EOF. +func NewGo2AndroidReader(reader crypto.Reader) *Go2AndroidReader { + return &Go2AndroidReader{isEOF: false, reader: reader} +} + +// Read reads bytes into the provided buffer and returns the number of bytes read +// It doesn't follow the standard golang Reader behavior, and returns n = -1 on EOF. +func (r *Go2AndroidReader) Read(b []byte) (n int, err error) { + if r.isEOF { + return -1, nil + } + n, err = r.reader.Read(b) + if errors.Is(err, io.EOF) { + if n == 0 { + return -1, nil + } else { + r.isEOF = true + return n, nil + } + } + return +} + +// Go2IOSReader is used to wrap a native golang Reader in the golang runtime, +// to be usable in the iOS app runtime (via gomobile) as a MobileReader. +type Go2IOSReader struct { + reader crypto.Reader +} + +// NewGo2IOSReader wraps a native golang Reader to be usable in the ios app runtime (via gomobile). +func NewGo2IOSReader(reader crypto.Reader) *Go2IOSReader { + return &Go2IOSReader{reader} +} + +// Read reads at most bytes from the wrapped Reader and returns the read data as a MobileReadResult. +func (r *Go2IOSReader) Read(max int) (result *MobileReadResult, err error) { + b := make([]byte, max) + n, err := r.reader.Read(b) + result = &MobileReadResult{} + if err != nil { + if errors.Is(err, io.EOF) { + result.IsEOF = true + } else { + return nil, err + } + } + result.N = n + if n > 0 { + result.Data = b[:n] + } + return result, nil +} diff --git a/helper/mobile_stream_test.go b/helper/mobile_stream_test.go new file mode 100644 index 0000000..3aac607 --- /dev/null +++ b/helper/mobile_stream_test.go @@ -0,0 +1,182 @@ +package helper + +import ( + "bytes" + "crypto/sha256" + "errors" + "io" + "testing" +) + +func cloneTestData() (a, b []byte) { + a = []byte("Hello World!") + b = clone(a) + return a, b +} +func Test_clone(t *testing.T) { + if a, b := cloneTestData(); !bytes.Equal(a, b) { + t.Fatalf("expected %x, got %x", a, b) + } +} + +func TestMobile2GoWriter(t *testing.T) { + testData := []byte("Hello World!") + outBuf := &bytes.Buffer{} + reader := bytes.NewReader(testData) + writer := NewMobile2GoWriter(outBuf) + bufSize := 2 + writeBuf := make([]byte, bufSize) + reachedEnd := false + for !reachedEnd { + n, err := reader.Read(writeBuf) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := writer.Write(writeBuf[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing, got:", err) + } + writtenTotal += written + } + } + if writtenData := outBuf.Bytes(); !bytes.Equal(testData, writtenData) { + t.Fatalf("expected %x, got %x", testData, writtenData) + } +} + +func TestMobile2GoWriterWithSHA256(t *testing.T) { + testData := []byte("Hello World!") + testHash := sha256.Sum256(testData) + outBuf := &bytes.Buffer{} + reader := bytes.NewReader(testData) + writer := NewMobile2GoWriterWithSHA256(outBuf) + bufSize := 2 + writeBuf := make([]byte, bufSize) + reachedEnd := false + for !reachedEnd { + n, err := reader.Read(writeBuf) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading, got:", err) + } + } + writtenTotal := 0 + for writtenTotal < n { + written, err := writer.Write(writeBuf[writtenTotal:n]) + if err != nil { + t.Fatal("Expected no error while writing, got:", err) + } + writtenTotal += written + } + } + if writtenData := outBuf.Bytes(); !bytes.Equal(testData, writtenData) { + t.Fatalf("expected data to be %x, got %x", testData, writtenData) + } + + if writtenHash := writer.GetSHA256(); !bytes.Equal(testHash[:], writtenHash) { + t.Fatalf("expected has to be %x, got %x", testHash, writtenHash) + } +} + +func TestGo2AndroidReader(t *testing.T) { + testData := []byte("Hello World!") + reader := NewGo2AndroidReader(bytes.NewReader(testData)) + var readData []byte + bufSize := 2 + buffer := make([]byte, bufSize) + reachedEnd := false + for !reachedEnd { + n, err := reader.Read(buffer) + if err != nil { + t.Fatal("Expected no error while reading, got:", err) + } + reachedEnd = n < 0 + if n > 0 { + readData = append(readData, buffer[:n]...) + } + } + if !bytes.Equal(testData, readData) { + t.Fatalf("expected data to be %x, got %x", testData, readData) + } +} + +func TestGo2IOSReader(t *testing.T) { + testData := []byte("Hello World!") + reader := NewGo2IOSReader(bytes.NewReader(testData)) + var readData []byte + bufSize := 2 + reachedEnd := false + for !reachedEnd { + res, err := reader.Read(bufSize) + if err != nil { + t.Fatal("Expected no error while reading, got:", err) + } + n := res.N + reachedEnd = res.IsEOF + if n > 0 { + readData = append(readData, res.Data[:n]...) + } + } + if !bytes.Equal(testData, readData) { + t.Fatalf("expected data to be %x, got %x", testData, readData) + } +} + +type testMobileReader struct { + reader io.Reader + returnError bool +} + +func (r *testMobileReader) Read(max int) (*MobileReadResult, error) { + if r.returnError { + return nil, errors.New("gopenpgp: test - forced error while reading") + } + buf := make([]byte, max) + n, err := r.reader.Read(buf) + eof := false + if err != nil { + if errors.Is(err, io.EOF) { + eof = true + } else { + return nil, errors.New("gopenpgp: test - error while reading") + } + } + return NewMobileReadResult(n, eof, buf[:n]), nil +} + +func TestMobile2GoReader(t *testing.T) { + testData := []byte("Hello World!") + reader := NewMobile2GoReader(&testMobileReader{bytes.NewReader(testData), false}) + var readData []byte + bufSize := 2 + readBuf := make([]byte, bufSize) + reachedEnd := false + for !reachedEnd { + n, err := reader.Read(readBuf) + if err != nil { + if errors.Is(err, io.EOF) { + reachedEnd = true + } else { + t.Fatal("Expected no error while reading, got:", err) + } + } + if n > 0 { + readData = append(readData, readBuf[:n]...) + } + } + if !bytes.Equal(testData, readData) { + t.Fatalf("expected data to be %x, got %x", testData, readData) + } + readerErr := NewMobile2GoReader(&testMobileReader{bytes.NewReader(testData), true}) + if _, err := readerErr.Read(readBuf); err == nil { + t.Fatal("expected an error while reading, got nil") + } +}