Azure AI Searchの垂直統合されたベクトル検索用LangChain Retrieverを作る

はじめに

前回の記事

tech.mti.co.jp

前回、Rest APIで垂直統合されたベクトル検索を試行しました。この機能はすでに Python の Azure SDK では対応していますが、LangChain の Retriever として利用できるとより便利そうです。

Custom Retriever の学習がてら実験してみたいと思います。

OSSライブラリの実装を確認する

langchain-community ライブラリで AzureAISearchRetriever が公開されています。

python.langchain.com

実装を一部引用します。

class AzureAISearchRetriever(BaseRetriever):
~略~
    def _search(self, query: str) -> List[dict]:
        search_url = self._build_search_url(query)
        response = requests.get(search_url, headers=self._headers)
        if response.status_code != 200:
            raise Exception(f"Error in search request: {response}")
~略~
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        search_results = self._search(query)

        return [
            Document(page_content=result.pop(self.content_key), metadata=result)
            for result in search_results
        ]

引用終わり

なるほど、Rest APIのURLを作成して、パラメータのテキスト検索クエリと合成して実行しているシンプルな作りのようです。

この実装を、垂直統合されたベクトル検索で利用するJSON BodyをPOSTするように変更するだけでよさそうです。

実装

Custom Retriever を1から実装してもそこまで難しくなさそうですが、今回は AzureAISearchRetriever を継承して実験してみます。

import json
from typing import Any, Dict, List, Optional

import aiohttp
import requests

from langchain_community.retrievers import AzureAISearchRetriever
from langchain_core.utils import get_from_env

DEFAULT_URL_SUFFIX = "search.windows.net"

class AzureAISearchVectorizedTextQueryRetriever(AzureAISearchRetriever):
    count: bool = True
    exhaustive: bool = True
    vector_key: str = "vector"
    api_version: str = "2024-07-01"

    def _build_search_url(self) -> str:
        url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
        if url_suffix in self.service_name and "https://" in self.service_name:
            base_url = f"{self.service_name}/"
        elif url_suffix in self.service_name and "https://" not in self.service_name:
            base_url = f"https://{self.service_name}/"
        elif url_suffix not in self.service_name and "https://" in self.service_name:
            base_url = f"{self.service_name}.{url_suffix}/"
        elif (
            url_suffix not in self.service_name and "https://" not in self.service_name
        ):
            base_url = f"https://{self.service_name}.{url_suffix}/"
        else:
            # pass to Azure to throw a specific error
            base_url = self.service_name
        endpoint_path = f"indexes('{self.index_name}')/docs/search.post.search?api-version={self.api_version}"

        return base_url + endpoint_path 

    def _build_search_body(self, query: str) -> Dict[str, Any]:
        return {
            "count": self.count,
            "select": 'id,'+self.content_key,
            "vectorQueries": [{
                "kind": "text",
                "text": query,
                "fields": self.vector_key,
                "k": self.top_k,
                "exhaustive": self.exhaustive
            }]
        }
    
    def _search(self, query: str) -> List[dict]:
        search_url = self._build_search_url()
        response = requests.post(search_url, headers=self._headers, json=self._build_search_body(query))
        if response.status_code != 200:
            raise Exception(f"Error in search request: {response}")

        return json.loads(response.text)["value"]
    
    async def _asearch(self, query: str) -> List[dict]:
        search_url = self._build_search_url()
        if not self.aiosession:
            async with aiohttp.ClientSession() as session:
                async with session.post(search_url, headers=self._headers, json=self._build_search_body(query)) as response:
                    response_json = await response.json()
        else:
            async with self.aiosession.post(
                search_url, headers=self._headers, json=self._build_search_body(query)
            ) as response:
                response_json = await response.json()

        return response_json["value"]

API Versionの初期値を設定して、Azure AI Search のドキュメントに記載された単純例の JSONを送信するようにオーバーライドしているだけですね。ほか、URLを作成している箇所など、元のクラスのほとんどコピペです。 このままライブラリとして配布できるような実装ではありませんが、今回は実験ですのでこれで許容します。

アプリケーションから利用してみる

前回作った「坊っちゃん」のインデックスでRAGをチェインします。

LLMは別途Open AIリソースにデプロイ済みの gpt-4o-mini を利用しています。

from typing import List
from langchain_core.documents import Document
from langchain_openai import AzureChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from custom_azure_ai_search import AzureAISearchVectorizedTextQueryRetriever

llm = AzureChatOpenAI(
    azure_deployment="gpt-4o-mini",
    azure_endpoint="https://**********.openai.azure.com/",
    api_version="2024-06-01"
)

retriever = AzureAISearchVectorizedTextQueryRetriever(
    service_name="**********",
    index_name="techblog-index",
    top_k=5,
    content_key="chunk"
)

prompt_template = ChatPromptTemplate.from_template(
"""
次の文章をもとに回答してください。

文章: {chunks}

質問: {question}
"""
)

def format_docs(docs: List[Document]) -> str:
    return "\n\n".join(doc.page_content for doc in docs)

chain = (
    {"chunks": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt_template
    | llm
    | StrOutputParser()
)
result = chain.invoke("主人公は何を4つも食べたのですか?それがわかるシーンの説明とともに教えてください。")

print(result)

主人公は「天麩羅」を4つ食べました。このことがわかるシーンは、彼が教場に入ったときのことです。黒板には「天麩羅先生」と書かれており、生徒たちが笑っています。主人公は「天麩羅を食っちゃ可笑しいか」と尋ねますが、生徒の一人が「四杯は過ぎるぞな」と言います。これに対して、主人公は「四杯食おうが五杯食おうがおれの銭でおれが食うのに文句があるもんか」と反論し、講義を終えて控所に戻ります。その後、次の教場に出ると黒板に「天麩羅四杯なり。但し笑うべからず」と書かれており、主人公はそれを見て癪に障ったと感じます。このように、主人公が天麩羅を4つ食べたことが明示されています。

うまく結果が出ました。