Qrious/Qrious/BERT.swift

62 lines
2.4 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
See LICENSE folder for this samples licensing information.
Abstract:
Wrapper class for the BERT model that handles its input and output.
*/
import CoreML
class BERT {
/// The underlying Core ML Model.
var bertModel: BERTQAFP16 {
do {
return try BERTQAFP16(configuration: .init())
} catch {
fatalError("Couldn't load BERT model due to: \(error.localizedDescription)")
}
}
/// Finds an answer to a given question by searching a document's text.
///
/// - parameters:
/// - question: The user's question about a document.
/// - document: The document text that (should) contain an answer.
/// - returns: The answer string or an error descripton.
/// - Tag: FindAnswerForQuestionInDocument
func findAnswer(for question: String, in document: String) -> Substring {
// Prepare the input for the BERT model.
let bertInput = BERTInput(documentString: document, questionString: question)
guard bertInput.totalTokenSize <= BERTInput.maxTokens else {
var message = "Text and question are too long"
message += " (\(bertInput.totalTokenSize) tokens)"
message += " for the BERT model's \(BERTInput.maxTokens) token limit."
return Substring(message)
}
// The MLFeatureProvider that supplies the BERT model with its input MLMultiArrays.
let modelInput = bertInput.modelInput!
// Make a prediction with the BERT model.
guard let prediction = try? bertModel.prediction(input: modelInput) else {
return "The BERT model is unable to make a prediction."
}
// Analyze the output form the BERT model.
guard let bestLogitIndices = bestLogitsIndices(from: prediction,
in: bertInput.documentRange) else {
return "Couldn't find a valid answer. Please try again."
}
// Find the indices of the original string.
let documentTokens = bertInput.document.tokens
let answerStart = documentTokens[bestLogitIndices.start].startIndex
let answerEnd = documentTokens[bestLogitIndices.end].endIndex
// Return the portion of the original string as the answer.
let originalText = bertInput.document.original
return originalText[answerStart..<answerEnd]
}
}