Skip to content

Commit

Permalink
readability
Browse files Browse the repository at this point in the history
  • Loading branch information
caleb-johnson committed Sep 25, 2024
1 parent 986b72c commit bd55063
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions qiskit_addon_dice_solver/dice_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,37 +126,37 @@ def solve_fermion(

# Set up the temp directory
temp_dir = temp_dir or tempfile.gettempdir()
intermediate_dir = Path(tempfile.mkdtemp(prefix="dice_cli_files_", dir=temp_dir))
dice_dir = Path(tempfile.mkdtemp(prefix="dice_cli_files_", dir=temp_dir))

# Write the integrals out as an FCI dump for Dice command line app
active_space_path = os.path.join(intermediate_dir, "fcidump.txt")
active_space_path = os.path.join(dice_dir, "fcidump.txt")
num_orbitals = hcore.shape[0]
tools.fcidump.from_integrals(
active_space_path, hcore, eri, num_orbitals, (num_up + num_dn)
)

_write_input_files(
ci_strs,
active_space_path,
num_up,
num_dn,
num_configurations,
intermediate_dir,
spin_sq,
1,
addresses=ci_strs,
active_space_path=active_space_path,
num_up=num_up,
num_dn=num_dn,
num_configurations=num_configurations,
dice_dir=dice_dir,
spin_sq=spin_sq,
max_iter=1,
)

# Navigate to temp dir and call Dice
_call_dice(intermediate_dir, mpirun_options)
# Navigate to dice dir and call Dice
_call_dice(dice_dir, mpirun_options)

# Read outputs and convert outputs
e_dice, sci_coefficients, avg_occupancies = _read_dice_outputs(
intermediate_dir, num_orbitals
dice_dir, num_orbitals
)

# Clean up the temp directory of intermediate files, if desired
if clean_temp_dir:
shutil.rmtree(intermediate_dir)
shutil.rmtree(dice_dir)

return (
e_dice,
Expand Down Expand Up @@ -268,11 +268,11 @@ def solve_dice(


def _read_dice_outputs(
temp_dir: str | Path, num_orbitals: int
dice_dir: str | Path, num_orbitals: int
) -> tuple[float, np.ndarray, np.ndarray]:
"""Calculate the estimated ground state energy and average orbitals occupancies from Dice outputs."""
# Read in the avg orbital occupancies
spin1_rdm_dice = np.loadtxt(os.path.join(temp_dir, "spin1RDM.0.0.txt"), skiprows=1)
spin1_rdm_dice = np.loadtxt(os.path.join(dice_dir, "spin1RDM.0.0.txt"), skiprows=1)
avg_occupancies = np.zeros(2 * num_orbitals)
for i in range(spin1_rdm_dice.shape[0]):
if spin1_rdm_dice[i, 0] == spin1_rdm_dice[i, 1]:
Expand All @@ -283,23 +283,23 @@ def _read_dice_outputs(
)

# Read in the estimated ground state energy
file_energy = open(os.path.join(temp_dir, "shci.e"), "rb")
file_energy = open(os.path.join(dice_dir, "shci.e"), "rb")
bytestring_energy = file_energy.read(8)
energy_dice = struct.unpack("d", bytestring_energy)[0]

# Construct the SCI wavefunction coefficients from Dice output dets.bin
occs, amps = _read_wave_function_magnitudes(os.path.join(temp_dir, "dets.bin"))
occs, amps = _read_wave_function_magnitudes(os.path.join(dice_dir, "dets.bin"))
addresses = _addresses_from_occupancies(occs)
sci_coefficients = _construct_ci_vec_from_addresses_amplitudes(amps, addresses)

return energy_dice, sci_coefficients, avg_occupancies


def _call_dice(temp_dir: Path, mpirun_options: Sequence[str] | str | None) -> None:
"""Navigate to the temp dir, invoke Dice, and navigate back."""
def _call_dice(dice_dir: Path, mpirun_options: Sequence[str] | str | None) -> None:
"""Navigate to the dice dir, invoke Dice, and navigate back."""
script_dir = os.path.dirname(os.path.abspath(__file__))
dice_path = os.path.join(script_dir, "bin", "Dice")
dice_log_path = os.path.join(temp_dir, "dice_solver_logfile.log")
dice_log_path = os.path.join(dice_dir, "dice_solver_logfile.log")
if mpirun_options:
if isinstance(mpirun_options, str):
mpirun_options = [mpirun_options]
Expand All @@ -310,7 +310,7 @@ def _call_dice(temp_dir: Path, mpirun_options: Sequence[str] | str | None) -> No
with open(dice_log_path, "w") as logfile:
try:
subprocess.run(
dice_call, cwd=temp_dir, stdout=logfile, stderr=logfile, check=True
dice_call, cwd=dice_dir, stdout=logfile, stderr=logfile, check=True
)
except subprocess.CalledProcessError as e:
raise DiceExecutionError(
Expand All @@ -326,13 +326,13 @@ def _write_input_files(
num_up: int,
num_dn: int,
num_configurations: int,
temp_dir: str | Path,
dice_dir: str | Path,
spin_sq: float,
max_iter: int,
) -> None:
"""Prepare the Dice inputs in the temp directory."""
### Move the FCI Dump to temp dir ###
shutil.copy(active_space_path, os.path.join(temp_dir, "fcidump.txt"))
"""Prepare the Dice inputs in the specified directory."""
### Move the FCI Dump to dice dir ###
shutil.copy(active_space_path, os.path.join(dice_dir, "fcidump.txt"))

### Write the input.dat ###
num_elec = num_up + num_dn
Expand Down Expand Up @@ -380,19 +380,19 @@ def _write_input_files(
nocc,
dummy_det,
]
file1 = open(os.path.join(temp_dir, "input.dat"), "w")
file1 = open(os.path.join(dice_dir, "input.dat"), "w")
file1.writelines(input_list)
file1.close()

### Write the determinants to temp dir ###
### Write the determinants to dice dir ###
up_addr, dn_addr = addresses
bytes_up = _address_list_to_bytes(up_addr)
bytes_dn = _address_list_to_bytes(dn_addr)
file1 = open(os.path.join(temp_dir, "AlphaDets.bin"), "wb") # type: ignore
file1 = open(os.path.join(dice_dir, "AlphaDets.bin"), "wb") # type: ignore
for bytestring in bytes_up:
file1.write(bytestring) # type: ignore
file1.close()
file1 = open(os.path.join(temp_dir, "BetaDets.bin"), "wb") # type: ignore
file1 = open(os.path.join(dice_dir, "BetaDets.bin"), "wb") # type: ignore
for bytestring in bytes_dn:
file1.write(bytestring) # type: ignore
file1.close()
Expand Down

0 comments on commit bd55063

Please sign in to comment.