1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Copyright 2020 The Tink-Rust Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

use std::{
    cell::{RefCell, RefMut},
    io,
    rc::Rc,
};

/// Possible states for a [`DecryptReader`].
enum State {
    // Matching primitive not yet determined, raw ciphertext reader available.
    Pending(Box<dyn io::Read>),
    // Matching primitive that correctly decrypts has been found.
    Found(Box<dyn io::Read>),
    // No matching primitive available.
    Failed,
}

/// `DecryptReader` is a reader that tries to find the right key to decrypt ciphertext from the
/// given primitive set.
pub(crate) struct DecryptReader {
    wrapped: crate::WrappedStreamingAead,
    aad: Vec<u8>,
    state: State,
}

impl DecryptReader {
    pub fn new(
        wrapped: crate::WrappedStreamingAead,
        reader: Box<dyn io::Read>,
        aad: &[u8],
    ) -> Self {
        Self {
            wrapped,
            aad: aad.to_vec(),
            state: State::Pending(reader),
        }
    }
}

impl io::Read for DecryptReader {
    fn read(&mut self, p: &mut [u8]) -> io::Result<usize> {
        match &mut self.state {
            State::Found(reader) => return reader.read(p),
            State::Failed => {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    "read previously failed",
                ))
            }
            State::Pending(_) => {}
        }
        // Move the underlying raw reader out of self and into a `SharedCopyReader`
        let state = std::mem::replace(&mut self.state, State::Failed);
        let raw_reader = match state {
            State::Pending(reader) => reader,
            _ => unreachable!(), // safe: checked above
        };
        let mut copy_reader = SharedCopyReader::new(raw_reader);

        // find proper key to decrypt ciphertext
        if let Some(entries) = self.wrapped.ps.raw_entries() {
            for e in entries {
                // Attempt a decrypting-read from the ciphertext reader `cr`, but also keep a copy
                // of the read data into a buffer so that it can be re-scanned with
                // a different key if decryption fails.
                let mut r = match e
                    .primitive
                    .new_decrypting_reader(Box::new(copy_reader.clone()), &self.aad)
                {
                    Ok(r) => r,
                    Err(_) => {
                        copy_reader.rewind();
                        continue;
                    }
                };
                let n = match r.read(p) {
                    Ok(n) => n,
                    Err(_) => {
                        // The read attempt will have consumed some of the underlying reader, but
                        // there is a copy of the data that has been read. Ensure that this
                        // already-read data is re-used next time around.
                        copy_reader.rewind();
                        continue;
                    }
                };

                // Reading has succeeded, so use this particular key from now on and no longer need
                // to store copies of read data.
                copy_reader.stop_copying();
                self.state = State::Found(r);
                return Ok(n);
            }
        }
        Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "no matching key found for the ciphertext in the stream",
        ))
    }
}

/// Wrapper around an [`io::Read`] trait object that stores a copy of all of the data
/// read from the underlying object.
struct CopyReader {
    reader: Box<dyn io::Read>,
    copying: bool,
    read_pos: usize,
    copied_data: Vec<u8>,
}

impl CopyReader {
    fn new(reader: Box<dyn io::Read>) -> Self {
        Self {
            reader,
            copying: true,
            read_pos: 0,
            copied_data: vec![],
        }
    }
    fn rewind(&mut self) {
        self.read_pos = 0;
    }
    fn stop_copying(&mut self) {
        self.copying = false;
        // Buffered data has been consumed, so drop it.
        self.copied_data = vec![];
        self.read_pos = 0;
    }
}

impl io::Read for CopyReader {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        if self.read_pos < self.copied_data.len() {
            // Read from the buffered copy of the data.
            let available_data = self.copied_data.len() - self.read_pos;
            let n = std::cmp::min(buf.len(), available_data);
            buf[..n].copy_from_slice(&self.copied_data[self.read_pos..self.read_pos + n]);
            self.read_pos += n;
            Ok(n)
        } else {
            // Read from the underlying object
            let n = self.reader.read(buf)?;
            if self.copying {
                // Store a copy of the data read.
                self.copied_data.extend_from_slice(&buf[..n]);
                self.read_pos += n;
            }
            Ok(n)
        }
    }
}

#[derive(Clone)]
struct SharedCopyReader(Rc<RefCell<CopyReader>>);

impl SharedCopyReader {
    fn new(reader: Box<dyn io::Read>) -> Self {
        Self(Rc::new(RefCell::new(CopyReader::new(reader))))
    }
    fn rewind(&mut self) {
        let mut cr: RefMut<_> = self.0.borrow_mut();
        cr.rewind();
    }
    fn stop_copying(&mut self) {
        let mut cr: RefMut<_> = self.0.borrow_mut();
        cr.stop_copying();
    }
}

impl io::Read for SharedCopyReader {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let mut cr: RefMut<_> = self.0.borrow_mut();
        cr.read(buf)
    }
}