N 以下の正整数のうち、各桁に特定の数字のみを含むものの数を数える

はじめに

この問題を解くことが幼少期*1からの夢でした。2021年度後期で『アルゴリズムとデータ構造入門』を履修し、少しアルゴリズムに興味が持てたため、プログラムを作成してこの問題を解いてみようと思います。使用言語は一応 C++ としています。

そもそもどういう問題?

タイトルの通りですが、例えば  N = 100 で、特定の (使ってよい) 数字が  5, 7 の場合、条件を満たす数は「 N 以下で  5, 7 しか桁に現れない数」ということで、  5, 7, 55, 57, 75, 77 6 種類となります。よって答は 6 です。そんな感じです。

愚直にやってみる (解法1)

まず初めに思いつくのは、

  1.  N 以下の正整数をfor文で全列挙する
  2. 全列挙した各整数の各桁を取り出し、使用可能な数字かどうか確かめる

という方法がありそうです。プログラムの一例としては、以下のようなものがあると思います。

#include <iostream>
#include <vector>

using namespace std;

long long digits_restricted_number(long long x, vector<int> &specified_digit) {
    // 返り値
    long long ret = 0; 

    // x 以下の正整数 i を全走査
    for (long long i = 1; i <= x; i++) {
        bool ok = true;
        long long j = i;
        while (j) {
            // i の各桁を順番に取り出していく
            long long n = j % 10;
            bool is_n_included = false;
            // 今見ている i の桁が specified_digit に含まれていればよい
            for (auto &e : specified_digit) {
                if (n == e) {
                    is_n_included = true;
                }
            }
            if (!is_n_included) {
                ok = false;
            }
            j /= 10;
        }
        // i の各桁がすべて specified_digit に含まれていれば、ret を 1 増やす
        if (ok) ret++;
    }

    return ret;
}

int main() {
    // 標準入力
    long long N; cin >> N;
    int n; cin >> n;

    vector<int> specified_digit(n);
    for (int i = 0; i < n; i++) {
        cin >> specified_digit[i];
    }

    cout << digits_restricted_number(N, specified_digit) << endl;

    return 0;
}

時間計算量ですが、 N 個の各整数の各桁について調べ、さらにその中で使用可能な数字を線形探索しているため、使用可能な数字の種類数 (正とします) を  K として、 O(K N \log N) くらいです*2
 N = 10^9 だとしても少なくとも  10^9 回の計算ステップがあり、現実的な計算時間とはならなそうで、悲しくなってしまいました。

もうちょっと考えてみる (解法2)

桁ごとに考えてみれば、もう少し良い解法が思いつくかもしれません。
動的計画法 の1つである Digit DP*3 というものを用いれば、この問題がもう少し高速に解けそうです。動的計画法では使用する配列の名前を慣習的に dp とするようなので、自分もそうしてみました。
0 を使用する数字に含めるか否かの場合分けが面倒でした。

#include <iostream>
#include <vector>

using namespace std;

long long digits_restricted_number(long long x, vector<int> &specified_digit) {
    // specified_digit に 0 が含まれているかチェック
    bool is_zero_permitted = false;
    for (auto &i : specified_digit) {
        if (i == 0) {
            is_zero_permitted = true;
        }
    }
    // どのみち 0 はゼロ埋めに使うので、specified_digit に入れておく
    if (!is_zero_permitted) specified_digit.push_back(0);

    string x_str = to_string(x);
    int digit = int(x_str.size());

    // dp[i][j][k] :
    // 上から i 桁目まで見た時点で、
    // x ギリギリである ⇒ j = 1, でない ⇒ j = 0
    // 今までの桁で 0 以外を使った ⇒ k = 1, 使っていない ⇒ k = 0
    vector<vector<vector<long long>>> dp(digit+1, vector<vector<long long>>(2, vector<long long>(2, 0)));

    dp[0][0][0] = 1;
    for (int dgt = 0; dgt < digit; dgt++) {
        int cur = x_str[dgt] - '0';
        for (int is_less = 0; is_less < 2; is_less++) {
            for (int is_nonzero_included = 0; is_nonzero_included < 2; is_nonzero_included++) {
                for (auto &nxt : specified_digit) {

                    int is_less_new = is_less;
                    int is_nonzero_included_new = is_nonzero_included;

                    if (!is_less and nxt < cur) {
                        is_less_new = 1;
                    }
                    if (nxt != 0) {
                        is_nonzero_included_new = 1;
                    }

                    if (!is_less and nxt > cur) {
                        continue;
                    }
                    if (!is_zero_permitted and is_nonzero_included and nxt == 0) {
                        continue;
                    }
                    dp[dgt+1][is_less_new][is_nonzero_included_new] += dp[dgt][is_less][is_nonzero_included];
                }
            }
        }
    }

    long long ret = 0;
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            ret += dp[digit][i][j];
        }
    }

    ret--;

    return ret;
}

int main() {
    // 標準入力
    long long N; cin >> N;
    int n; cin >> n;

    vector<int> specified_digit(n);
    for (int i = 0; i < n; i++) {
        cin >> specified_digit[i];
    }

    cout << digits_restricted_number(N, specified_digit) << endl;

    return 0;
}

時間計算量は、各桁に対して使用可能な数字を線形探索するため、  O(K \log N)*4 となり、かなり高速になりました*5

おわりに

意識の底にあった問題をついに解くことができ (たと思われ) 、少しうれしい気持ちになっています。作成したプログラムはかなり乱雑なものとなってしまいました。記事内に誤りがあれば、教えて頂けますと幸いです。
また、 n進法を用いた考え方などもできそうです。またゆっくり考えてみたいと思っています。

*1:2021年9月下旬

*2:あやしい。

*3:よくわかっていません。

*4:あやしい。

*5:おおざっぱに考えると、 N 倍高速になっているので、 N = 10^9 とすると 1兆倍くらい速くなりました (テキトーですみません)。