Single Round Match 731 Editorials
SRM 731, Division 2 Easy - RingLex - Editorial by square1001
The problem requires brute-forcing the offset x and the prime number p, which both are less than n. There are at most n^2 combinations of (x,p).
If we decide n and p, we can generate string of length n. The minimal one is the answer. The computation time is proportional to n^3.The constraints of n<=50 says that this solution will pass in execution time limit of 2 seconds.
Since n<=50, we can judge prime by checking if the number is 2, 3, 5, 7, 11, 13, 17, 23, 29, 31, 37, 43, 47 or not. But there’s a different way. The prime numbers are the numbers which is not divisible by any integer between 2 and p-1. If we check divisibility for all numbers between 2 and p-1, we can check if the number p is a prime.
#include <string
#include <algorithm
using namespace std;
class RingLex {
public:
bool isprime(int n) {
if (n <= 1) return false;
for (int i = 2; i < n; ++i) {
if (n % i == 0) return false;
}
return true;
}
string getmin(string s) {
int n = s.size();
string t;
for (int i = 0; i < n; ++i) t += s;
string ans = "~"; // ASCII Code 127
for (int i = 0; i < n; ++i) {
for (int j = 1; j < n; ++j) {
if (isprime(j)) {
string gen;
for (int k = 0; k < n; ++k) {
gen += t[i + j * k];
}
ans = min(ans, gen);
}
}
}
return ans;
}
};
SRM 731, Division 2 Medium - DancingClass - Editorial by square1001
Some problem requires big integers which is more than 64-bits. This problem is one example.
First, we explain how to calculate the probability of which the number of boy-girl pairs will be at least K, in the randomly-gendered N people in dancing class.
If there are i boys and N-i girls, we will have i(N-i) boy-girl pair. This probability will be (½)^N * C(N,i) where C(N,i) is the number of combinations to choose i items from N distinguishable items. This is called “binomial coefficient”, and it can be calculated from Pascal’s Triangle. The probability that there are at least K boy-girl pairs is the sum of (1/2)^N * C(N,i) for all integer i which holds i*(N-i) >= K.
Second, we will explain how can we deal with the very big numbers. One possible way is to use double (in C++, C#, Java) or long double (in C++) typed variables, which stores real numbers. The double-typed (64-bit) real number variables will assure error of about at most 10^-15 per an arithmetic operation. Here, there are no cases that the probability will be 0.5-x or 0.5+x for very small x like 10^-12, it is okay. The example source code is eriksuenderhauf’s Program written during the contest.
However, we can also use big integers. To solve just with integers (not using real numbers), we need to calculate the sum of C(N,i) for all i which holds i*(N-i) = K, and check if the sum (let this be S) is larger than, equal to, or less than 2^N-1.We can calculate S by using Pascal’s triangle, but we can also use the formula C(N,i)+ C(N, i-1) * (N-i+1)/ i and calculate C(N,i) in increasing order of i. This will achieve the time complexity of O(N^2). Here is my implementation in C++. (Since it includes my personal source code library of big integers, the source code size is relatively large)
#ifndef ___CLASS_MODINT
#define ___CLASS_MODINT
#include <vector
#include <cstdint
using singlebit = uint32_t;
using doublebit = uint64_t;
static constexpr singlebit find_inv(singlebit n, int d = 5, singlebit x = 1) {
return d == 0 ? x : find_inv(n, d - 1, x * (2 - x * n));
}
template <singlebit mod, singlebit primroot class modint {
// Fast Modulo Integer, Assertion: mod < 2^31
private:
singlebit n;
static constexpr int level = 32; // LIMIT OF singlebit
static constexpr singlebit max_value = -1;
static constexpr singlebit r2 = (((1ull << level) % mod) << level) % mod;
static constexpr singlebit inv = singlebit(-1) * find_inv(mod);
static singlebit reduce(doublebit x) {
singlebit res = (x + doublebit(singlebit(x) * inv) * mod) level;
return res < mod ? res : res - mod;
}
public:
modint() : n(0) {};
modint(singlebit n_) { n = reduce(doublebit(n_) * r2); };
modint& operator=(const singlebit x) { n = reduce(doublebit(x) * r2); return *this; }
bool operator==(const modint& x) const { return n == x.n; }
bool operator!=(const modint& x) const { return n != x.n; }
modint& operator+=(const modint& x) { n += x.n; n -= (n < mod ? 0 : mod); return *this; }
modint& operator-=(const modint& x) { n += mod - x.n; n -= (n < mod ? 0 : mod); return *this; }
modint& operator*=(const modint& x) { n = reduce(1ull * n * x.n); return *this; }
modint operator+(const modint& x) const { return modint(*this) += x; }
modint operator-(const modint& x) const { return modint(*this) -= x; }
modint operator*(const modint& x) const { return modint(*this) *= x; }
static singlebit get_mod() { return mod; }
static singlebit get_primroot() { return primroot; }
singlebit get() { return reduce(doublebit(n)); }
modint binpow(singlebit b) {
modint ans(1), cur(*this);
while (b 0) {
if (b & 1) ans *= cur;
cur *= cur;
b = 1;
}
return ans;
}
};
template<typename modulo
std::vector<modulo get_modvector(std::vector<int
v) {
std::vector<modulo ans(v.size());
for (int i = 0; i < v.size(); ++i) {
ans[i] = v[i];
}
return ans;
}
#endif
#ifndef ___CLASS_NTT
#define ___CLASS_NTT
#include <vector
template<typename modulo
class ntt {
// Number Theoretic Transform
private:
int depth;
std::vector<modulo roots;
std::vector<modulo powinv;
public:
ntt() {
depth = 0;
uint32_t div_number = modulo::get_mod() - 1;
while (div_number % 2 == 0) div_number = 1, ++depth;
modulo b = modulo::get_primroot();
for (int i = 0; i < depth; ++i) b *= b;
modulo baseroot = modulo::get_primroot(), bb = b;
while (bb != 1) bb *= b, baseroot *= modulo::get_primroot();
roots = std::vector<modulo(depth + 1, 0);
powinv = std::vector<modulo(depth + 1, 0);
powinv[1] = (modulo::get_mod() + 1) / 2;
for (int i = 2; i <= depth; ++i) powinv[i] = powinv[i - 1] * powinv[1];
roots[depth] = 1;
for (int i = 0; i < modulo::get_mod() - 1; i += 1 << depth) roots[depth] *= baseroot;
for (int i = depth - 1; i = 1; --i) roots[i] = roots[i + 1] * roots[i + 1];
}
void fourier_transform(std::vector<modulo &v, bool inverse) {
int s = v.size();
for (int i = 0, j = 1; j < s - 1; ++j) {
for (int k = s 1; k
(i ^= k); k
= 1);
if (i < j) std::swap(v[i], v[j]);
}
int sc = 0, sz = 1;
while (sz < s) sz *= 2, ++sc;
std::vector<modulo pw(s + 1); pw[0] = 1;
for (int i = 1; i <= s; i++) pw[i] = pw[i - 1] * roots[sc];
int qs = s;
for (int b = 1; b < s; b <<= 1) {
qs = 1;
for (int i = 0; i < s; i += b * 2) {
for (int j = i; j < i + b; ++j) {
modulo delta = pw[(inverse ? b * 2 - j + i : j - i) * qs] * v[j + b];
v[j + b] = v[j] - delta;
v[j] += delta;
}
}
}
if (inverse) {
for (int i = 0; i < s; ++i) v[i] *= powinv[sc];
}
}
std::vector<modulo convolve(std::vector<modulo
v1, std::vector<modulo
v2) {
const int threshold = 16;
if (v1.size() < v2.size()) swap(v1, v2);
int s1 = 1; while (s1 < v1.size()) s1 <<= 1; v1.resize(s1);
int s2 = 1; while (s2 < v2.size()) s2 <<= 1; v2.resize(s2 * 2);
std::vector<modulo ans(s1 + s2);
if (s2 <= threshold) {
for (int i = 0; i < s1; ++i) {
for (int j = 0; j < s2; ++j) {
ans[i + j] += v1[i] * v2[j];
}
}
}
else {
fourier_transform(v2, false);
for (int i = 0; i < s1; i += s2) {
std::vector<modulo v(v1.begin() + i, v1.begin() + i + s2);
v.resize(s2 * 2);
fourier_transform(v, false);
for (int j = 0; j < v.size(); ++j) v[j] *= v2[j];
fourier_transform(v, true);
for (int j = 0; j < s2 * 2; ++j) {
ans[i + j] += v[j];
}
}
}
return ans;
}
};
#endif
#ifndef __CLASS_BASICINTEGER
#define __CLASS_BASICINTEGER
#include <vector
using modulo1 = modint<469762049, 3; ntt<modulo1
ntt_base1;
using modulo2 = modint<167772161, 3; ntt<modulo2
ntt_base2;
const modulo1 magic_inv = modulo1(modulo2::get_mod()).binpow(modulo1::get_mod() - 2);
template<int base
class basic_integer {
protected:
std::vector<int a;
public:
basic_integer() : a(std::vector<int({ 0 })) {};
basic_integer(const std::vector<int& a_) : a(a_) {};
int size() const { return a.size(); }
int nth_digit(int n) const { return a[n]; }
basic_integer& resize() {
int lim = 1;
for (int i = 0; i < a.size(); ++i) {
if (a[i] != 0) lim = i + 1;
}
a.resize(lim);
return *this;
}
basic_integer& shift() {
for (int i = 0; i < int(a.size()) - 1; ++i) {
if (a[i] = 0) {
a[i + 1] += a[i] / base;
a[i] %= base;
}
else {
int x = (-a[i] + base - 1) / base;
a[i] += x * base;
a[i + 1] -= x;
}
}
while (a.back() = base) {
a.push_back(a.back() / base);
a[a.size() - 2] %= base;
}
return *this;
}
bool operator==(const basic_integer& b) const { return a == b.a; }
bool operator!=(const basic_integer& b) const { return a != b.a; }
bool operator<(const basic_integer& b) const {
if (a.size() != b.a.size()) return a.size() < b.a.size();
for (int i = a.size() - 1; i = 0; --i) {
if (a[i] != b.a[i]) return a[i] < b.a[i];
}
return false;
}
bool operator(const basic_integer& b) const { return b < (*this); }
bool operator<=(const basic_integer& b) const { return !((*this) b); }
bool operator=(const basic_integer& b) const { return !((*this) < b); }
basic_integer& operator<<=(const uint32_t x) {
if (a.back() = 1 || a.size()
= 2) {
std::vector<int v(x, 0);
a.insert(a.begin(), v.begin(), v.end());
}
return (*this);
}
basic_integer& operator=(const uint32_t x) {
if (x == 0) return *this;
if (x a.size()) a = { 0 };
else a = std::vector<int(a.begin() + x, a.end());
return (*this);
}
basic_integer& operator+=(const basic_integer& b) {
if (a.size() < b.a.size()) a.resize(b.a.size(), 0);
for (int i = 0; i < b.a.size(); ++i) a[i] += b.a[i];
return (*this).shift();
}
basic_integer& operator-=(const basic_integer& b) {
for (int i = 0; i < b.a.size(); ++i) a[i] -= b.a[i];
return (*this).shift().resize();
}
basic_integer& operator*=(const basic_integer& b) {
std::vector<modulo1 mul_base1 = ntt_base1.convolve(get_modvector<modulo1
(a), get_modvector<modulo1
(b.a));
std::vector<modulo2 mul_base2 = ntt_base2.convolve(get_modvector<modulo2
(a), get_modvector<modulo2
(b.a));
const int margin = 20;
a = std::vector<int(mul_base1.size() + margin);
for (int i = 0; i < a.size() - margin; ++i) {
// s * p1 + a1 = val = t * p2 + a2's solution is t = (a1 - a2) / p2 (mod p1)
long long val = (long long)(((mul_base1[i] - modulo1(mul_base2[i].get())) * magic_inv).get()) * modulo2::get_mod() + mul_base2[i].get();
for (int j = i; val 0 && j < a.size(); ++j) {
a[j] += val % base;
if (a[j] = base) {
a[j] -= base;
a[j + 1] += 1;
}
val /= base;
}
}
return (*this).resize();
}
basic_integer& operator/=(const basic_integer& b) {
int preci = a.size() - b.a.size();
basic_integer t({ 1 });
basic_integer two = basic_integer({ 2 }) << b.a.size();
basic_integer pre;
int lim = std::min(preci, 3);
int blim = std::min(int(b.a.size()), 6);
t <<= lim;
while (pre != t) {
basic_integer rb = b (b.a.size() - blim);
if (blim != b.a.size()) rb += basic_integer({ 1 });
pre = t;
t *= (basic_integer({ 2 }) << (blim + lim)) - rb * t;
t.a = std::vector<int(t.a.begin() + lim + blim, t.a.end());
}
if (lim != preci) {
pre = basic_integer();
while (pre != t) {
basic_integer rb = b (b.a.size() - blim);
if (blim != b.a.size()) rb += basic_integer({ 1 });
pre = t;
t *= (basic_integer({ 2 }) << (blim + lim)) - rb * t;
t.a = std::vector<int(t.a.begin() + lim + blim, t.a.end());
int next_lim = std::min(lim * 2 + 1, preci);
if (next_lim != lim) t <<= next_lim - lim;
int next_blim = std::min(blim * 2 + 1, int(b.a.size()));
lim = next_lim;
blim = next_blim;
}
}
basic_integer ans = (*this) * t;
ans.a = std::vector<int(ans.a.begin() + a.size(), ans.a.end());
while ((ans + basic_integer({ 1 })) * b <= (*this)) {
ans += basic_integer({ 1 });
}
(*this) = ans.resize();
return *this;
}
basic_integer& divide_by_2() {
for (int i = a.size() - 1; i = 0; --i) {
int carry = a[i] % 2;
a[i] /= 2;
if (i != 0) a[i - 1] += carry * base;
}
if (a.size() = 2 && a.back() == 0) a.pop_back();
return *this;
}
basic_integer operator<<(int x) const { return basic_integer(*this) <<= x; }
basic_integer operator (int x) const { return basic_integer(*this)
= x; }
basic_integer operator+(const basic_integer& b) const { return basic_integer(*this) += b; }
basic_integer operator-(const basic_integer& b) const { return basic_integer(*this) -= b; }
basic_integer operator*(const basic_integer& b) const { return basic_integer(*this) *= b; }
basic_integer operator/(const basic_integer& b) const { return basic_integer(*this) /= b; }
};
#endif
#ifndef ___CLASS_NEWBIGINT
#define ___CLASS_NEWBIGINT
#include <string
#include <iostream
#include <algorithm
const int digit = 4;
const int digit_base = 10000;
class bigint : public basic_integer<digit_base {
public:
bigint() { a = std::vector<int({ 0 }); };
bigint(long long x) {
a.clear();
for (int i = 0; x 0; ++i) {
a.push_back(x % digit_base);
x /= digit_base;
}
if (a.size() == 0) a = { 0 };
}
bigint(const std::string& s) {
a.clear();
for (int i = 0; digit * i < s.size(); ++i) {
a.push_back(std::stoi(s.substr(std::max(int(s.size()) - i * digit - digit, 0), digit - std::max(digit + i * digit - int(s.size()), 0))));
}
if (a.size() == 0) a = { 0 };
}
std::string to_string() const {
std::string ret;
bool flag = false;
for (int i = a.size() - 1; i = 0; --i) {
if (a[i] 0 && !flag) {
ret += std::to_string(a[i]);
flag = true;
}
else if (flag) {
std::string sub = std::to_string(a[i]);
ret += std::string(digit - sub.size(), '0') + sub;
}
}
return ret.empty() ? "0" : ret;
}
int convert_int() const { return std::stoi((*this).to_string()); }
long long convert_ll() const { return std::stoll((*this).to_string()); }
bigint& operator<<=(int x) { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a) <<= x); }
bigint& operator=(int x) { return reinterpret_cast<bigint&
(reinterpret_cast<basic_integer&
(a)
= x); }
bigint& operator+=(const bigint& b) { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a) += basic_integer(b)); }
bigint& operator-=(const bigint& b) { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a) -= basic_integer(b)); }
bigint& operator*=(const bigint& b) { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a) *= basic_integer(b)); }
bigint& operator/=(const bigint& b) { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a) /= basic_integer(b)); }
bigint& divide_by_2() { return reinterpret_cast<bigint&(reinterpret_cast<basic_integer&
(a).divide_by_2()); }
bigint operator<<(int x) const { return bigint(*this) <<= x; }
bigint operator (int x) const { return bigint(*this)
= x; }
bigint operator+(const bigint& b) const { return bigint(*this) += b; }
bigint operator-(const bigint& b) const { return bigint(*this) -= b; }
bigint operator*(const bigint& b) const { return bigint(*this) *= b; }
bigint operator/(const bigint& b) const { return bigint(*this) /= b; }
friend std::istream& operator (std::istream& is, bigint& x) { std::string s; is
s; x = bigint(s); return is; }
friend std::ostream& operator<<(std::ostream& os, const bigint& x) { os << x.to_string(); return os; }
};
#endif
#include <string
using namespace std;
class DancingClass {
public:
string checkOdds(int N, int K) {
vector<bigint b(N + 1);
b[0] = 1;
for (int i = 1; i <= N; ++i) {
b[i] = (b[i - 1] * (N - i + 1)) / i;
}
bigint all = 0, sum = 0;
for (int i = 0; i <= N; ++i) {
all += b[i];
if (i * (N - i) = K) sum += b[i];
}
if (sum * 2 == all) return "Equal";
return sum * 2 all ? "High" : "Low";
}
};
SRM 731, Division 2 Hard - JustBrackets - Editorial by square1001
Solving this problem is difficult. Most of them came up with wrong solutions and failed in system tests. Only one competitor submitted a correct solution and passed.
My solution uses SRM 731 Division 1 Easy - TreesAndBrackets as a subproblem. This problem means, given a bracket string S, check if we can turn this string to bracket string T with some operations of removing consecutive “()”.
This problem (Div2 Hard) asks to find lexicographically smallest string from S with some operations of removing consecutive “()”.
Let’s do greedy approach. We will decide the answer from the first character. The example of approach is as follows:
Can the resulted string start from “(“ ? → Yes (The 1st letter of optimal answer is “(“ )
Can the resulted string start from “((“ ? → Yes (The 2nd letter of optimal answer is “(” )
Can the resulted string start from “(((“ ? → No (The 3rd letter of optimal answer is “)” )
Can the resulted string start from “(()(“ ? → Yes (The 4th letter of optimal answer is “(“ )
And so on…
Here, if the resulted string became a non-empty valid bracket sequence, we can terminate this loop, because it has no meaning (will be lexicographically larger) to append any string.
So, how can you judge things like “Can the resulted string start from “(()(“ ?”. Actually, it is equivalent to the question “Can the resulted string start from “(()())” ?”, which is also equivalent to “Can the resulted string be “(()())” ?”. It means, instead of asking if the resulted string start from a non-necessarily valid bracket sequence, we can ask if the resulted string be a valid bracket sequence by adding some “)”.
You may notice that this subproblem is same as Div1 Easy one. We can calculate in O(N^2) (see my editorial of Div1 Easy). Since the greedy approach takes at most N steps, we can solve this problem in O(N^3) time complexity.
Of course, I think there can be a faster solution. Let’s think about it :)
Time Complexity: O(N^3)
#include <string
#include <vector
#include <iostream
#include <algorithm
using namespace std;
class JustBrackets {
public:
bool valid(string S, string T) {
vector<string sp, tp;
int depth = 0, pre = 0;
for(int i = 0; i < S.size(); ++i) {
if(S[i] == '(') ++depth;
else --depth;
if(depth == 0) {
string sub = S.substr(pre + 1, i - 1 - pre);
sp.push_back(sub);
pre = i + 1;
}
}
depth = 0, pre = 0;
for(int i = 0; i < T.size(); ++i) {
if(T[i] == '(') ++depth;
else --depth;
if(depth == 0) {
string sub = T.substr(pre + 1, i - 1 - pre);
tp.push_back(sub);
pre = i + 1;
}
}
int ptr = 0;
for(int i = 0; i < tp.size(); ++i) {
while(ptr < sp.size()) {
bool res = valid(sp[ptr], tp[i]);
if(res) break;
++ptr;
}
if(ptr == sp.size()) return false;
++ptr;
}
return true;
}
string getSmallest(string S) {
string ans = "";
while(true) {
int depth = 0;
for(int i = 0; i < ans.size(); ++i) {
if(ans[i] == '(') ++depth;
else --depth;
}
if(ans != "" && depth == 0) break;
string nxt = ans + "(" + string(depth + 1, ')');
if(valid(S, nxt)) ans += '(';
else ans += ')';
}
return ans;
}
};
SRM 731, Division 1 Easy - TreesAndBrackets - Editorial by square1001
Solution with Reversed Thinking + Dynamic Programming
In this problem, we remove leaves and transform t1 to t2. However, we can also think that this problem is to add leaves and transform t2 to t1.
If we think about it we can solve by DP. To make a valid tree from t2, we can repeatedly insert “valid forest string” between i-th character and (i+1)-th character 1<= i <=|t2| -1).
Here, “valid forest string” is the string which has following properties:
There number of ‘(‘ and the number of ‘)’ are same.
For any prefix, the number of ‘(‘ is greater than or equal to the number of ‘)’.
Such string is sometimes called “parenthesis sequence” or “parenthesis string”. There are many problems in competitive programming which uses parenthesis sequence, for example, SRM 521 Div1 Easy - MissingParentheses.
Now, let dp[i][j] that “We can transform validly when just looking first i characters of t2 and first j characters of t1”.
Here, dp[i][j] is true. For other dp[i][j], it is true if it meets either of conditions:
t2[i-1]=t1[j-1]and dp[i-1][j-1] is true
t1[j-k,j) is a valid forest string and dp[i][j-k] is true
If dp[|t2|][|t1|] is true, the answer is “Possible”, otherwise “Impossible”. We can solve this problem in time complexity O(N^4) and this can be improved to O(N^3) if we only take constant time to judge if the substring is a valid forest string.
During the contest, ainta solved in this way and solved in just 4 minutes and 36 seconds! We can see his source code for implementation.
Solution with DFS + Greedy Algorithm
For some people this solution could be more simple. Let’s try to solve recursively.
First, let’s think the following situation:
t1 is expressed as ((u1)(u2 )....(ux ))
t2 s expressed as ((v1)(v2 )....(vy ))
We can use greedy algorithm. If (u1) can be transformed to (v1), then we should only care about if ((u2)(u3 )....(ux )) can be transformed to ((v23)(v )....(vy )). In other words, “match (u1) and (v1)”. If cannot, we should only care about if ((u2)(u3 )....(ux )) can be transformed to ((v1)(v2 )....(vy )). In other words, (u1) will not be “matched” to anything.
We can judge it recursively. The total number of recursion is at most n times, so the time complexity will be O(n^2) (depends on implementation, though), which is faster than DP solution.
#include <string
#include <vector
#include <iostream
#include <algorithm
using namespace std;
class TreesAndBrackets {
public:
bool valid(string S, string T) {
vector<string sp, tp;
int depth = 0, pre = 0;
for(int i = 0; i < S.size(); ++i) {
if(S[i] == '(') ++depth;
else --depth;
if(depth == 0) {
string sub = S.substr(pre + 1, i - 1 - pre);
sp.push_back(sub);
pre = i + 1;
}
}
depth = 0, pre = 0;
for(int i = 0; i < T.size(); ++i) {
if(T[i] == '(') ++depth;
else --depth;
if(depth == 0) {
string sub = T.substr(pre + 1, i - 1 - pre);
tp.push_back(sub);
pre = i + 1;
}
}
int ptr = 0;
for(int i = 0; i < tp.size(); ++i) {
while(ptr < sp.size()) {
bool res = valid(sp[ptr], tp[i]);
if(res) break;
++ptr;
}
if(ptr == sp.size()) return false;
++ptr;
}
return true;
}
string check(string S, string T) {
bool res = valid(S, T);
return (res ? "Possible" : "Impossible");
}
};
SRM 731, Division 1 Medium - RndSubTree - Editorial by square1001
I think that this problem is difficult. Even more difficult than Div1 Hard of this contest, aside from the number of solvers were larger for Div1 Medium.
Prerequisite: Knowledge about Binomial Distribution
A binomial distribution X= B(n,p) is the probabilistic distribution of number of successes for n independent trials which will success at probability p.
The expected value E(X)= np. That’s because there are n independent trials that the expected ‘contribution’ to the final number of successes is p.
The variance V(X)=np(1-p). If n=1, the value is 0 for probability 1-p and 1 for probability p, so the variance is V(X)=(1-p)(0-p)^2 + p(1-p)^2= p(1-p). Since there are n independent trials, the variance will be multiplied by n, which means V(X)= np(1-p).
Now let’s find the expected square of value E(X^2). There is a famous formula that, for any probabilistic distribution, V(X)=E(X^2)- (E(X))^2. It means, for binomial distribution X= B(n,p) X=B(n, p), the expected square of value E(X^2)= (E(X))^2+V(X)= (np)^2 +np(1-p) = np(np-p +1).
Solution Part #1 - How can we calculate the expected probability?
Let G be a rooted tree of N vertices. This time, G will be the (induced) subgraph consists of all red vertices on infinite binary tree. The total distance of N(N-1)/2 pair of vertices can be calculated in this way: Letting Ci be the size of subtree of G rooted at vertex i, the total distance will be equal to the sum of ci * (N-Ci). It is intuitive if you think about the number of shortest paths which passes the ancestor-going edge from vertex i, is ci * (N-Ci)
So, we should calculate the expected value of ci * (N-Ci) for all red vertices. But we need one more additional information j - the probability of which the vertex will turn red in exactly j-th token-putting operation. Let this probability P(j,i). Also, let Ri the probability which token passes vertex i when putting one token after the vertex i turns red. The expected contribution to ci * (N-Ci) will be P(j,i) * E((X+1) * (N-X-1)) where X=B(N-j-1,Ri ).
Solution Part #2 - Calculating exact expected values
Here, P(j,i) is same if the depth of vertex i is same. So, we define P(j,i) the probability of which the leftmost vertex of depth i will turn red in exactly j-th token-putting operation. Here, there are two children (call left/right child) in a vertex of infinite binary tree. And the leftmost vertex of depth i means the vertex which goes left children i times from the root.
Let q(j,i) the probability that the maximum depth of red leftmost vertex is i when exactly j tokens were put. Here, it holds:
p(j,i) = q(j,i)+ q(j, i+1)+ q(j,i+2)+.... -1(j-1, i+1)+q(j-1, i+2)...
We can calculate q(j,i) by dynamic programming.
q(1,0)=1
q(j,i) = q(j-1, i)* (1-2^(i+1)) + q(j-1, i-1) *2^-i
Note that the maximum depth of red leftmost vertices changes from i to i+1 in probability 2^(i+1), because the token needs to go left i+1 times consecutively.
Because of j<i, we can calculate all values of q(j,i) up to N tokens in O(N^2) time complexity. After this, with cumulative sum technique, we can calculate p(j,i) in O(N^2) time complexity.
So, the expected contribution to expected value in when the leftmost vertex of depth i will turn red in exactly j-th token, is, p(j,i) * E*((X+1) * (N-X-1)) when X=B(N-j-1,2^-i). Note that, here Ri (in part #1) changed to 2^-1,, because the probability which the leftmost vertex of depth is equal to the probability that the token goes left i times consecutively.
Now, let’s calculate E*((X+1) * (N-X-1)). If you define Y= X+1 = X=B(N-j-1,2^-i) +1 , you now have to calculate E(Y*(N-Y)). Here, E(Y)= (N-j-1) *2^-i +1, and V(Y)= (N-j-1) * 2^-i*(1-2^-i) - then we can calculate E(Y^2)=(E(Y^2)= (E(Y))^2 + V (Y)
And we can then calculate E(Y*(N-Y) = N*E(Y) - E(Y^2). That’s all - now we only have to calculate the sum of p(j,i)* E(Y*(N-Y))* 2î (note that there are 2î vertices of depth i and their contributions to expected total distances are equal by symmetry). And it’s the answer.
Total Time Complexity:O(N^2)
#ifndef CLASS_MODINT
#define CLASS_MODINT
#include <cstdint
template<std::uint32_t mod
class modint {
private:
std::uint32_t n;
public:
modint() : n(0) {};
modint(std::int64_t n_) : n((n_ = 0 ? n_ : mod - (-n_) % mod) % mod) {};
std::uint32_t get() const { return n; }
bool operator==(const modint& m) const { return n == m.n; }
bool operator!=(const modint& m) const { return n != m.n; }
modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }
modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }
modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }
modint operator+(const modint& m) const { return modint(*this) += m; }
modint operator-(const modint& m) const { return modint(*this) -= m; }
modint operator*(const modint& m) const { return modint(*this) *= m; }
modint pow(std::uint64_t b) const {
modint ans = 1, m = modint(*this);
while (b) {
if (b & 1) ans *= m;
m *= m;
b = 1;
}
return ans;
}
modint inv() const { return (*this).pow(mod - 2); }
};
#endif // CLASS_MODINT
#include <vector
#include <algorithm
using namespace std;
const int mod = 1000000007;
using modulo = modint<mod;
class RndSubTree {
private:
int N;
vector<modulo pw;
public:
void init() {
pw.resize(2 * N + 1);
pw[N] = 1;
modulo inv2 = modulo(2).inv();
for(int i = 0; i < N; ++i) pw[N + i + 1] = pw[N + i] * 2;
for(int i = 0; i < N; ++i) pw[N - i - 1] = pw[N - i] * inv2;
}
modulo pow2(int x) {
// assuming -N <= x <= N, returns 2^x modulo 10^9+7
return pw[x + N];
}
int count(int N_) {
N = N_;
init();
vector<vector<modulo
dp(N + 1, vector<modulo
(N, modulo(0)));
dp[1][0] = 1;
for(int i = 2; i <= N; ++i) {
for(int j = 0; j < i; ++j) {
dp[i][j] += dp[i - 1][j] * (modulo(1) - pow2(-(j + 1)));
if(j = 1) dp[i][j] += dp[i - 1][j - 1] * pow2(-j);
}
}
for(int i = 1; i <= N; ++i) {
for(int j = i - 2; j = 0; --j) {
dp[i][j] += dp[i][j + 1];
}
}
modulo ans = 0;
for(int i = 1; i <= N; ++i) {
for(int j = 1; j < i; ++j) {
modulo prob = pow2(-j);
modulo ex = prob * (N - i) + modulo(1);
modulo var = prob * (modulo(1) - prob) * (N - i);
modulo ex_square = var + ex * ex;
modulo ex_contrib = modulo(N) * ex - ex_square; // expectation of x*(N-x)
ans += ex_contrib * (dp[i][j] - dp[i - 1][j]) * pow2(j);
}
}
return ans.get();
}
};
SRM 731, Division 1 Hard - SquadConstructor2 - Editorial by square1001
Slow DP Solution
Brute-forcing all possibilities is really difficult. There are C(256,8) 4.0966 * 10^14 which takes days or even years to calculate the optimal answer. This may cause things like this.
This problem can be solved by dynamic programming efficiently than brute-force. Let dp[i][j][seq] (seq= seq0, seq1,......seqN-1)) the possibility (in boolean value) of “choosing j players from friends 0 through i-1, and the current number of players which can play strategy k is seqk”.
Let the number of friends S. The number of combinations of (i,j,seq) is at most S * (K+1) *(K+ 1)^N= S* (K+1)^N+1. And there are two ways to move from (i,j,seq): to choose friend i, or not. The calculation of new seq takes O(N) time, so the total time complexity of this DP is O(S*N* (K+1)^N+1).
Let’s assign S=256, N=8, K=8 for maximum case… and S* N * (K+1)^N+1 will be about 7.9344* 10^11, which is far better than brute-force solution!
Fast DP Optimization with “bitset”
Let’s optimize this DP with bitset! We can think that bitset is a boolean array of fixed length s. We can see this as an s-digit binary number, and we can calculate bitwise-AND, bitwise-OR, bitwise-XOR, and bitwise-shift. It is very fast because it is splitted each (usually) 64 digits and processed by integer array, plus bitwise operations. So, let’s assume that it is 64-times fast than naive calculation.
Actually, this DP can be calculated by bitset! But, how to use? Here are two tips to use bitset by transforming seq to an integer.
Let h=seq0 * (K+1)^0 + seq1 *(K+1)^1+...+ seqN-1* (K+1)^N-1.
If we add friend i, then the new h will become h+d, when d is the sum of (K+1)^x for all x that friend can use strategy x.
So this will become a variant of the well-known “partial sum problem”. Partial sum problem is that “we choose some elements from a1,a2,...an, then, judge if it is possible to choose elements that sum will be exactly k^n.The partial sum problem can be solved by bitset like following.
-------------------------------
bitset<1000007 bs;
bs[0] = 1;
for(int i = 0; i < n; ++i) {
bs = bs << a[i];
}
return (bs[k] ? “Possible” : “Impossible”;
-------------------------------
And the time complexity will be, O(nk) times 1/64. We can similarly do the DP with bitset, and the time complexity will be O(S* (K+1)^N+1) times 1/64. In this problem, the maximum case will be S=256, K=8, N=8, and S* (K+1)^N+1/ 64 will be about 1.5497* 10^9. Actually, this is slightly tight for the execution time limit 3 seconds.
We now eliminate the unnecessary elements. Let Ti the set of ability friend i can use. If there are more than K occurrences that Ti Tj (Note: i-j is possible), the friend i will never be in the squad. The maximum case I can consider is that “N=8 and all cases of |Ti|=3 or 4”, which no element will be eliminated. In this case,S= C(8,3) + C(8,4)= 56+70= 126, which halves the execution time.
My following source code passed all test cases in less than 2.5 seconds. I think solving this problem in this way using Java, C#, Python, VB will be difficult, and that’s one of the reason why most people are using C++ in competitive programming.
#include <bitset
#include <vector
#include <algorithm
using namespace std;
bitset<43046721 dp[9];
int digit[11];
class SquadConstructor2 {
public:
int teamget(int N, int K, vector<int A) {
vector<int fixedA;
for(int i = 0; i < A.size(); ++i) {
int excel = 0;
for(int j = 0; j < A.size(); ++j) {
if((A[i] | A[j]) == A[j]) ++excel;
}
if(excel <= K) {
fixedA.push_back(A[i]);
}
}
A = fixedA;
dp[0][0] = 1;
for(int i = 0; i < A.size(); ++i) {
int mul = 1, sum = 0;
for(int j = 0; j < N; ++j) {
if((A[i] j) & 1) {
sum += mul;
}
mul *= K + 1;
}
for(int j = K - 1; j = 0; --j) {
dp[j + 1] |= dp[j] << sum;
}
}
int cnt = 0, ans = 0;
while(digit[N] != 1) {
if(dp[K][cnt++]) {
int sum = 0;
for(int i = 0; i < N; ++i) {
sum += digit[i] * digit[i];
}
ans = max(ans, sum);
}
++digit[0];
int pos = 0;
while(digit[pos] == K + 1) {
digit[pos] = 0;
++digit[++pos];
}
}
return ans;
}
};