diff --git a/Sources/SwiftMemcache/Extensions/ByteBuffer+SwiftMemcache.swift b/Sources/SwiftMemcache/Extensions/ByteBuffer+SwiftMemcache.swift index aed4895..dc4044f 100644 --- a/Sources/SwiftMemcache/Extensions/ByteBuffer+SwiftMemcache.swift +++ b/Sources/SwiftMemcache/Extensions/ByteBuffer+SwiftMemcache.swift @@ -53,6 +53,9 @@ extension ByteBuffer { /// - parameters: /// - flags: An instance of MemcachedFlags that holds the flags intended to be serialized and written to the ByteBuffer. mutating func writeMemcachedFlags(flags: MemcachedFlags) { + // Ensure that both storageMode and arithmeticMode aren't set at the same time. + precondition(!(flags.storageMode != nil && flags.arithmeticMode != nil), "Cannot specify both a storage and arithmetic mode.") + if let shouldReturnValue = flags.shouldReturnValue, shouldReturnValue { self.writeInteger(UInt8.whitespace) self.writeInteger(UInt8.v) @@ -101,6 +104,23 @@ extension ByteBuffer { self.writeInteger(UInt8.R) } } + + if let arithmeticMode = flags.arithmeticMode { + self.writeInteger(UInt8.whitespace) + self.writeInteger(UInt8.M) + switch arithmeticMode { + case .decrement(let delta): + self.writeInteger(UInt8.decrement) + self.writeInteger(UInt8.whitespace) + self.writeInteger(UInt8.D) + self.writeIntegerAsASCII(delta) + case .increment(let delta): + self.writeInteger(UInt8.increment) + self.writeInteger(UInt8.whitespace) + self.writeInteger(UInt8.D) + self.writeIntegerAsASCII(delta) + } + } } } diff --git a/Sources/SwiftMemcache/Extensions/UInt8+Characters.swift b/Sources/SwiftMemcache/Extensions/UInt8+Characters.swift index 2fb0c2e..06c671d 100644 --- a/Sources/SwiftMemcache/Extensions/UInt8+Characters.swift +++ b/Sources/SwiftMemcache/Extensions/UInt8+Characters.swift @@ -20,6 +20,7 @@ extension UInt8 { static var s: UInt8 = .init(ascii: "s") static var g: UInt8 = .init(ascii: "g") static var d: UInt8 = .init(ascii: "d") + static var a: UInt8 = .init(ascii: "a") static var v: UInt8 = .init(ascii: "v") static var T: UInt8 = .init(ascii: "T") static var M: UInt8 = .init(ascii: "M") @@ -27,6 +28,9 @@ extension UInt8 { static var A: UInt8 = .init(ascii: "A") static var E: UInt8 = .init(ascii: "E") static var R: UInt8 = .init(ascii: "R") + static var D: UInt8 = .init(ascii: "D") static var zero: UInt8 = .init(ascii: "0") static var nine: UInt8 = .init(ascii: "9") + static var increment: UInt8 = .init(ascii: "+") + static var decrement: UInt8 = .init(ascii: "-") } diff --git a/Sources/SwiftMemcache/MemcachedConnection.swift b/Sources/SwiftMemcache/MemcachedConnection.swift index 4a380e6..3af1e4d 100644 --- a/Sources/SwiftMemcache/MemcachedConnection.swift +++ b/Sources/SwiftMemcache/MemcachedConnection.swift @@ -424,4 +424,62 @@ public actor MemcachedConnection { throw MemcachedConnectionError.connectionShutdown } } + + // MARK: - Increment a Value + + /// Increment the value for an existing key in the Memcache server by a specified amount. + /// + /// - Parameters: + /// - key: The key for the value to increment. + /// - amount: The `Int` amount to increment the value by. Must be larger than 0. + /// - Throws: A `MemcachedConnectionError` if the connection to the Memcached server is shut down. + public func increment(_ key: String, amount: Int) async throws { + // Ensure the amount is greater than 0 + precondition(amount > 0, "Amount to increment should be larger than 0") + + switch self.state { + case .initial(_, _, _, _), + .running: + + var flags = MemcachedFlags() + flags.arithmeticMode = .increment(amount) + + let command = MemcachedRequest.ArithmeticCommand(key: key, flags: flags) + let request = MemcachedRequest.arithmetic(command) + + _ = try await self.sendRequest(request) + + case .finished: + throw MemcachedConnectionError.connectionShutdown + } + } + + // MARK: - Decrement a Value + + /// Decrement the value for an existing key in the Memcache server by a specified amount. + /// + /// - Parameters: + /// - key: The key for the value to decrement. + /// - amount: The `Int` amount to decrement the value by. Must be larger than 0. + /// - Throws: A `MemcachedConnectionError` if the connection to the Memcached server is shut down. + public func decrement(_ key: String, amount: Int) async throws { + // Ensure the amount is greater than 0 + precondition(amount > 0, "Amount to decrement should be larger than 0") + + switch self.state { + case .initial(_, _, _, _), + .running: + + var flags = MemcachedFlags() + flags.arithmeticMode = .decrement(amount) + + let command = MemcachedRequest.ArithmeticCommand(key: key, flags: flags) + let request = MemcachedRequest.arithmetic(command) + + _ = try await self.sendRequest(request) + + case .finished: + throw MemcachedConnectionError.connectionShutdown + } + } } diff --git a/Sources/SwiftMemcache/MemcachedFlags.swift b/Sources/SwiftMemcache/MemcachedFlags.swift index 69c464d..e098b6f 100644 --- a/Sources/SwiftMemcache/MemcachedFlags.swift +++ b/Sources/SwiftMemcache/MemcachedFlags.swift @@ -37,6 +37,11 @@ struct MemcachedFlags { /// The default mode is 'set'. var storageMode: StorageMode? + /// Flag 'M' for the 'ma' (meta arithmetic) command. + /// + /// Represents the mode of the 'ma' command, which determines the behavior of the arithmetic operation. + var arithmeticMode: ArithmeticMode? + init() {} } @@ -60,4 +65,12 @@ enum StorageMode: Equatable, Hashable { case replace } +/// Enum representing the mode for the 'ma' (meta arithmetic) command in Memcached (corresponding to the 'M' flag). +enum ArithmeticMode: Equatable, Hashable { + /// 'increment' command. If applied, it increases the numerical value of the item. + case increment(Int) + /// 'decrement' command. If applied, it decreases the numerical value of the item. + case decrement(Int) +} + extension MemcachedFlags: Hashable {} diff --git a/Sources/SwiftMemcache/MemcachedRequest.swift b/Sources/SwiftMemcache/MemcachedRequest.swift index 232d0b7..5b9b03e 100644 --- a/Sources/SwiftMemcache/MemcachedRequest.swift +++ b/Sources/SwiftMemcache/MemcachedRequest.swift @@ -29,7 +29,13 @@ enum MemcachedRequest { let key: String } + struct ArithmeticCommand { + let key: String + var flags: MemcachedFlags + } + case set(SetCommand) case get(GetCommand) case delete(DeleteCommand) + case arithmetic(ArithmeticCommand) } diff --git a/Sources/SwiftMemcache/MemcachedRequestEncoder.swift b/Sources/SwiftMemcache/MemcachedRequestEncoder.swift index 1389bf0..274cdb3 100644 --- a/Sources/SwiftMemcache/MemcachedRequestEncoder.swift +++ b/Sources/SwiftMemcache/MemcachedRequestEncoder.swift @@ -73,6 +73,22 @@ struct MemcachedRequestEncoder: MessageToByteEncoder { out.writeInteger(UInt8.whitespace) out.writeBytes(command.key.utf8) + // write separator + out.writeInteger(UInt8.carriageReturn) + out.writeInteger(UInt8.newline) + + case .arithmetic(let command): + precondition(!command.key.isEmpty, "Key must not be empty") + + // write command and key + out.writeInteger(UInt8.m) + out.writeInteger(UInt8.a) + out.writeInteger(UInt8.whitespace) + out.writeBytes(command.key.utf8) + + // write flags if there are any + out.writeMemcachedFlags(flags: command.flags) + // write separator out.writeInteger(UInt8.carriageReturn) out.writeInteger(UInt8.newline) diff --git a/Sources/SwiftMemcache/MemcachedValue.swift b/Sources/SwiftMemcache/MemcachedValue.swift index f68b365..abd8417 100644 --- a/Sources/SwiftMemcache/MemcachedValue.swift +++ b/Sources/SwiftMemcache/MemcachedValue.swift @@ -33,14 +33,14 @@ extension MemcachedValue where Self: FixedWidthInteger { /// /// - Parameter buffer: The ByteBuffer to which the integer should be written. public func writeToBuffer(_ buffer: inout ByteBuffer) { - buffer.writeInteger(self) + buffer.writeIntegerAsASCII(self) } /// Reads a FixedWidthInteger from a ByteBuffer. /// /// - Parameter buffer: The ByteBuffer from which the value should be read. public static func readFromBuffer(_ buffer: inout ByteBuffer) -> Self? { - return buffer.readInteger() + return buffer.readIntegerFromASCII() } } diff --git a/Tests/SwiftMemcacheTests/IntegrationTest/MemcachedIntegrationTests.swift b/Tests/SwiftMemcacheTests/IntegrationTest/MemcachedIntegrationTests.swift index 628e28b..685aca6 100644 --- a/Tests/SwiftMemcacheTests/IntegrationTest/MemcachedIntegrationTests.swift +++ b/Tests/SwiftMemcacheTests/IntegrationTest/MemcachedIntegrationTests.swift @@ -434,6 +434,62 @@ final class MemcachedIntegrationTest: XCTestCase { } } + func testIncrementValue() async throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try! group.syncShutdownGracefully()) + } + let memcachedConnection = MemcachedConnection(host: "memcached", port: 11211, eventLoopGroup: group) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { try await memcachedConnection.run() } + + // Set key and initial value + let initialValue = 1 + try await memcachedConnection.set("increment", value: initialValue) + + // Increment value + let incrementAmount = 100 + try await memcachedConnection.increment("increment", amount: incrementAmount) + + // Get new value + let newValue: Int? = try await memcachedConnection.get("increment") + + // Check if new value is equal to initial value plus increment amount + XCTAssertEqual(newValue, initialValue + incrementAmount, "Incremented value is incorrect") + + group.cancelAll() + } + } + + func testDecrementValue() async throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try! group.syncShutdownGracefully()) + } + let memcachedConnection = MemcachedConnection(host: "memcached", port: 11211, eventLoopGroup: group) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { try await memcachedConnection.run() } + + // Set key and initial value + let initialValue = 100 + try await memcachedConnection.set("decrement", value: initialValue) + + // Increment value + let decrementAmount = 10 + try await memcachedConnection.decrement("decrement", amount: decrementAmount) + + // Get new value + let newValue: Int? = try await memcachedConnection.get("decrement") + + // Check if new value is equal to initial value plus increment amount + XCTAssertEqual(newValue, initialValue - decrementAmount, "Incremented value is incorrect") + + group.cancelAll() + } + } + func testMemcachedConnectionWithUInt() async throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { diff --git a/Tests/SwiftMemcacheTests/UnitTest/MemcachedRequestEncoderTests.swift b/Tests/SwiftMemcacheTests/UnitTest/MemcachedRequestEncoderTests.swift index cdcf069..92de139 100644 --- a/Tests/SwiftMemcacheTests/UnitTest/MemcachedRequestEncoderTests.swift +++ b/Tests/SwiftMemcacheTests/UnitTest/MemcachedRequestEncoderTests.swift @@ -169,4 +169,32 @@ final class MemcachedRequestEncoderTests: XCTestCase { let expectedEncodedData = "md foo\r\n" XCTAssertEqual(outBuffer.getString(at: 0, length: outBuffer.readableBytes), expectedEncodedData) } + + func testEncodeIncrementRequest() { + // Prepare a MemcachedRequest + var flags = MemcachedFlags() + flags.arithmeticMode = .increment(100) + let command = MemcachedRequest.ArithmeticCommand(key: "foo", flags: flags) + let request = MemcachedRequest.arithmetic(command) + + // pass our request through the encoder + let outBuffer = self.encodeRequest(request) + + let expectedEncodedData = "ma foo M+ D100\r\n" + XCTAssertEqual(outBuffer.getString(at: 0, length: outBuffer.readableBytes), expectedEncodedData) + } + + func testEncodeDecrementRequest() { + // Prepare a MemcachedRequest + var flags = MemcachedFlags() + flags.arithmeticMode = .decrement(100) + let command = MemcachedRequest.ArithmeticCommand(key: "foo", flags: flags) + let request = MemcachedRequest.arithmetic(command) + + // pass our request through the encoder + let outBuffer = self.encodeRequest(request) + + let expectedEncodedData = "ma foo M- D100\r\n" + XCTAssertEqual(outBuffer.getString(at: 0, length: outBuffer.readableBytes), expectedEncodedData) + } }