openjij
Framework for the Ising model and QUBO.
Loading...
Searching...
No Matches
integer_sa_sampler.hpp
Go to the documentation of this file.
1// Copyright 2023 Jij Inc.
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#pragma once
16
17#include "openjij/graph/all.hpp"
20
21namespace openjij {
22namespace sampler {
23
25 double energy = 0.0;
26 std::vector<std::int64_t> solution = {};
27 std::vector<double> energy_history = {};
28 std::vector<double> temperature_history = {};
29};
30
31template <class ModelType, class RandType, class StateUpdater>
32IntegerSAResult
33BaseSA(const ModelType &model, const utility::TemperatureSchedule schedule,
34 const std::int64_t num_sweeps, const typename RandType::result_type seed,
35 const double min_T, const double max_T, const bool log_history) {
36
37 // Initialize the system
39
40 // Initialize the updater
41 auto state_updater = StateUpdater{};
42
43 auto get_T = [&](const std::int64_t sweep) {
45 return max_T +
46 (min_T - max_T) * (static_cast<double>(sweep) / (num_sweeps - 1));
47 } else if (schedule == utility::TemperatureSchedule::GEOMETRIC) {
48 return max_T * std::pow(min_T / max_T,
49 static_cast<double>(sweep) / (num_sweeps - 1));
50 } else {
51 throw std::runtime_error("Unknown temperature schedule");
52 }
53 };
54
55 const std::int64_t num_variables = model.GetNumVariables();
56 IntegerSAResult result;
57
58 for (std::int64_t sweep = 0; sweep < num_sweeps; ++sweep) {
59 const double T = get_T(sweep);
60 const double progress = static_cast<double>(sweep) / (num_sweeps - 1);
61 for (std::int64_t i = 0; i < num_variables; ++i) {
62 const auto new_x =
63 state_updater.GenerateNewValue(sa_system, i, T, progress);
64 sa_system.SetValue(i, new_x);
65 }
66 if (log_history) {
67 result.energy_history.push_back(sa_system.GetEnergy());
68 result.temperature_history.push_back(T);
69 }
70 }
71
72 result.energy = sa_system.GetEnergy();
73 result.solution.resize(num_variables);
74 const auto &state = sa_system.GetState();
75 for (std::int64_t i = 0; i < num_variables; ++i) {
76 result.solution[i] = state[i].value;
77 }
78
79 return result;
80}
81
82template <class ModelType, class UpdaterType>
84 const std::int64_t num_sweeps,
85 const algorithm::RandomNumberEngine rand_type,
86 const utility::TemperatureSchedule schedule,
87 const std::int64_t seed, const double min_T,
88 const double max_T, const bool log_history) {
89 switch (rand_type) {
91 return BaseSA<ModelType, utility::Xorshift, UpdaterType>(
92 model, schedule, num_sweeps,
93 static_cast<utility::Xorshift::result_type>(seed), min_T, max_T,
94 log_history);
96 return BaseSA<ModelType, std::mt19937, UpdaterType>(
97 model, schedule, num_sweeps,
98 static_cast<std::mt19937::result_type>(seed), min_T, max_T,
99 log_history);
101 return BaseSA<ModelType, std::mt19937_64, UpdaterType>(
102 model, schedule, num_sweeps, seed, min_T, max_T, log_history);
103 default:
104 throw std::runtime_error("Unknown random number engine");
105 }
106}
107
108template <class ModelType>
109IntegerSAResult SolveByIntegerSA(const ModelType &model,
110 const std::int64_t num_sweeps,
111 const algorithm::UpdateMethod update_method,
112 const algorithm::RandomNumberEngine rand_type,
113 const utility::TemperatureSchedule schedule,
114 const std::int64_t seed, const double min_T,
115 const double max_T, const bool log_history) {
116
117 switch (update_method) {
119 return SolveByIntegerSAImpl<ModelType, updater::MetropolisUpdater>(
120 model, num_sweeps, rand_type, schedule, seed, min_T, max_T, log_history);
122 return SolveByIntegerSAImpl<ModelType, updater::HeatBathUpdater>(
123 model, num_sweeps, rand_type, schedule, seed, min_T, max_T, log_history);
125 return SolveByIntegerSAImpl<ModelType, updater::SuwaTodoUpdater>(
126 model, num_sweeps, rand_type, schedule, seed, min_T, max_T, log_history);
128 return SolveByIntegerSAImpl<ModelType, updater::OptMetropolisUpdater>(
129 model, num_sweeps, rand_type, schedule, seed, min_T, max_T, log_history);
130 default:
131 throw std::runtime_error("Unknown update method");
132 }
133}
134
135template <class ModelType>
136std::vector<IntegerSAResult>
137SampleByIntegerSA(const ModelType &model, const std::int64_t num_sweeps,
138 const algorithm::UpdateMethod update_method,
139 const algorithm::RandomNumberEngine rand_type,
140 const utility::TemperatureSchedule schedule,
141 const std::int64_t num_reads, const std::int64_t seed,
142 const std::int32_t num_threads, const double min_T,
143 const double max_T, const bool log_history) {
144
145 std::vector<IntegerSAResult> results(num_reads);
146
147#pragma omp parallel for schedule(guided) num_threads(num_threads)
148 for (std::int64_t i = 0; i < num_reads; ++i) {
149 results[i] =
150 SolveByIntegerSA(model, num_sweeps, update_method, rand_type, schedule,
151 seed + i, min_T, max_T, log_history);
152 }
153
154 return results;
155}
156
157} // namespace sampler
158} // namespace openjij
Definition sa_system.hpp:24
uint_fast32_t result_type
Definition random.hpp:41
UpdateMethod
Definition algorithm.hpp:63
@ OPT_METROPOLIS
Metropolis update with optimal transition.
@ METROPOLIS
Metropolis update.
RandomNumberEngine
Definition algorithm.hpp:78
@ MT_64
64-bit Mersenne Twister
IntegerSAResult SolveByIntegerSA(const ModelType &model, const std::int64_t num_sweeps, const algorithm::UpdateMethod update_method, const algorithm::RandomNumberEngine rand_type, const utility::TemperatureSchedule schedule, const std::int64_t seed, const double min_T, const double max_T, const bool log_history)
Definition integer_sa_sampler.hpp:109
IntegerSAResult BaseSA(const ModelType &model, const utility::TemperatureSchedule schedule, const std::int64_t num_sweeps, const typename RandType::result_type seed, const double min_T, const double max_T, const bool log_history)
Definition integer_sa_sampler.hpp:33
IntegerSAResult SolveByIntegerSAImpl(const ModelType &model, const std::int64_t num_sweeps, const algorithm::RandomNumberEngine rand_type, const utility::TemperatureSchedule schedule, const std::int64_t seed, const double min_T, const double max_T, const bool log_history)
Definition integer_sa_sampler.hpp:83
std::vector< IntegerSAResult > SampleByIntegerSA(const ModelType &model, const std::int64_t num_sweeps, const algorithm::UpdateMethod update_method, const algorithm::RandomNumberEngine rand_type, const utility::TemperatureSchedule schedule, const std::int64_t num_reads, const std::int64_t seed, const std::int32_t num_threads, const double min_T, const double max_T, const bool log_history)
Definition integer_sa_sampler.hpp:137
TemperatureSchedule
Definition schedule_list.hpp:259
Definition algorithm.hpp:24
Definition integer_sa_sampler.hpp:24
std::vector< std::int64_t > solution
Definition integer_sa_sampler.hpp:26
std::vector< double > temperature_history
Definition integer_sa_sampler.hpp:28
double energy
Definition integer_sa_sampler.hpp:25
std::vector< double > energy_history
Definition integer_sa_sampler.hpp:27