mirror of https://github.com/llvm/torch-mlir
Get simple quantized model importing.
This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.pull/217/head
parent
0c89296075
commit
d66e8fe1f8
|
@ -19,6 +19,7 @@
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/Types.h"
|
||||||
|
|
||||||
#include "caffe2/core/scope_guard.h"
|
#include "caffe2/core/scope_guard.h"
|
||||||
|
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -101,10 +102,11 @@ public:
|
||||||
: importBlock(importBlock), context(context), typeMapper(context),
|
: importBlock(importBlock), context(context), typeMapper(context),
|
||||||
annotator(annotator) {}
|
annotator(annotator) {}
|
||||||
|
|
||||||
MlirValue importIValue(c10::IValue value);
|
MlirValue importIValue(c10::IValue ivalue);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirValue rawImportIValue(c10::IValue value);
|
MlirValue rawImportIValue(c10::IValue ivalue);
|
||||||
|
MlirValue importTensor(c10::IValue ivalue);
|
||||||
MlirValue importModule(torch::jit::Module jitModule);
|
MlirValue importModule(torch::jit::Module jitModule);
|
||||||
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
|
||||||
const MethodAnnotation &methodAnnotation);
|
const MethodAnnotation &methodAnnotation);
|
||||||
|
@ -284,16 +286,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
return importTensor(ivalue);
|
||||||
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
|
||||||
MlirOperation constant = createMlirOperationAtEnd(
|
|
||||||
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
|
|
||||||
toMlirNamedAttribute("value", denseElements));
|
|
||||||
MlirOperation ndarray = createMlirOperationAtEnd(
|
|
||||||
importBlock, "numpy.create_array_from_tensor", loc,
|
|
||||||
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
|
|
||||||
mlirOperationGetResult(constant, 0));
|
|
||||||
return mlirOperationGetResult(ndarray, 0);
|
|
||||||
}
|
}
|
||||||
if (ivalue.isModule()) {
|
if (ivalue.isModule()) {
|
||||||
return importModule(ivalue.toModule());
|
return importModule(ivalue.toModule());
|
||||||
|
@ -313,11 +306,83 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
|
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
|
if (ivalue.isCustomClass()) {
|
||||||
|
if (ivalue.type().get() ==
|
||||||
|
c10::getCustomClassType<c10::intrusive_ptr<LinearPackedParamsBase>>()
|
||||||
|
.get()) {
|
||||||
|
c10::intrusive_ptr<LinearPackedParamsBase> linearParams =
|
||||||
|
ivalue.toCustomClass<LinearPackedParamsBase>();
|
||||||
|
at::Tensor weight;
|
||||||
|
c10::optional<at::Tensor> bias;
|
||||||
|
std::tie(weight, bias) = linearParams->unpack();
|
||||||
|
MlirValue weightValue = importIValue(c10::IValue(weight));
|
||||||
|
c10::optional<MlirValue> biasValue = c10::nullopt;
|
||||||
|
if (bias.has_value()) {
|
||||||
|
biasValue = importIValue(c10::IValue(*bias));
|
||||||
|
}
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.linear_params.create", loc,
|
||||||
|
npcompLinearParamsTypeGet(context), weightValue, biasValue);
|
||||||
|
return mlirOperationGetResult(operation, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "Unsupported ivalue: " << ivalue;
|
msg << "Unsupported ivalue: " << ivalue;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
|
assert(ivalue.isTensor() && "expected a tensor!");
|
||||||
|
|
||||||
|
// TODO: Can we do better?
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
// Import the bulk tensor representation.
|
||||||
|
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||||
|
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
||||||
|
MlirOperation constant = createMlirOperationAtEnd(
|
||||||
|
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
|
||||||
|
toMlirNamedAttribute("value", denseElements));
|
||||||
|
MlirValue tensorReprValue = mlirOperationGetResult(constant, 0);
|
||||||
|
|
||||||
|
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||||
|
// for quantized tensors (and probably sparse too, TBD) there is more for us
|
||||||
|
// to do.
|
||||||
|
MlirValue tensorValue;
|
||||||
|
if (tensor.is_quantized()) {
|
||||||
|
// Note that Torch models quantization in a type-erased way. So we don't
|
||||||
|
// make an effort here to do any special static modeling. If desired, later
|
||||||
|
// compiler stages that are building a statically modeled quantization
|
||||||
|
// representation will need to convert this to their representation.
|
||||||
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||||
|
MlirType quantizedTensorType = mlirRankedTensorTypeGetChecked(
|
||||||
|
loc, shape.size(), shape.data(),
|
||||||
|
typeMapper.mapFromTorchScalarType(tensor.scalar_type()), {nullptr});
|
||||||
|
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||||
|
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||||
|
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||||
|
MlirOperation quantizedTensor = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.per_tensor_affine.create", loc,
|
||||||
|
quantizedTensorType, tensorReprValue, qScale, zeroPoint);
|
||||||
|
tensorValue = mlirOperationGetResult(quantizedTensor, 0);
|
||||||
|
} else {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Unsupported quantization scheme '"
|
||||||
|
<< c10::toString(tensor.qscheme()) << "' for tensor: " << ivalue;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tensorValue = tensorReprValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the tensor to ndarray to match Torch's default-mutable semantics.
|
||||||
|
MlirOperation ndarray = createMlirOperationAtEnd(
|
||||||
|
importBlock, "numpy.create_array_from_tensor", loc,
|
||||||
|
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
|
||||||
|
tensorValue);
|
||||||
|
return mlirOperationGetResult(ndarray, 0);
|
||||||
|
}
|
||||||
|
|
||||||
void IValueImporter::importMethod(torch::jit::Function *function,
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
MlirBlock classTypeBody,
|
MlirBlock classTypeBody,
|
||||||
const MethodAnnotation &methodAnnotation) {
|
const MethodAnnotation &methodAnnotation) {
|
||||||
|
|
|
@ -67,12 +67,59 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
return mlirBF16TypeGet(context);
|
return mlirBF16TypeGet(context);
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return mlirF16TypeGet(context);
|
return mlirF16TypeGet(context);
|
||||||
|
case ScalarType::QInt8:
|
||||||
|
return npcompQInt8TypeGet(context);
|
||||||
default: {
|
default: {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Types (such as `LinearPackedParamsBase`) implemented with the
|
||||||
|
// `torch::CustomClassHolder` mechanism described at
|
||||||
|
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
|
||||||
|
// are modeled with ordinary c10::ClassType's, but require special handling
|
||||||
|
// for importing.
|
||||||
|
//
|
||||||
|
// These types correspond to c10::IValue's with `isCustomClass() == true`.
|
||||||
|
//
|
||||||
|
// Under the hood, Torch represents such "custom classes" using the
|
||||||
|
// "object" variant of c10::IValue and a class type with one slot holding a
|
||||||
|
// type-erased c10::intrusive_ptr to the custom type. One the side, it keeps
|
||||||
|
// a registry of custom classes which is used to implement `isCustomClass()`
|
||||||
|
// by checking names against the registry.
|
||||||
|
//
|
||||||
|
// There is no generic way to import custom classes (or their types), so we
|
||||||
|
// have to name match them here (and the relevant code in the ivalue
|
||||||
|
// importer) and create special IR constructs for them.
|
||||||
|
static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
||||||
|
const c10::ClassTypePtr &classType) {
|
||||||
|
// If the type is unnamed, it cannot be a custom class.
|
||||||
|
if (!classType->name().has_value()) {
|
||||||
|
return {nullptr};
|
||||||
|
}
|
||||||
|
std::string name = classType->name()->qualifiedName();
|
||||||
|
// If the type is not stored in the custom class registry, it cannot be a
|
||||||
|
// custom class.
|
||||||
|
if (!torch::getCustomClass(name)) {
|
||||||
|
return {nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Individually handle the custom classes that we know about.
|
||||||
|
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
||||||
|
return npcompLinearParamsTypeGet(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we know that the type is indeed a custom class type, but
|
||||||
|
// that we don't know how to specially import it. We cannot proceed, so emit a
|
||||||
|
// diagnostic and halt compilation.
|
||||||
|
std::stringstream message;
|
||||||
|
message << "unable to import Torch CustomClass type '" << classType
|
||||||
|
<< "' to MLIR type";
|
||||||
|
mlirEmitError(loc, message.str().c_str());
|
||||||
|
throw mlir_diagnostic_emitted();
|
||||||
|
}
|
||||||
|
|
||||||
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
const c10::TypePtr &torchType) {
|
const c10::TypePtr &torchType) {
|
||||||
using c10::TypeKind;
|
using c10::TypeKind;
|
||||||
|
@ -106,10 +153,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||||
}
|
}
|
||||||
case TypeKind::ClassType: {
|
case TypeKind::ClassType: {
|
||||||
auto maybeName = torchType->cast<c10::ClassType>()->name();
|
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||||
return npcompNnModuleTypeGet(
|
MlirType customClassType = mapCustomClassType(context, loc, classType);
|
||||||
context, toMlirStringRef(maybeName ? maybeName->qualifiedName()
|
if (!mlirTypeIsNull(customClassType)) {
|
||||||
: "unnamed class"));
|
return customClassType;
|
||||||
|
}
|
||||||
|
auto maybeName = classType->name();
|
||||||
|
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||||
|
return npcompNnModuleTypeGet(context, toMlirStringRef(name));
|
||||||
}
|
}
|
||||||
case TypeKind::FloatType: {
|
case TypeKind::FloatType: {
|
||||||
return mlirF64TypeGet(context);
|
return mlirF64TypeGet(context);
|
||||||
|
@ -220,6 +271,14 @@ MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
||||||
// Basicpy bool type.
|
// Basicpy bool type.
|
||||||
elementType = mlirIntegerTypeGet(context, 1);
|
elementType = mlirIntegerTypeGet(context, 1);
|
||||||
|
} else if (tensor.scalar_type() == ScalarType::QInt8) {
|
||||||
|
// This function returns the underlying integer representation of the tensor
|
||||||
|
// as an elements attr. That underlying representation is of type i8
|
||||||
|
// for a torch.qint8 tensor.
|
||||||
|
// Caller code is responsible for materializing the proper op that
|
||||||
|
// incorporates the quantization scheme to create a tensor of `!torch.qint8`
|
||||||
|
// element type.
|
||||||
|
elementType = mlirIntegerTypeGet(context, 8);
|
||||||
} else {
|
} else {
|
||||||
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
||||||
}
|
}
|
||||||
|
@ -259,6 +318,9 @@ MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
|
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
|
||||||
static_cast<const int *>(tensorData));
|
static_cast<const int *>(tensorData));
|
||||||
break;
|
break;
|
||||||
|
case ScalarType::QInt8:
|
||||||
|
return mlirDenseElementsAttrInt8Get(
|
||||||
|
shapedType, numElements, static_cast<const int8_t *>(tensorData));
|
||||||
default:
|
default:
|
||||||
throwUnsupportedTensorError();
|
throwUnsupportedTensorError();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
|
||||||
import basic
|
import basic
|
||||||
import vision_models
|
import vision_models
|
||||||
import mlp
|
import mlp
|
||||||
|
import quantized_models
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||||
|
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedMLP(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(16, 8),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(8, 4),
|
||||||
|
)
|
||||||
|
self.quantize = torch.quantization.QuantStub()
|
||||||
|
self.dequantize = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 16], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.quantize(x)
|
||||||
|
x = self.layers(x)
|
||||||
|
x = self.dequantize(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlp_input():
|
||||||
|
return 2 * torch.rand((1, 16)) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantized_mlp():
|
||||||
|
model = QuantizedMLP()
|
||||||
|
model.eval()
|
||||||
|
model.qconfig = torch.quantization.default_qconfig
|
||||||
|
torch.quantization.prepare(model, inplace=True)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
for _ in range(32):
|
||||||
|
model(get_mlp_input())
|
||||||
|
torch.quantization.convert(model, inplace=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=get_quantized_mlp)
|
||||||
|
def QuantizedMLP_basic(module, tu: TestUtils):
|
||||||
|
module.forward(get_mlp_input())
|
|
@ -213,6 +213,8 @@ SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]]
|
||||||
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
|
||||||
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
|
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
|
||||||
|
|
||||||
|
# Mapping from torch types to their corresponding ODS type predicates.
|
||||||
|
# Use `get_ods_type` instead of using this directly.
|
||||||
TORCH_TYPE_TO_ODS_TYPE = {
|
TORCH_TYPE_TO_ODS_TYPE = {
|
||||||
"Tensor": "AnyTorchTensorType",
|
"Tensor": "AnyTorchTensorType",
|
||||||
"Tensor?": "AnyTorchOptionalTensor",
|
"Tensor?": "AnyTorchOptionalTensor",
|
||||||
|
@ -229,9 +231,18 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
||||||
"Any": "AnyTorchType",
|
"Any": "AnyTorchType",
|
||||||
"Device": "Torch_DeviceType",
|
"Device": "Torch_DeviceType",
|
||||||
"str": "Basicpy_BytesType",
|
"str": "Basicpy_BytesType",
|
||||||
|
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_ods_type(type: str):
|
||||||
|
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
|
||||||
|
if ods_type is None:
|
||||||
|
raise Exception(
|
||||||
|
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
|
||||||
|
return ods_type
|
||||||
|
|
||||||
|
|
||||||
def _get_main_module_name() -> str:
|
def _get_main_module_name() -> str:
|
||||||
# pytype: disable=attribute-error
|
# pytype: disable=attribute-error
|
||||||
return sys.modules["__main__"].__loader__.name
|
return sys.modules["__main__"].__loader__.name
|
||||||
|
@ -287,7 +298,7 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
||||||
p("Variadic<AnyTorchType>:$operands")
|
p("Variadic<AnyTorchType>:$operands")
|
||||||
else:
|
else:
|
||||||
p(",\n".join([
|
p(",\n".join([
|
||||||
f"""{TORCH_TYPE_TO_ODS_TYPE[arg["type"]]}:${arg["name"]}"""
|
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
|
||||||
for arg in operator.arguments
|
for arg in operator.arguments
|
||||||
]))
|
]))
|
||||||
p(");")
|
p(");")
|
||||||
|
@ -297,7 +308,7 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
|
||||||
p("Variadic<AnyTorchType>:$results")
|
p("Variadic<AnyTorchType>:$results")
|
||||||
else:
|
else:
|
||||||
p(",\n".join([
|
p(",\n".join([
|
||||||
f"""{TORCH_TYPE_TO_ODS_TYPE[ret["type"]]}:${ret["name"] or "result"}"""
|
f"""{get_ods_type(ret["type"])}:${ret["name"] or "result"}"""
|
||||||
for ret in operator.returns
|
for ret in operator.returns
|
||||||
]))
|
]))
|
||||||
p(");")
|
p(");")
|
||||||
|
@ -444,6 +455,19 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::len.t : (t[]) -> (int)", has_canonicalizer=True)
|
emit("aten::len.t : (t[]) -> (int)", has_canonicalizer=True)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||||
|
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||||
|
with open(td_file, "w") as f:
|
||||||
|
f.write(ODS_BANNER)
|
||||||
|
|
||||||
|
def emit(key, **kwargs):
|
||||||
|
emit_op(registry[key], f, **kwargs)
|
||||||
|
|
||||||
|
emit(
|
||||||
|
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
||||||
|
traits=["HasValueSemantics"])
|
||||||
|
|
||||||
|
|
||||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||||
for _, v in sorted(registry.by_unique_key.items()):
|
for _, v in sorted(registry.by_unique_key.items()):
|
||||||
outfile.write(repr(v))
|
outfile.write(repr(v))
|
||||||
|
@ -460,6 +484,7 @@ def main(args: argparse.Namespace):
|
||||||
dump_registered_ops(debug_registry_dump, registry)
|
dump_registered_ops(debug_registry_dump, registry)
|
||||||
emit_prim_ops(args.torch_ir_dir, registry)
|
emit_prim_ops(args.torch_ir_dir, registry)
|
||||||
emit_aten_ops(args.torch_ir_dir, registry)
|
emit_aten_ops(args.torch_ir_dir, registry)
|
||||||
|
emit_quantized_ops(args.torch_ir_dir, registry)
|
||||||
|
|
||||||
|
|
||||||
def _create_argparse() -> argparse.ArgumentParser:
|
def _create_argparse() -> argparse.ArgumentParser:
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.quantized.Linear(5, 2, dtype=torch.qint8)
|
||||||
|
self.linear_no_bias = torch.nn.quantized.Linear(6,
|
||||||
|
2,
|
||||||
|
bias_=False,
|
||||||
|
dtype=torch.qint8)
|
||||||
|
# CHECK-DAG: %[[SCALE:.*]] = basicpy.numeric_constant {{.*}} : f64
|
||||||
|
# CHECK-DAG: %[[ZERO_POINT:.*]] = basicpy.numeric_constant 0 : i64
|
||||||
|
# CHECK-DAG: %[[INT_REPR:.*]] = constant dense<{{.*}}> : tensor<2x5xi8>
|
||||||
|
# CHECK-DAG: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : tensor<2x5xi8>, f64, i64 -> tensor<2x5x!torch.qint8>
|
||||||
|
# CHECK-DAG: %[[WEIGHTS_ARRAY:.*]] = numpy.create_array_from_tensor %[[WEIGHTS]] : (tensor<2x5x!torch.qint8>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK-DAG: %[[BIAS:.*]] = constant dense<{{.*}}> : tensor<2xf32>
|
||||||
|
# CHECK-DAG: %[[BIAS_ARRAY:.*]] = numpy.create_array_from_tensor %[[BIAS]] : (tensor<2xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK-DAG: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS_ARRAY]], %[[BIAS_ARRAY]] : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
@torch.jit.export
|
||||||
|
def test_linear(self, t):
|
||||||
|
return self.linear(t)
|
||||||
|
|
||||||
|
# CHECK: %[[LINEAR_PARAMS_NO_BIAS:.*]] = torch.linear_params.create %{{.*}} : !numpy.ndarray<*:!numpy.any_dtype>{{$}}
|
||||||
|
@torch.jit.export
|
||||||
|
def test_linear_no_bias(self, t):
|
||||||
|
return self.linear_no_bias(t)
|
||||||
|
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -147,6 +147,26 @@ int npcompTypeIsADevice(MlirType t);
|
||||||
/** Gets the !torch.Device type. */
|
/** Gets the !torch.Device type. */
|
||||||
MlirType npcompDeviceTypeGet(MlirContext context);
|
MlirType npcompDeviceTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.LinearParams type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.LinearParams type */
|
||||||
|
int npcompTypeIsALinearParams(MlirType t);
|
||||||
|
|
||||||
|
/** Gets the !torch.LinearParams type. */
|
||||||
|
MlirType npcompLinearParamsTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.qint8 type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.qint8 type */
|
||||||
|
int npcompTypeIsAQInt8(MlirType t);
|
||||||
|
|
||||||
|
/** Gets the !torch.qint8 type. */
|
||||||
|
MlirType npcompQInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
// Operation summaries and descriptions were systematically derived from public
|
||||||
|
// API docstrings and are licensed accordingly:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is automatically generated. Please do not edit.
|
||||||
|
// Generated via:
|
||||||
|
// python -m torch_mlir_utils.codegen.torch_ods_gen
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
|
HasValueSemantics,
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$X,
|
||||||
|
Torch_LinearParamsType:$W_prepack,
|
||||||
|
AnyFloat:$Y_scale_i,
|
||||||
|
AnyTorchIntType:$Y_zero_point_i
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$Y
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` type($X) `,` type($W_prepack) `,` type($Y_scale_i) `,` type($Y_zero_point_i) `->` type($Y)";
|
||||||
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
|
|
||||||
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
|
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
|
||||||
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
|
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
|
||||||
|
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TorchScript `torch.nn.Module` object instantiation ops.
|
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||||
|
@ -439,7 +440,9 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_OperatorOp : Torch_Op<"operator", []> {
|
def Torch_OperatorOp : Torch_Op<"operator", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
let summary = "Opaque torch operator";
|
let summary = "Opaque torch operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
Represents an invocation of a `torch::jit::Operator` for which we don't
|
Represents an invocation of a `torch::jit::Operator` for which we don't
|
||||||
|
@ -458,4 +461,48 @@ def Torch_OperatorOp : Torch_Op<"operator", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Create a `!torch.LinearParams`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$weight,
|
||||||
|
Optional<AnyTorchTensorType>:$bias
|
||||||
|
);
|
||||||
|
let results = (outs Torch_LinearParamsType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$weight (`,` $bias^)? attr-dict `:` type($weight) (`,` type($bias)^)?
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Create a per-tensor-affine quantized tensor";
|
||||||
|
let description = [{
|
||||||
|
Create a quantized tensor.
|
||||||
|
|
||||||
|
Quantization formula is:
|
||||||
|
```
|
||||||
|
Q(x, scale, zero_point) = round(x/scale + zero_point)
|
||||||
|
```
|
||||||
|
|
||||||
|
See:
|
||||||
|
https://pytorch.org/docs/stable/quantization.html#quantized-tensors
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$int_repr,
|
||||||
|
AnyFloat:$scale,
|
||||||
|
AnyTorchIntType:$offset
|
||||||
|
);
|
||||||
|
// TODO: Limit to quantized dtypes (e.g. !torch.qint8).
|
||||||
|
let results = (outs AnyTorchTensorType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$int_repr `,` $scale `,` $offset attr-dict
|
||||||
|
`:` type($int_repr) `,` type($scale) `,` type($offset) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
|
@ -83,6 +83,37 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||||
let summary = "Torch device";
|
let summary = "Torch device";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
|
||||||
|
let summary = "Type modeling `ScalarType::QInt8`";
|
||||||
|
let description = [{
|
||||||
|
This is intended to be a 1:1 match for the Torch `ScalarType` types.
|
||||||
|
|
||||||
|
Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of
|
||||||
|
types, it is deemed preferable to import them as one-off ad-hoc types
|
||||||
|
instead of a single parameterized type.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
||||||
|
let summary = "Torch packed linear params type";
|
||||||
|
let description = [{
|
||||||
|
A weight and optional bias, packed into one value.
|
||||||
|
|
||||||
|
This is used to model the
|
||||||
|
`__torch__.torch.classes.quantized.LinearPackedParamsBase` custom C++ class
|
||||||
|
type which is the input to some Torch `quantized::` ops.
|
||||||
|
|
||||||
|
We may want to eventually have a full set of ops that model the
|
||||||
|
LinearPackedParamsBase interface, such as `apply`, `apply_relu`, etc.
|
||||||
|
But we instead are likely to just expand the `quantized::` ops directly
|
||||||
|
and fold away the instances of this type.
|
||||||
|
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
|
||||||
|
very library-call-y, runtime-y thing that embodies a number of assumptions
|
||||||
|
about the structure of how the program will be executed, which need not hold
|
||||||
|
for npcomp backends.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -189,6 +220,7 @@ def AnyTorchType : AnyTypeOf<[
|
||||||
Torch_NnModuleType,
|
Torch_NnModuleType,
|
||||||
Torch_OptionalType,
|
Torch_OptionalType,
|
||||||
Torch_DeviceType,
|
Torch_DeviceType,
|
||||||
|
Torch_LinearParamsType,
|
||||||
], "Any type that is legal to pass to a Torch kernel">;
|
], "Any type that is legal to pass to a Torch kernel">;
|
||||||
|
|
||||||
#endif // TORCH_TYPES
|
#endif // TORCH_TYPES
|
||||||
|
|
|
@ -188,3 +188,31 @@ int npcompTypeIsADevice(MlirType t) {
|
||||||
MlirType npcompDeviceTypeGet(MlirContext context) {
|
MlirType npcompDeviceTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.LinearParams type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.LinearParams type */
|
||||||
|
int npcompTypeIsALinearParams(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the !torch.LinearParams type. */
|
||||||
|
MlirType npcompLinearParamsTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.qint8 type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.qint8 type */
|
||||||
|
int npcompTypeIsAQInt8(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::QInt8Type>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the !torch.qint8 type. */
|
||||||
|
MlirType npcompQInt8TypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,12 @@ func @torch.operator(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.nd
|
||||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @torch.linear_params.create(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>) -> (!torch.LinearParams, !torch.LinearParams) {
|
||||||
|
%with_bias = torch.linear_params.create %arg0, %arg1 : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
%without_bias = torch.linear_params.create %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
|
||||||
|
}
|
||||||
|
|
||||||
func @derefine(%arg0: tensor<f32>) -> !torch.optional<tensor<f32>> {
|
func @derefine(%arg0: tensor<f32>) -> !torch.optional<tensor<f32>> {
|
||||||
%0 = torch.derefine %arg0 : tensor<f32> to !torch.optional<tensor<f32>>
|
%0 = torch.derefine %arg0 : tensor<f32> to !torch.optional<tensor<f32>>
|
||||||
return %0 : !torch.optional<tensor<f32>>
|
return %0 : !torch.optional<tensor<f32>>
|
||||||
|
|
Loading…
Reference in New Issue