diff --git a/GCryptLib/include/GCrypt/Block.h b/GCryptLib/include/GCrypt/Block.h index 2184e75..3fefbd5 100644 --- a/GCryptLib/include/GCrypt/Block.h +++ b/GCryptLib/include/GCrypt/Block.h @@ -35,8 +35,8 @@ namespace Leonetienne::GCrypt { //! Since the matrices values are pretty much sudo-random, //! they will most likely integer-overflow. //! So see this as a one-way function. - Block MMul(const Block& other) const; - Block operator*(const Block& other) const; + [[nodiscard]] Block MMul(const Block& other) const; + [[nodiscard]] Block operator*(const Block& other) const; //! Will matrix-multiply two blocks together, //! and directly write into this same block. @@ -47,67 +47,98 @@ namespace Leonetienne::GCrypt { Block& operator*=(const Block& other); //! Will xor two blocks together - Block Xor(const Block& other) const; + [[nodiscard]] Block Xor(const Block& other) const; //! Will xor two blocks together - Block operator^(const Block& other) const; + [[nodiscard]] Block operator^(const Block& other) const; //! Will xor two blocks together, inplace void XorInplace(const Block& other); //! Will xor two blocks together, inplace Block& operator^=(const Block& other); - // # TO BE IMPLEMENTED - //! Will shift rows upwards by n - void ShiftRowsUp(const std::size_t n); + //! Will add all the integer making up this block, one by one + [[nodiscard]] Block Add(const Block& other) const; + //! Will add all the integer making up this block, one by one + [[nodiscard]] Block operator+(const Block& other) const; - // # TO BE IMPLEMENTED - //! Will shift matrix rows downwards by n - void ShiftRowsDown(const std::size_t n); + //! Will add all the integer making up this block, one by one, inplace + void AddInplace(const Block& other); + //! Will add all the integer making up this block, one by one, inplace + Block& operator+=(const Block& other); - // # TO BE IMPLEMENTED - //! Will shift matrix columns to the left by n - void ShiftColumnsLeft(const std::size_t n); + //! Will subtract all the integer making up this block, one by one + [[nodiscard]] Block Sub(const Block& other) const; + //! Will subtract all the integer making up this block, one by one + [[nodiscard]] Block operator-(const Block& other) const; - // # TO BE IMPLEMENTED - //! Will shift matrix columns to the right by n - void ShiftColumnsRight(const std::size_t n); + //! Will subtract all the integer making up this block, one by one, inplace + void SubInplace(const Block& other); + //! Will subtract all the integer making up this block, one by one, inplace + Block& operator-=(const Block& other); - // # TO BE IMPLEMENTED - //! Will shift array cells to the left by n - void ShiftCellsLeft(const std::size_t n); + //! Will shift rows upwards by 1 + [[nodiscard]] Block ShiftRowsUp() const; - // # TO BE IMPLEMENTED - //! Will shift array cells to the right by n - void ShiftCellsRight(const std::size_t n); + //! Will shift rows upwards by 1 + void ShiftRowsUpInplace(); + + //! Will shift matrix rows downwards by 1 + [[nodiscard]] Block ShiftRowsDown() const; + + //! Will shift matrix rows downwards by 1 + void ShiftRowsDownInplace(); + + //! Will shift matrix columns to the left by 1 + [[nodiscard]] Block ShiftColumnsLeft() const; + + //! Will shift matrix columns to the left by 1 + void ShiftColumnsLeftInplace(); + + //! Will shift matrix columns to the right by 1 + [[nodiscard]] Block ShiftColumnsRight() const; + + //! Will shift matrix columns to the right by 1 + void ShiftColumnsRightInplace(); + + //! Will shift array cells to the left by 1 + [[nodiscard]] Block ShiftCellsLeft() const; + + //! Will shift array cells to the left by 1 + void ShiftCellsLeftInplace(); + + //! Will shift array cells to the right by 1 + [[nodiscard]] Block ShiftCellsRight() const; + + //! Will shift array cells to the right by 1 + void ShiftCellsRightInplace(); //! Will copy a block Block& operator=(const Block& other); //! Will compare whether or not two blocks are equal - bool operator==(const Block& other) const; + [[nodiscard]] bool operator==(const Block& other) const; //! Will compare whether or not two blocks are unequal - bool operator!=(const Block& other) const; + [[nodiscard]] bool operator!=(const Block& other) const; //! Will zero all data void Reset(); - //! Returns 32-bit chunks of data, indexed by matrix coordinates (0-3) - std::uint32_t& Get(const std::uint8_t row, const std::uint8_t column); + [[nodiscard]] std::uint32_t& Get(const std::uint8_t row, const std::uint8_t column); //! Returns 32-bit chunks of data, indexed by matrix coordinates (0-3) - const std::uint32_t& Get(const std::uint8_t row, const std::uint8_t column) const; + [[nodiscard]] const std::uint32_t& Get(const std::uint8_t row, const std::uint8_t column) const; //! Returns 32-bit chunks of data, indexed by a 1d-index (0-16) - std::uint32_t& Get(const std::uint8_t index); + [[nodiscard]] std::uint32_t& Get(const std::uint8_t index); //! Returns 32-bit chunks of data, indexed by a 1d-index (0-16) - const std::uint32_t& Get(const std::uint8_t index) const; + [[nodiscard]] const std::uint32_t& Get(const std::uint8_t index) const; //! Returns 32-bit chunks of data, indexed by a 1d-index (0-16) - std::uint32_t& operator[](const std::uint8_t index); + [[nodiscard]] std::uint32_t& operator[](const std::uint8_t index); //! Returns 32-bit chunks of data, indexed by a 1d-index (0-16) - const std::uint32_t& operator[](const std::uint8_t index) const; + [[nodiscard]] const std::uint32_t& operator[](const std::uint8_t index) const; static constexpr std::size_t CHUNK_SIZE = sizeof(std::uint32_t); static constexpr std::size_t CHUNK_SIZE_BITS = CHUNK_SIZE * 8; diff --git a/GCryptLib/src/Block.cpp b/GCryptLib/src/Block.cpp index 0152541..bd16dd1 100644 --- a/GCryptLib/src/Block.cpp +++ b/GCryptLib/src/Block.cpp @@ -5,6 +5,10 @@ #include #include +// Just to be sure, the compiler will optimize this +// little formula out, let's do it in the preprocessor +#define MAT_INDEX(row, column) (column*4 + row) + namespace Leonetienne::GCrypt { Block::Block() { @@ -132,28 +136,310 @@ namespace Leonetienne::GCrypt { return *this; } - void ShiftRowsUp(const std::size_t n) { - // TO BE IMPLEMENTED + Block Block::Add(const Block& other) const { + + Block m; + for (std::size_t i = 0; i < data.size(); i++) { + m.Get(i) = this->Get(i) + other.Get(i); + } + return m; } - void ShiftRowsDown(const std::size_t n) { - // TO BE IMPLEMENTED + Block Block::operator+(const Block& other) const { + return Add(other); } - void ShiftColumnsLeft(const std::size_t n) { - // TO BE IMPLEMENTED + void Block::AddInplace(const Block& other) { + for (std::size_t i = 0; i < data.size(); i++) { + this->Get(i) += other.Get(i); + } + return; } - void ShiftColumnsRight(const std::size_t n) { - // TO BE IMPLEMENTED + Block& Block::operator+=(const Block& other) { + AddInplace(other); + return *this; } - void ShiftCellsLeft(const std::size_t n) { - // TO BE IMPLEMENTED + Block Block::Sub(const Block& other) const { + + Block m; + for (std::size_t i = 0; i < data.size(); i++) { + m.Get(i) = this->Get(i) - other.Get(i); + } + return m; } - void ShiftCellsRight(const std::size_t n) { - // TO BE IMPLEMENTED + Block Block::operator-(const Block& other) const { + return Sub(other); + } + + void Block::SubInplace(const Block& other) { + for (std::size_t i = 0; i < data.size(); i++) { + this->Get(i) -= other.Get(i); + } + return; + } + + Block& Block::operator-=(const Block& other) { + SubInplace(other); + return *this; + } + + void Block::ShiftRowsUpInplace() { + Block tmp = *this; + + Get(MAT_INDEX(0, 0)) = tmp.Get(MAT_INDEX(1, 0)); + Get(MAT_INDEX(0, 1)) = tmp.Get(MAT_INDEX(1, 1)); + Get(MAT_INDEX(0, 2)) = tmp.Get(MAT_INDEX(1, 2)); + Get(MAT_INDEX(0, 3)) = tmp.Get(MAT_INDEX(1, 3)); + + Get(MAT_INDEX(1, 0)) = tmp.Get(MAT_INDEX(2, 0)); + Get(MAT_INDEX(1, 1)) = tmp.Get(MAT_INDEX(2, 1)); + Get(MAT_INDEX(1, 2)) = tmp.Get(MAT_INDEX(2, 2)); + Get(MAT_INDEX(1, 3)) = tmp.Get(MAT_INDEX(2, 3)); + + Get(MAT_INDEX(2, 0)) = tmp.Get(MAT_INDEX(3, 0)); + Get(MAT_INDEX(2, 1)) = tmp.Get(MAT_INDEX(3, 1)); + Get(MAT_INDEX(2, 2)) = tmp.Get(MAT_INDEX(3, 2)); + Get(MAT_INDEX(2, 3)) = tmp.Get(MAT_INDEX(3, 3)); + + Get(MAT_INDEX(3, 0)) = tmp.Get(MAT_INDEX(0, 0)); + Get(MAT_INDEX(3, 1)) = tmp.Get(MAT_INDEX(0, 1)); + Get(MAT_INDEX(3, 2)) = tmp.Get(MAT_INDEX(0, 2)); + Get(MAT_INDEX(3, 3)) = tmp.Get(MAT_INDEX(0, 3)); + + return; + } + + Block Block::ShiftRowsUp() const { + Block b; + + b.Get(MAT_INDEX(0, 0)) = Get(MAT_INDEX(1, 0)); + b.Get(MAT_INDEX(0, 1)) = Get(MAT_INDEX(1, 1)); + b.Get(MAT_INDEX(0, 2)) = Get(MAT_INDEX(1, 2)); + b.Get(MAT_INDEX(0, 3)) = Get(MAT_INDEX(1, 3)); + + b.Get(MAT_INDEX(1, 0)) = Get(MAT_INDEX(2, 0)); + b.Get(MAT_INDEX(1, 1)) = Get(MAT_INDEX(2, 1)); + b.Get(MAT_INDEX(1, 2)) = Get(MAT_INDEX(2, 2)); + b.Get(MAT_INDEX(1, 3)) = Get(MAT_INDEX(2, 3)); + + b.Get(MAT_INDEX(2, 0)) = Get(MAT_INDEX(3, 0)); + b.Get(MAT_INDEX(2, 1)) = Get(MAT_INDEX(3, 1)); + b.Get(MAT_INDEX(2, 2)) = Get(MAT_INDEX(3, 2)); + b.Get(MAT_INDEX(2, 3)) = Get(MAT_INDEX(3, 3)); + + b.Get(MAT_INDEX(3, 0)) = Get(MAT_INDEX(0, 0)); + b.Get(MAT_INDEX(3, 1)) = Get(MAT_INDEX(0, 1)); + b.Get(MAT_INDEX(3, 2)) = Get(MAT_INDEX(0, 2)); + b.Get(MAT_INDEX(3, 3)) = Get(MAT_INDEX(0, 3)); + + return b; + } + + void Block::ShiftRowsDownInplace() { + Block tmp = *this; + + Get(MAT_INDEX(0, 0)) = tmp.Get(MAT_INDEX(3, 0)); + Get(MAT_INDEX(0, 1)) = tmp.Get(MAT_INDEX(3, 1)); + Get(MAT_INDEX(0, 2)) = tmp.Get(MAT_INDEX(3, 2)); + Get(MAT_INDEX(0, 3)) = tmp.Get(MAT_INDEX(3, 3)); + + Get(MAT_INDEX(1, 0)) = tmp.Get(MAT_INDEX(0, 0)); + Get(MAT_INDEX(1, 1)) = tmp.Get(MAT_INDEX(0, 1)); + Get(MAT_INDEX(1, 2)) = tmp.Get(MAT_INDEX(0, 2)); + Get(MAT_INDEX(1, 3)) = tmp.Get(MAT_INDEX(0, 3)); + + Get(MAT_INDEX(2, 0)) = tmp.Get(MAT_INDEX(1, 0)); + Get(MAT_INDEX(2, 1)) = tmp.Get(MAT_INDEX(1, 1)); + Get(MAT_INDEX(2, 2)) = tmp.Get(MAT_INDEX(1, 2)); + Get(MAT_INDEX(2, 3)) = tmp.Get(MAT_INDEX(1, 3)); + + Get(MAT_INDEX(3, 0)) = tmp.Get(MAT_INDEX(2, 0)); + Get(MAT_INDEX(3, 1)) = tmp.Get(MAT_INDEX(2, 1)); + Get(MAT_INDEX(3, 2)) = tmp.Get(MAT_INDEX(2, 2)); + Get(MAT_INDEX(3, 3)) = tmp.Get(MAT_INDEX(2, 3)); + + return; + } + + Block Block::ShiftRowsDown() const { + Block b; + + b.Get(MAT_INDEX(0, 0)) = Get(MAT_INDEX(3, 0)); + b.Get(MAT_INDEX(0, 1)) = Get(MAT_INDEX(3, 1)); + b.Get(MAT_INDEX(0, 2)) = Get(MAT_INDEX(3, 2)); + b.Get(MAT_INDEX(0, 3)) = Get(MAT_INDEX(3, 3)); + + b.Get(MAT_INDEX(1, 0)) = Get(MAT_INDEX(0, 0)); + b.Get(MAT_INDEX(1, 1)) = Get(MAT_INDEX(0, 1)); + b.Get(MAT_INDEX(1, 2)) = Get(MAT_INDEX(0, 2)); + b.Get(MAT_INDEX(1, 3)) = Get(MAT_INDEX(0, 3)); + + b.Get(MAT_INDEX(2, 0)) = Get(MAT_INDEX(1, 0)); + b.Get(MAT_INDEX(2, 1)) = Get(MAT_INDEX(1, 1)); + b.Get(MAT_INDEX(2, 2)) = Get(MAT_INDEX(1, 2)); + b.Get(MAT_INDEX(2, 3)) = Get(MAT_INDEX(1, 3)); + + b.Get(MAT_INDEX(3, 0)) = Get(MAT_INDEX(2, 0)); + b.Get(MAT_INDEX(3, 1)) = Get(MAT_INDEX(2, 1)); + b.Get(MAT_INDEX(3, 2)) = Get(MAT_INDEX(2, 2)); + b.Get(MAT_INDEX(3, 3)) = Get(MAT_INDEX(2, 3)); + + return b; + } + + void Block::ShiftColumnsLeftInplace() { + Block tmp = *this; + + Get(MAT_INDEX(0, 0)) = tmp.Get(MAT_INDEX(0, 1)); + Get(MAT_INDEX(1, 0)) = tmp.Get(MAT_INDEX(1, 1)); + Get(MAT_INDEX(2, 0)) = tmp.Get(MAT_INDEX(2, 1)); + Get(MAT_INDEX(3, 0)) = tmp.Get(MAT_INDEX(3, 1)); + + Get(MAT_INDEX(0, 1)) = tmp.Get(MAT_INDEX(0, 2)); + Get(MAT_INDEX(1, 1)) = tmp.Get(MAT_INDEX(1, 2)); + Get(MAT_INDEX(2, 1)) = tmp.Get(MAT_INDEX(2, 2)); + Get(MAT_INDEX(3, 1)) = tmp.Get(MAT_INDEX(3, 2)); + + Get(MAT_INDEX(0, 2)) = tmp.Get(MAT_INDEX(0, 3)); + Get(MAT_INDEX(1, 2)) = tmp.Get(MAT_INDEX(1, 3)); + Get(MAT_INDEX(2, 2)) = tmp.Get(MAT_INDEX(2, 3)); + Get(MAT_INDEX(3, 2)) = tmp.Get(MAT_INDEX(3, 3)); + + Get(MAT_INDEX(0, 3)) = tmp.Get(MAT_INDEX(0, 0)); + Get(MAT_INDEX(1, 3)) = tmp.Get(MAT_INDEX(1, 0)); + Get(MAT_INDEX(2, 3)) = tmp.Get(MAT_INDEX(2, 0)); + Get(MAT_INDEX(3, 3)) = tmp.Get(MAT_INDEX(3, 0)); + + return; + } + + Block Block::ShiftColumnsLeft() const { + Block b; + + b.Get(MAT_INDEX(0, 0)) = Get(MAT_INDEX(0, 1)); + b.Get(MAT_INDEX(1, 0)) = Get(MAT_INDEX(1, 1)); + b.Get(MAT_INDEX(2, 0)) = Get(MAT_INDEX(2, 1)); + b.Get(MAT_INDEX(3, 0)) = Get(MAT_INDEX(3, 1)); + + b.Get(MAT_INDEX(0, 1)) = Get(MAT_INDEX(0, 2)); + b.Get(MAT_INDEX(1, 1)) = Get(MAT_INDEX(1, 2)); + b.Get(MAT_INDEX(2, 1)) = Get(MAT_INDEX(2, 2)); + b.Get(MAT_INDEX(3, 1)) = Get(MAT_INDEX(3, 2)); + + b.Get(MAT_INDEX(0, 2)) = Get(MAT_INDEX(0, 3)); + b.Get(MAT_INDEX(1, 2)) = Get(MAT_INDEX(1, 3)); + b.Get(MAT_INDEX(2, 2)) = Get(MAT_INDEX(2, 3)); + b.Get(MAT_INDEX(3, 2)) = Get(MAT_INDEX(3, 3)); + + b.Get(MAT_INDEX(0, 3)) = Get(MAT_INDEX(0, 0)); + b.Get(MAT_INDEX(1, 3)) = Get(MAT_INDEX(1, 0)); + b.Get(MAT_INDEX(2, 3)) = Get(MAT_INDEX(2, 0)); + b.Get(MAT_INDEX(3, 3)) = Get(MAT_INDEX(3, 0)); + + return b; + } + + void Block::ShiftColumnsRightInplace() { + Block tmp = *this; + + Get(MAT_INDEX(0, 1)) = tmp.Get(MAT_INDEX(0, 0)); + Get(MAT_INDEX(1, 1)) = tmp.Get(MAT_INDEX(1, 0)); + Get(MAT_INDEX(2, 1)) = tmp.Get(MAT_INDEX(2, 0)); + Get(MAT_INDEX(3, 1)) = tmp.Get(MAT_INDEX(3, 0)); + + Get(MAT_INDEX(0, 2)) = tmp.Get(MAT_INDEX(0, 1)); + Get(MAT_INDEX(1, 2)) = tmp.Get(MAT_INDEX(1, 1)); + Get(MAT_INDEX(2, 2)) = tmp.Get(MAT_INDEX(2, 1)); + Get(MAT_INDEX(3, 2)) = tmp.Get(MAT_INDEX(3, 1)); + + Get(MAT_INDEX(0, 3)) = tmp.Get(MAT_INDEX(0, 2)); + Get(MAT_INDEX(1, 3)) = tmp.Get(MAT_INDEX(1, 2)); + Get(MAT_INDEX(2, 3)) = tmp.Get(MAT_INDEX(2, 2)); + Get(MAT_INDEX(3, 3)) = tmp.Get(MAT_INDEX(3, 2)); + + Get(MAT_INDEX(0, 0)) = tmp.Get(MAT_INDEX(0, 3)); + Get(MAT_INDEX(1, 0)) = tmp.Get(MAT_INDEX(1, 3)); + Get(MAT_INDEX(2, 0)) = tmp.Get(MAT_INDEX(2, 3)); + Get(MAT_INDEX(3, 0)) = tmp.Get(MAT_INDEX(3, 3)); + + return; + } + + Block Block::ShiftColumnsRight() const { + Block b; + + b.Get(MAT_INDEX(0, 1)) = Get(MAT_INDEX(0, 0)); + b.Get(MAT_INDEX(1, 1)) = Get(MAT_INDEX(1, 0)); + b.Get(MAT_INDEX(2, 1)) = Get(MAT_INDEX(2, 0)); + b.Get(MAT_INDEX(3, 1)) = Get(MAT_INDEX(3, 0)); + + b.Get(MAT_INDEX(0, 2)) = Get(MAT_INDEX(0, 1)); + b.Get(MAT_INDEX(1, 2)) = Get(MAT_INDEX(1, 1)); + b.Get(MAT_INDEX(2, 2)) = Get(MAT_INDEX(2, 1)); + b.Get(MAT_INDEX(3, 2)) = Get(MAT_INDEX(3, 1)); + + b.Get(MAT_INDEX(0, 3)) = Get(MAT_INDEX(0, 2)); + b.Get(MAT_INDEX(1, 3)) = Get(MAT_INDEX(1, 2)); + b.Get(MAT_INDEX(2, 3)) = Get(MAT_INDEX(2, 2)); + b.Get(MAT_INDEX(3, 3)) = Get(MAT_INDEX(3, 2)); + + b.Get(MAT_INDEX(0, 0)) = Get(MAT_INDEX(0, 3)); + b.Get(MAT_INDEX(1, 0)) = Get(MAT_INDEX(1, 3)); + b.Get(MAT_INDEX(2, 0)) = Get(MAT_INDEX(2, 3)); + b.Get(MAT_INDEX(3, 0)) = Get(MAT_INDEX(3, 3)); + + return b; + } + + void Block::ShiftCellsLeftInplace() { + Block tmp = *this; + + Get(15) = tmp.Get(0); + + for (std::size_t i = 0; i < 15; i++) { + Get(i) = tmp.Get(i+1); + } + + return; + } + + Block Block::ShiftCellsLeft() const { + Block b; + + b.Get(15) = Get(0); + + for (std::size_t i = 0; i < 15; i++) { + b.Get(i) = Get(i+1); + } + + return b; + } + + void Block::ShiftCellsRightInplace() { + Block tmp = *this; + + Get(0) = tmp.Get(15); + + for (std::size_t i = 1; i < 16; i++) { + Get(i) = tmp.Get(i-1); + } + + return; + } + + Block Block::ShiftCellsRight() const { + Block b; + + b.Get(0) = Get(15); + + for (std::size_t i = 1; i < 16; i++) { + b.Get(i) = Get(i-1); + } + + return b; } Block& Block::operator=(const Block& other) { @@ -162,11 +448,11 @@ namespace Leonetienne::GCrypt { } std::uint32_t& Block::Get(const std::uint8_t row, const std::uint8_t column){ - return data[column*4 + row]; + return data[MAT_INDEX(row, column)]; } const std::uint32_t& Block::Get(const std::uint8_t row, const std::uint8_t column) const { - return data[column*4 + row]; + return data[MAT_INDEX(row, column)]; } std::uint32_t& Block::Get(const std::uint8_t index) { @@ -222,3 +508,5 @@ namespace Leonetienne::GCrypt { } +#undef MAT_INDEX + diff --git a/GCryptLib/src/Feistel.cpp b/GCryptLib/src/Feistel.cpp index 565d54a..9f2d20d 100644 --- a/GCryptLib/src/Feistel.cpp +++ b/GCryptLib/src/Feistel.cpp @@ -1,7 +1,6 @@ #include #include "GCrypt/Feistel.h" #include "GCrypt/Util.h" -#include "GCrypt/BlockMatrix.h" #include "GCrypt/Config.h" namespace Leonetienne::GCrypt { diff --git a/GCryptLib/test/Block.cpp b/GCryptLib/test/Block.cpp index faa715e..09730f7 100644 --- a/GCryptLib/test/Block.cpp +++ b/GCryptLib/test/Block.cpp @@ -1,4 +1,5 @@ #include +#include #include "Catch2.h" #include #include @@ -148,6 +149,116 @@ TEST_CASE(__FILE__"/operator^&=", "[Block]") { REQUIRE(block1 == block3); } +// Tests that operator+ (add) works +TEST_CASE(__FILE__"/add", "[Block]") { + + // Setup + Block block; + for (std::size_t i = 0; i < 16; i++) { + block.Get(i) = i * 1024; + } + + Block addRH; + for (std::size_t i = 0; i < 16; i++) { + addRH.Get(i) = i * 5099; + } + + // Exercise + Block result = block + addRH; + + Block manualResult; + for (std::size_t i = 0; i < 16; i++) { + manualResult.Get(i) = (i * 1024) + (i * 5099); + } + + // Verify + REQUIRE(result == manualResult); +} + +// Tests that operator+ is the same as += +TEST_CASE(__FILE__"/operator+&=", "[Block]") { + + // Setup + Block block1; + for (std::size_t i = 0; i < 16; i++) { + block1.Get(i) = i * 1024; + } + + Block block2; + for (std::size_t i = 0; i < 16; i++) { + block2.Get(i) = i * 5099 * 2; + } + + // Exercise + Block block3 = block1 + block2; + block1 += block2; + + // Verify + REQUIRE(block1 == block3); +} + +// Tests that operator- (subtract) works +TEST_CASE(__FILE__"/subtract", "[Block]") { + + // Setup + Block block; + for (std::size_t i = 0; i < 16; i++) { + block.Get(i) = i * 1024; + } + + Block subRH; + for (std::size_t i = 0; i < 16; i++) { + subRH.Get(i) = i * 5099; + } + + // Exercise + Block result = block - subRH; + + Block manualResult; + for (std::size_t i = 0; i < 16; i++) { + manualResult.Get(i) = (i * 1024) - (i * 5099); + } + + // Verify + REQUIRE(result == manualResult); +} + +// Tests that operator- is the same as -= +TEST_CASE(__FILE__"/operator-&=", "[Block]") { + + // Setup + Block block1; + for (std::size_t i = 0; i < 16; i++) { + block1.Get(i) = i * 1024; + } + + Block block2; + for (std::size_t i = 0; i < 16; i++) { + block2.Get(i) = i * 5099 * 2; + } + + // Exercise + Block block3 = block1 - block2; + block1 -= block2; + + // Verify + REQUIRE(block1 == block3); +} + +// Tests that subtraction undoes addition, and vica versa +TEST_CASE(__FILE__"/subtraction-undoes-addition", "[Block]") { + // Setup + const Block a = Key::FromPassword("Halleluja"); + const Block b = Key::FromPassword("Ananas"); + + // Exercise + const Block a_plus_b = a + b; + const Block a_plus_b_minus_b = a_plus_b - b; + + // Verify + REQUIRE(a == a_plus_b_minus_b); +} + // Tests that operator== works correctly TEST_CASE(__FILE__"/operator==", "[Block]") { @@ -249,3 +360,276 @@ TEST_CASE(__FILE__"/reset", "[Block]") { REQUIRE(block[i] == 0); } } + +// Tests that shift rows up works +TEST_CASE(__FILE__"/shift-rows-up", "[Block]") { + + // Setup + Block a; + a.Get(0,0) = 10; a.Get(0,1) = 11; a.Get(0,2) = 12; a.Get(0,3) = 13; + a.Get(1,0) = 20; a.Get(1,1) = 21; a.Get(1,2) = 22; a.Get(1,3) = 23; + a.Get(2,0) = 30; a.Get(2,1) = 31; a.Get(2,2) = 32; a.Get(2,3) = 33; + a.Get(3,0) = 40; a.Get(3,1) = 41; a.Get(3,2) = 42; a.Get(3,3) = 43; + + Block e; /* expected */ + e.Get(0,0) = 20; e.Get(0,1) = 21; e.Get(0,2) = 22; e.Get(0,3) = 23; + e.Get(1,0) = 30; e.Get(1,1) = 31; e.Get(1,2) = 32; e.Get(1,3) = 33; + e.Get(2,0) = 40; e.Get(2,1) = 41; e.Get(2,2) = 42; e.Get(2,3) = 43; + e.Get(3,0) = 10; e.Get(3,1) = 11; e.Get(3,2) = 12; e.Get(3,3) = 13; + + // Exercise + a.ShiftRowsUpInplace(); + + // Verify + REQUIRE(a == e); +} + +// Tests that ShiftRowsUpInplace() does the exact same thing as ShiftRowsUp() +TEST_CASE(__FILE__"/shift-rows-up-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftRowsUpInplace(); + REQUIRE(a == initial_a.ShiftRowsUp()); +} + +// Tests that shift rows down works +TEST_CASE(__FILE__"/shift-rows-down", "[Block]") { + + // Setup + Block a; + a.Get(0,0) = 10; a.Get(0,1) = 11; a.Get(0,2) = 12; a.Get(0,3) = 13; + a.Get(1,0) = 20; a.Get(1,1) = 21; a.Get(1,2) = 22; a.Get(1,3) = 23; + a.Get(2,0) = 30; a.Get(2,1) = 31; a.Get(2,2) = 32; a.Get(2,3) = 33; + a.Get(3,0) = 40; a.Get(3,1) = 41; a.Get(3,2) = 42; a.Get(3,3) = 43; + + Block e; /* expected */ + e.Get(0,0) = 40; e.Get(0,1) = 41; e.Get(0,2) = 42; e.Get(0,3) = 43; + e.Get(1,0) = 10; e.Get(1,1) = 11; e.Get(1,2) = 12; e.Get(1,3) = 13; + e.Get(2,0) = 20; e.Get(2,1) = 21; e.Get(2,2) = 22; e.Get(2,3) = 23; + e.Get(3,0) = 30; e.Get(3,1) = 31; e.Get(3,2) = 32; e.Get(3,3) = 33; + + // Exercise + a.ShiftRowsDownInplace(); + + // Verify + REQUIRE(a == e); +} + +// Tests that ShiftRowsDownInplace() does the exact same thing as ShiftRowsDown() +TEST_CASE(__FILE__"/shift-rows-down-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftRowsDownInplace(); + REQUIRE(a == initial_a.ShiftRowsDown()); +} + +// Tests that shift columns left works +TEST_CASE(__FILE__"/shift-columns-left", "[Block]") { + + // Setup + Block a; + a.Get(0,0) = 10; a.Get(0,1) = 11; a.Get(0,2) = 12; a.Get(0,3) = 13; + a.Get(1,0) = 20; a.Get(1,1) = 21; a.Get(1,2) = 22; a.Get(1,3) = 23; + a.Get(2,0) = 30; a.Get(2,1) = 31; a.Get(2,2) = 32; a.Get(2,3) = 33; + a.Get(3,0) = 40; a.Get(3,1) = 41; a.Get(3,2) = 42; a.Get(3,3) = 43; + + Block e; /* expected */ + e.Get(0,0) = 11; e.Get(0,1) = 12; e.Get(0,2) = 13; e.Get(0,3) = 10; + e.Get(1,0) = 21; e.Get(1,1) = 22; e.Get(1,2) = 23; e.Get(1,3) = 20; + e.Get(2,0) = 31; e.Get(2,1) = 32; e.Get(2,2) = 33; e.Get(2,3) = 30; + e.Get(3,0) = 41; e.Get(3,1) = 42; e.Get(3,2) = 43; e.Get(3,3) = 40; + + // Exercise + a.ShiftColumnsLeftInplace(); + + // Verify + REQUIRE(a == e); +} + +// Tests that ShiftColumnsLeftInplace()() does the exact same thing as ShiftColumnsLeft() +TEST_CASE(__FILE__"/shift-columns-left-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftColumnsLeftInplace(); + REQUIRE(a == initial_a.ShiftColumnsLeft()); +} + +// Tests that shift columns right works +TEST_CASE(__FILE__"/shift-columns-right", "[Block]") { + + // Setup + Block a; + a.Get(0,0) = 10; a.Get(0,1) = 11; a.Get(0,2) = 12; a.Get(0,3) = 13; + a.Get(1,0) = 20; a.Get(1,1) = 21; a.Get(1,2) = 22; a.Get(1,3) = 23; + a.Get(2,0) = 30; a.Get(2,1) = 31; a.Get(2,2) = 32; a.Get(2,3) = 33; + a.Get(3,0) = 40; a.Get(3,1) = 41; a.Get(3,2) = 42; a.Get(3,3) = 43; + + Block e; /* expected */ + e.Get(0,0) = 13; e.Get(0,1) = 10; e.Get(0,2) = 11; e.Get(0,3) = 12; + e.Get(1,0) = 23; e.Get(1,1) = 20; e.Get(1,2) = 21; e.Get(1,3) = 22; + e.Get(2,0) = 33; e.Get(2,1) = 30; e.Get(2,2) = 31; e.Get(2,3) = 32; + e.Get(3,0) = 43; e.Get(3,1) = 40; e.Get(3,2) = 41; e.Get(3,3) = 42; + + // Exercise + a.ShiftColumnsRightInplace(); + + // Verify + REQUIRE(a == e); +} + +// Tests that ShiftColumnsRightInplace()() does the exact same thing as ShiftColumnsRight() +TEST_CASE(__FILE__"/shift-columns-right-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftColumnsRightInplace(); + REQUIRE(a == initial_a.ShiftColumnsRight()); +} + +// Tests that shift cells left works +TEST_CASE(__FILE__"/shift-cells-left", "[Block]") { + + // Setup + Block a; + for (std::size_t i = 0; i < 16; i++) { + a.Get(i) = i; + } + + Block expected; + for (std::size_t i = 0; i < 15; i++) { + expected.Get(i) = i+1; + } + expected.Get(15) = 0; + + // Exercise + a.ShiftCellsLeftInplace(); + + // Verify + REQUIRE(a == expected); +} + +// Tests that ShiftCellsLeftInplace()() does the exact same thing as ShiftCellsLeft() +TEST_CASE(__FILE__"/shift-cells-left-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftCellsLeftInplace(); + REQUIRE(a == initial_a.ShiftCellsLeft()); +} + +// Tests that shift cells right works +TEST_CASE(__FILE__"/shift-cells-right", "[Block]") { + + // Setup + Block a; + for (std::size_t i = 0; i < 16; i++) { + a.Get(i) = i; + } + + Block expected; + for (std::size_t i = 1; i < 16; i++) { + expected.Get(i) = i-1; + } + expected.Get(0) = 15; + + // Exercise + a.ShiftCellsRightInplace(); + + // Verify + REQUIRE(a == expected); +} + +// Tests that ShiftCellsRightInplace()() does the exact same thing as ShiftCellsRight() +TEST_CASE(__FILE__"/shift-cells-right-same-as-inplace", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise and verify + a.ShiftCellsRightInplace(); + REQUIRE(a == initial_a.ShiftCellsRight()); +} + +// Tests that shifting down undoes shifting up, and vica versa +TEST_CASE(__FILE__"/shift-down-undoes-shift-up", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise + a.ShiftRowsUpInplace(); + a.ShiftRowsDownInplace(); + + // Verify + REQUIRE(a == initial_a); +} + +// Tests that shifting left undoes shifting right, and vica versa +TEST_CASE(__FILE__"/shift-left-undoes-shift-right", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise + a.ShiftColumnsRightInplace(); + a.ShiftColumnsLeftInplace(); + + // Verify + REQUIRE(a == initial_a); +} + +// Tests that shifting cells left undoes shifting cells right, and vica versa +TEST_CASE(__FILE__"/cellshift-left-undoes-cellshift-right", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + const Block initial_a = a; + + // Exercise + a.ShiftCellsRightInplace(); + a.ShiftCellsLeftInplace(); + + // Verify + REQUIRE(a == initial_a); +} + +// Tests that multiple, combined shifts and additions can be undone +TEST_CASE(__FILE__"/multiple-combined-shifts-and-additions-can-be-undone", "[Block]") { + // Setup + Block a = Key::FromPassword("Halleluja"); + Block key = Key::FromPassword("Papaya"); + + const Block initial_a = a; + + // Exercise (mix-up) + for (std::size_t i = 0; i < 64; i++) { + a.ShiftRowsUpInplace(); + a.ShiftColumnsLeftInplace(); + a += key; + a.ShiftCellsRightInplace(); + } + + // Exercise (un-mix) + for (std::size_t i = 0; i < 64; i++) { + a.ShiftCellsLeftInplace(); + a -= key; + a.ShiftColumnsRightInplace(); + a.ShiftRowsDownInplace(); + } + + // Verify + REQUIRE(a == initial_a); +} +