Skip to content

Commit 0fafa21

Browse files
authored
feat(server): Add support for command aliasing (#4932)
Add support for command aliasing using command_alias flag Signed-off-by: Abhijat Malviya <abhijat@dragonflydb.io>
1 parent 7ffe812 commit 0fafa21

10 files changed

+124
-84
lines changed

src/server/acl/acl_family.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ void AclFamily::DryRun(CmdArgList args, const CommandContext& cmd_cntx) {
591591

592592
string command = absl::AsciiStrToUpper(ArgS(args, 1));
593593
auto* cid = cmd_registry_->Find(command);
594-
if (!cid) {
594+
if (!cid || cid->IsAlias()) {
595595
auto error = absl::StrCat("Command '", command, "' not found");
596596
rb->SendError(error);
597597
return;
@@ -1062,7 +1062,7 @@ std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
10621062
std::string_view command) const {
10631063
if (absl::StartsWith(command, "+")) {
10641064
auto res = cmd_registry_->Find(command.substr(1));
1065-
if (!res) {
1065+
if (!res || res->IsAlias()) {
10661066
return {};
10671067
}
10681068
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};
@@ -1071,7 +1071,7 @@ std::pair<AclFamily::OptCommand, bool> AclFamily::MaybeParseAclCommand(
10711071

10721072
if (absl::StartsWith(command, "-")) {
10731073
auto res = cmd_registry_->Find(command.substr(1));
1074-
if (!res) {
1074+
if (!res || res->IsAlias()) {
10751075
return {};
10761076
}
10771077
std::pair<size_t, uint64_t> cmd{res->GetFamily(), res->GetBitIndex()};

src/server/acl/acl_family_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using namespace testing;
2020

2121
ABSL_DECLARE_FLAG(std::vector<std::string>, rename_command);
22+
ABSL_DECLARE_FLAG(std::vector<std::string>, command_alias);
2223

2324
namespace dfly {
2425

@@ -29,6 +30,7 @@ class AclFamilyTest : public BaseFamilyTest {
2930
class AclFamilyTestRename : public BaseFamilyTest {
3031
void SetUp() override {
3132
absl::SetFlag(&FLAGS_rename_command, {"ACL=ROCKS"});
33+
absl::SetFlag(&FLAGS_command_alias, {"___SET=SET"});
3234
ResetService();
3335
}
3436
};
@@ -538,4 +540,22 @@ TEST_F(AclFamilyTest, TestPubSub) {
538540
EXPECT_THAT(vec[9], "resetchannels &foo");
539541
}
540542

543+
TEST_F(AclFamilyTest, TestAlias) {
544+
auto resp = Run({"ACL", "SETUSER", "luke", "+___SET"});
545+
EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter +___SET"));
546+
547+
resp = Run({"ACL", "SETUSER", "leia", "-___SET"});
548+
EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter -___SET"));
549+
550+
resp = Run({"ACL", "SETUSER", "anakin", "+SET"});
551+
EXPECT_EQ(resp, "OK");
552+
553+
resp = Run({"ACL", "SETUSER", "jarjar", "allcommands"});
554+
EXPECT_EQ(resp, "OK");
555+
556+
resp = Run({"ACL", "DRYRUN", "jarjar", "___SET"});
557+
EXPECT_THAT(resp, ErrArg("ERR Command '___SET' not found"));
558+
EXPECT_EQ(Run({"ACL", "DRYRUN", "jarjar", "SET"}), "OK");
559+
}
560+
541561
} // namespace dfly

src/server/acl/validator.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ bool ValidateCommand(const std::vector<uint64_t>& acl_commands, const CommandId&
6666
return true;
6767
}
6868

69+
if (id.IsAlias()) {
70+
return false;
71+
}
72+
6973
std::pair<bool, AclLog::Reason> auth_res;
7074

7175
if (id.IsPubSub() || id.IsShardedPSub()) {

src/server/command_registry.cc

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ using namespace std;
2222
ABSL_FLAG(vector<string>, rename_command, {},
2323
"Change the name of commands, format is: <cmd1_name>=<cmd1_new_name>, "
2424
"<cmd2_name>=<cmd2_new_name>");
25-
ABSL_FLAG(vector<string>, command_alias, {},
26-
"Add an alias for given commands, format is: <alias>=<original>, "
27-
"<alias>=<original>");
2825
ABSL_FLAG(vector<string>, restricted_commands, {},
2926
"Commands restricted to connections on the admin port");
3027

3128
ABSL_FLAG(vector<string>, oom_deny_commands, {},
3229
"Additinal commands that will be marked as denyoom");
30+
31+
ABSL_FLAG(vector<string>, command_alias, {},
32+
"Add an alias for given command(s), format is: <alias>=<original>, <alias>=<original>. "
33+
"Aliases must be set identically on replicas, if applicable");
34+
3335
namespace dfly {
3436

3537
using namespace facade;
@@ -75,16 +77,17 @@ uint32_t ImplicitAclCategories(uint32_t mask) {
7577
return out;
7678
}
7779

78-
absl::flat_hash_map<std::string, std::string> ParseCmdlineArgMap(
79-
const absl::Flag<std::vector<std::string>>& flag, const bool allow_duplicates = false) {
80+
using CmdLineMapping = absl::flat_hash_map<std::string, std::string>;
81+
82+
CmdLineMapping ParseCmdlineArgMap(const absl::Flag<std::vector<std::string>>& flag) {
8083
const auto& mappings = absl::GetFlag(flag);
81-
absl::flat_hash_map<std::string, std::string> parsed_mappings;
84+
CmdLineMapping parsed_mappings;
8285
parsed_mappings.reserve(mappings.size());
8386

8487
for (const std::string& mapping : mappings) {
85-
std::vector<std::string_view> kv = absl::StrSplit(mapping, '=');
88+
absl::InlinedVector<std::string_view, 2> kv = absl::StrSplit(mapping, '=');
8689
if (kv.size() != 2) {
87-
LOG(ERROR) << "Malformed command " << mapping << " for " << flag.Name()
90+
LOG(ERROR) << "Malformed command '" << mapping << "' for " << flag.Name()
8891
<< ", expected key=value";
8992
exit(1);
9093
}
@@ -97,15 +100,27 @@ absl::flat_hash_map<std::string, std::string> ParseCmdlineArgMap(
97100
exit(1);
98101
}
99102

100-
const bool inserted = parsed_mappings.emplace(std::move(key), std::move(value)).second;
101-
if (!allow_duplicates && !inserted) {
103+
if (!parsed_mappings.emplace(std::move(key), std::move(value)).second) {
102104
LOG(ERROR) << "Duplicate insert to " << flag.Name() << " not allowed";
103105
exit(1);
104106
}
105107
}
106108
return parsed_mappings;
107109
}
108110

111+
CmdLineMapping OriginalToAliasMap() {
112+
CmdLineMapping original_to_alias;
113+
CmdLineMapping alias_to_original = ParseCmdlineArgMap(FLAGS_command_alias);
114+
original_to_alias.reserve(alias_to_original.size());
115+
std::for_each(std::make_move_iterator(alias_to_original.begin()),
116+
std::make_move_iterator(alias_to_original.end()),
117+
[&original_to_alias](auto&& pair) {
118+
original_to_alias.emplace(std::move(pair.second), std::move(pair.first));
119+
});
120+
121+
return original_to_alias;
122+
}
123+
109124
} // namespace
110125

111126
CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key,
@@ -115,6 +130,17 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
115130
implicit_acl_ = !acl_categories.has_value();
116131
}
117132

133+
CommandId CommandId::Clone(const std::string_view name) const {
134+
CommandId cloned =
135+
CommandId{name.data(), opt_mask_, arity_, first_key_, last_key_, acl_categories_};
136+
cloned.handler_ = handler_;
137+
cloned.opt_mask_ = opt_mask_ | CO::HIDDEN;
138+
cloned.acl_categories_ = acl_categories_;
139+
cloned.implicit_acl_ = implicit_acl_;
140+
cloned.is_alias_ = true;
141+
return cloned;
142+
}
143+
118144
bool CommandId::IsTransactional() const {
119145
if (first_key_ > 0 || (opt_mask_ & CO::GLOBAL_TRANS) || (opt_mask_ & CO::NO_KEY_TRANSACTIONAL))
120146
return true;
@@ -130,16 +156,15 @@ bool CommandId::IsMultiTransactional() const {
130156
return CO::IsTransKind(name()) || CO::IsEvalKind(name());
131157
}
132158

133-
uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx,
134-
std::string_view orig_cmd_name) const {
159+
uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx) const {
135160
int64_t before = absl::GetCurrentTimeNanos();
136161
handler_(args, cmd_cntx);
137162
int64_t after = absl::GetCurrentTimeNanos();
138163

139164
ServerState* ss = ServerState::tlocal(); // Might have migrated thread, read after invocation
140165
int64_t execution_time_usec = (after - before) / 1000;
141166

142-
auto& ent = command_stats_[ss->thread_index()][orig_cmd_name];
167+
auto& ent = command_stats_[ss->thread_index()];
143168

144169
++ent.first;
145170
ent.second += execution_time_usec;
@@ -169,7 +194,6 @@ optional<facade::ErrorReply> CommandId::Validate(CmdArgList tail_args) const {
169194

170195
CommandRegistry::CommandRegistry() {
171196
cmd_rename_map_ = ParseCmdlineArgMap(FLAGS_rename_command);
172-
cmd_aliases_ = ParseCmdlineArgMap(FLAGS_command_alias, true);
173197

174198
for (string name : GetFlag(FLAGS_restricted_commands)) {
175199
restricted_cmds_.emplace(AsciiStrToUpper(name));
@@ -181,9 +205,20 @@ CommandRegistry::CommandRegistry() {
181205
}
182206

183207
void CommandRegistry::Init(unsigned int thread_count) {
208+
const CmdLineMapping original_to_alias = OriginalToAliasMap();
209+
absl::flat_hash_map<std::string, CommandId> alias_to_command_id;
210+
alias_to_command_id.reserve(original_to_alias.size());
184211
for (auto& [_, cmd] : cmd_map_) {
185212
cmd.Init(thread_count);
213+
if (auto it = original_to_alias.find(cmd.name()); it != original_to_alias.end()) {
214+
auto alias_cmd = cmd.Clone(it->second);
215+
alias_cmd.Init(thread_count);
216+
alias_to_command_id.insert({it->second, std::move(alias_cmd)});
217+
}
186218
}
219+
std::copy(std::make_move_iterator(alias_to_command_id.begin()),
220+
std::make_move_iterator(alias_to_command_id.end()),
221+
std::inserter(cmd_map_, cmd_map_.end()));
187222
}
188223

189224
CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
@@ -212,7 +247,7 @@ CommandRegistry& CommandRegistry::operator<<(CommandId cmd) {
212247

213248
if (!is_sub_command || absl::StartsWith(cmd.name(), "ACL")) {
214249
cmd.SetBitIndex(1ULL << bit_index_);
215-
family_of_commands_.back().push_back(std::string(k));
250+
family_of_commands_.back().emplace_back(k);
216251
++bit_index_;
217252
} else {
218253
DCHECK(absl::StartsWith(k, family_of_commands_.back().back()));
@@ -266,10 +301,6 @@ std::pair<const CommandId*, ArgSlice> CommandRegistry::FindExtended(string_view
266301
return {res, tail_args};
267302
}
268303

269-
bool CommandRegistry::IsAlias(std::string_view cmd) const {
270-
return cmd_aliases_.contains(cmd);
271-
}
272-
273304
namespace CO {
274305

275306
const char* OptName(CO::CommandOpt fl) {

src/server/command_registry.h

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@ static_assert(!IsEvalKind(""));
7171

7272
}; // namespace CO
7373

74-
// Per thread vector of command stats. Each entry is:
75-
// command invocation string -> {cmd_calls, cmd_latency_agg in usec}.
76-
using CmdCallStats = absl::flat_hash_map<std::string, std::pair<uint64_t, uint64_t>>;
74+
// Per thread vector of command stats. Each entry is {cmd_calls, cmd_latency_agg in usec}.
75+
using CmdCallStats = std::pair<uint64_t, uint64_t>;
7776

7877
struct CommandContext {
7978
CommandContext(Transaction* _tx, facade::SinkReplyBuilder* _rb, ConnectionContext* cntx)
@@ -94,6 +93,8 @@ class CommandId : public facade::CommandId {
9493

9594
CommandId(CommandId&&) = default;
9695

96+
[[nodiscard]] CommandId Clone(std::string_view name) const;
97+
9798
void Init(unsigned thread_count) {
9899
command_stats_ = std::make_unique<CmdCallStats[]>(thread_count);
99100
}
@@ -103,10 +104,8 @@ class CommandId : public facade::CommandId {
103104
using ArgValidator = fu2::function_base<true, true, fu2::capacity_default, false, false,
104105
std::optional<facade::ErrorReply>(CmdArgList) const>;
105106

106-
// Invokes the command handler. Returns the invoke time in usec. The invoked_by parameter is set
107-
// to the string passed in by user, if available. If not set, defaults to command name.
108-
uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx,
109-
std::string_view orig_cmd_name) const;
107+
// Returns the invoke time in usec.
108+
uint64_t Invoke(CmdArgList args, const CommandContext& cmd_cntx) const;
110109

111110
// Returns error if validation failed, otherwise nullopt
112111
std::optional<facade::ErrorReply> Validate(CmdArgList tail_args) const;
@@ -144,7 +143,7 @@ class CommandId : public facade::CommandId {
144143
}
145144

146145
void ResetStats(unsigned thread_index) {
147-
command_stats_[thread_index].clear();
146+
command_stats_[thread_index] = {0, 0};
148147
}
149148

150149
CmdCallStats GetStats(unsigned thread_index) const {
@@ -156,11 +155,16 @@ class CommandId : public facade::CommandId {
156155
acl_categories_ |= mask;
157156
}
158157

158+
bool IsAlias() const {
159+
return is_alias_;
160+
}
161+
159162
private:
160163
bool implicit_acl_;
161164
std::unique_ptr<CmdCallStats[]> command_stats_;
162165
Handler3 handler_;
163166
ArgValidator validator_;
167+
bool is_alias_{false};
164168
};
165169

166170
class CommandRegistry {
@@ -172,16 +176,8 @@ class CommandRegistry {
172176
CommandRegistry& operator<<(CommandId cmd);
173177

174178
const CommandId* Find(std::string_view cmd) const {
175-
if (const auto it = cmd_map_.find(cmd); it != cmd_map_.end()) {
176-
return &it->second;
177-
}
178-
179-
if (const auto it = cmd_aliases_.find(cmd); it != cmd_aliases_.end()) {
180-
if (const auto alias_lookup = cmd_map_.find(it->second); alias_lookup != cmd_map_.end()) {
181-
return &alias_lookup->second;
182-
}
183-
}
184-
return nullptr;
179+
auto it = cmd_map_.find(cmd);
180+
return it == cmd_map_.end() ? nullptr : &it->second;
185181
}
186182

187183
CommandId* Find(std::string_view cmd) {
@@ -203,17 +199,13 @@ class CommandRegistry {
203199
}
204200
}
205201

206-
void MergeCallStats(
207-
unsigned thread_index,
208-
std::function<void(std::string_view, const CmdCallStats::mapped_type&)> cb) const {
209-
for (const auto& [_, cmd_id] : cmd_map_) {
210-
for (const auto& [cmd_name, call_stats] : cmd_id.GetStats(thread_index)) {
211-
if (call_stats.first == 0) {
212-
continue;
213-
}
214-
215-
cb(cmd_name, call_stats);
216-
}
202+
void MergeCallStats(unsigned thread_index,
203+
std::function<void(std::string_view, const CmdCallStats&)> cb) const {
204+
for (const auto& k_v : cmd_map_) {
205+
auto src = k_v.second.GetStats(thread_index);
206+
if (src.first == 0)
207+
continue;
208+
cb(k_v.second.name(), src);
217209
}
218210
}
219211

@@ -227,16 +219,9 @@ class CommandRegistry {
227219
std::pair<const CommandId*, facade::ArgSlice> FindExtended(std::string_view cmd,
228220
facade::ArgSlice tail_args) const;
229221

230-
bool IsAlias(std::string_view cmd) const;
231-
232222
private:
233223
absl::flat_hash_map<std::string, CommandId> cmd_map_;
234224
absl::flat_hash_map<std::string, std::string> cmd_rename_map_;
235-
// Stores a mapping from alias to original command. During the find operation, the first lookup is
236-
// done in the cmd_map_, then in the alias map. This results in two lookups but only for commands
237-
// which are not in original map, ie either typos or aliases. While it would be faster, we cannot
238-
// store iterators into cmd_map_ here as they may be invalidated on rehashing.
239-
absl::flat_hash_map<std::string, std::string> cmd_aliases_;
240225
absl::flat_hash_set<std::string> restricted_cmds_;
241226
absl::flat_hash_set<std::string> oomdeny_cmds_;
242227

0 commit comments

Comments
 (0)