# This code is a Qiskit project.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Aer MPS state and basic operations that do not require an MPS simulator."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from plum import dispatch
from qiskit.circuit import Gate
from ..abstract import TensorNetworkState
[docs]
@dataclass
class QiskitAerMPS(TensorNetworkState):
    """Qiskit Aer representation of a matrix-product state.
    This form of a matrix-product state was introduced in the original
    time-evolving block decimation (TEBD) paper,
    https://arxiv.org/abs/quant-ph/0301063.
    See Sec. 7.3.2 of https://arxiv.org/abs/1008.3477v2 for more details on the relationship
    with other MPS representations.
    """
    #: Gamma matrices (list of 2-tuples of ndarrays, one for each site).
    gamma: list[tuple[np.ndarray, np.ndarray]]
    #: Lambda matrices (list of ndarrays, one for each bond between adjacent qubits).
    lamb: list[np.ndarray]
    def _as_tuple(self):
        # This form is useful when passing to Qiskit Aer
        return (self.gamma, self.lamb) 
def _validate_mps(mps: QiskitAerMPS, /) -> None:
    if len(mps.gamma) != len(mps.lamb) + 1:
        raise RuntimeError("The lambda matrix array has an unexpected length.")
def _validate_mps_compatibility(mps1: QiskitAerMPS, mps2: QiskitAerMPS, /) -> None:
    """Ensure that the matrix-product states are defined on the same number of qubits."""
    if len(mps1.gamma) != len(mps2.gamma):
        raise ValueError(
            "The matrix-product states have different numbers of qubits "
            f"({len(mps1.gamma)} vs. {len(mps2.gamma)})."
        )
    _validate_mps(mps1)
    _validate_mps(mps2)
@dispatch
def compute_overlap(mps1: QiskitAerMPS, mps2: QiskitAerMPS, /) -> complex:
    num_qubits = len(mps1.gamma)
    _validate_mps_compatibility(mps1, mps2)
    # Remove singleton dimension, so the product "a_b" becomes a matrix.
    # This variable accumulates the result of contraction of two states.
    a = np.squeeze(mps1.gamma[0], axis=1)
    b = np.squeeze(mps2.gamma[0], axis=1)
    a_b = np.tensordot(np.conj(a), b, axes=([0], [0]))
    for n in range(1, num_qubits):
        # Multiply the matrix "a_b" by Diag(Lambda1) and Diag(Lambda2)
        # on respective dimensions.
        a_b *= np.expand_dims(mps1.lamb[n - 1], axis=1)
        a_b *= np.expand_dims(mps2.lamb[n - 1], axis=0)
        # Contraction of the next couple of Gamma tensors.
        a_b = np.tensordot(a_b, np.conj(mps1.gamma[n]), axes=([0], [1]))
        a_b = np.tensordot(a_b, mps2.gamma[n], axes=([0, 1], [1, 0]))
    return complex(a_b.item())
@dispatch
def _compute_overlap_with_local_gate_applied(
    mps1: QiskitAerMPS, gate: Gate, qubit: int, mps2: QiskitAerMPS, /
) -> complex:
    num_qubits = len(mps1.gamma)
    _validate_mps_compatibility(mps1, mps2)
    if qubit not in range(num_qubits):
        raise IndexError(f"Invalid qubit index for {num_qubits} qubits: {qubit}")
    if gate.num_qubits != 1:
        raise ValueError("The gate must act on a single qubit.")
    # The following line is expected to always succeed, since plum-dispatch has
    # already identified it as a [unitary] Gate.
    gate_matrix = gate.to_matrix()
    # Remove singleton dimension, so the product "a_b" becomes a matrix.
    # This variable accumulates the result of contraction of two states.
    # Multiply the second Gamma tensor by the gate matrix, if necessary.
    a = np.squeeze(mps1.gamma[0], axis=1)
    b = np.squeeze(mps2.gamma[0], axis=1)
    if qubit == 0:
        a_b = np.tensordot(np.conj(a), gate_matrix @ b, axes=([0], [0]))
    else:
        a_b = np.tensordot(np.conj(a), b, axes=([0], [0]))
    for n in range(1, num_qubits):
        # Multiply the matrix "a_b" by Diag(Lambda1) and Diag(Lambda2)
        # on respective dimensions.
        a_b *= np.expand_dims(mps1.lamb[n - 1], axis=1)
        a_b *= np.expand_dims(mps2.lamb[n - 1], axis=0)
        # Contraction of the next couple of Gamma tensors.
        # Multiply the second Gamma tensor by the gate matrix, if necessary.
        a_b = np.tensordot(a_b, np.conj(mps1.gamma[n]), axes=([0], [1]))
        if qubit == n:
            a_b = np.tensordot(
                a_b,
                np.tensordot(gate_matrix, mps2.gamma[n], axes=([1], [0])),
                axes=([0, 1], [1, 0]),
            )
        else:
            a_b = np.tensordot(a_b, mps2.gamma[n], axes=([0, 1], [1, 0]))
    return complex(a_b.item())