Get simple quantized model importing.

This is enough to import the program and get it through the compilation
pipeline. It of course fails at the VerifyBackendContract pass since
there is a lot missing, but the final IR for a simple quantized MLP is
looking pretty decent already:
[IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d)

Main changes:
- Add support for importing torch quantized tensors, including
  `torch.per_tensor_affine.create` op and `!torch.qint8` element type.
- Add support for importing `LinearPackedParamsBase` (basically a weight
  + optional bias, but requires `torch.linear_params.create` op +
  `!torch.LinearParams` type to model it). This was less painful than I
  expected, as it has the necessary methods to opaquely unpack itself. I
  factored things so it should be easy to extend to other custom classes
  like `ConvPackedParamsBase`.
- Add minimal boilerplate for importing `quantized::*` ops, with
  `quantized::linear` being a motivating example.
- Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark).

This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as
really the proper semantics of `!torch.qint8` dtype on a Torch tensor is
"check the quantizer object of the tensor for side data (scale/offset,
possibly per-channel) that defines the full semantics of the tensor". We
don't have any such notion of "side data" for `!numpy.ndarray` /
`tensor`, let alone anything that would have the associated behavior of
keying off the dtype to determine if the side data is present.
This will be fixed by a proper `!torch.tensor` type.
pull/217/head
Sean Silva 2021-05-19 11:40:48 -07:00
parent 0c89296075
commit d66e8fe1f8
12 changed files with 442 additions and 19 deletions

View File

@ -19,6 +19,7 @@
#include "npcomp-c/Types.h"
#include "caffe2/core/scope_guard.h"
#include "ATen/native/quantized/cpu/packed_params.h"
using namespace torch_mlir;
@ -101,10 +102,11 @@ public:
: importBlock(importBlock), context(context), typeMapper(context),
annotator(annotator) {}
MlirValue importIValue(c10::IValue value);
MlirValue importIValue(c10::IValue ivalue);
private:
MlirValue rawImportIValue(c10::IValue value);
MlirValue rawImportIValue(c10::IValue ivalue);
MlirValue importTensor(c10::IValue ivalue);
MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation);
@ -284,16 +286,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
MlirOperation constant = createMlirOperationAtEnd(
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirOperation ndarray = createMlirOperationAtEnd(
importBlock, "numpy.create_array_from_tensor", loc,
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
mlirOperationGetResult(constant, 0));
return mlirOperationGetResult(ndarray, 0);
return importTensor(ivalue);
}
if (ivalue.isModule()) {
return importModule(ivalue.toModule());
@ -313,11 +306,83 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isCustomClass()) {
if (ivalue.type().get() ==
c10::getCustomClassType<c10::intrusive_ptr<LinearPackedParamsBase>>()
.get()) {
c10::intrusive_ptr<LinearPackedParamsBase> linearParams =
ivalue.toCustomClass<LinearPackedParamsBase>();
at::Tensor weight;
c10::optional<at::Tensor> bias;
std::tie(weight, bias) = linearParams->unpack();
MlirValue weightValue = importIValue(c10::IValue(weight));
c10::optional<MlirValue> biasValue = c10::nullopt;
if (bias.has_value()) {
biasValue = importIValue(c10::IValue(*bias));
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.linear_params.create", loc,
npcompLinearParamsTypeGet(context), weightValue, biasValue);
return mlirOperationGetResult(operation, 0);
}
}
std::stringstream msg;
msg << "Unsupported ivalue: " << ivalue;
throw std::invalid_argument(msg.str());
}
MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
assert(ivalue.isTensor() && "expected a tensor!");
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
// Import the bulk tensor representation.
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
MlirOperation constant = createMlirOperationAtEnd(
importBlock, "std.constant", loc, mlirAttributeGetType(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(constant, 0);
// Construct the complete tensor value. This is trivial for most tensors, but
// for quantized tensors (and probably sparse too, TBD) there is more for us
// to do.
MlirValue tensorValue;
if (tensor.is_quantized()) {
// Note that Torch models quantization in a type-erased way. So we don't
// make an effort here to do any special static modeling. If desired, later
// compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type()), {nullptr});
if (tensor.qscheme() == c10::kPerTensorAffine) {
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
MlirOperation quantizedTensor = createMlirOperationAtEnd(
importBlock, "torch.per_tensor_affine.create", loc,
quantizedTensorType, tensorReprValue, qScale, zeroPoint);
tensorValue = mlirOperationGetResult(quantizedTensor, 0);
} else {
std::stringstream msg;
msg << "Unsupported quantization scheme '"
<< c10::toString(tensor.qscheme()) << "' for tensor: " << ivalue;
throw std::invalid_argument(msg.str());
}
} else {
tensorValue = tensorReprValue;
}
// Convert the tensor to ndarray to match Torch's default-mutable semantics.
MlirOperation ndarray = createMlirOperationAtEnd(
importBlock, "numpy.create_array_from_tensor", loc,
npcompNdArrayTypeGetUnranked(npcompAnyDtypeTypeGet(context)),
tensorValue);
return mlirOperationGetResult(ndarray, 0);
}
void IValueImporter::importMethod(torch::jit::Function *function,
MlirBlock classTypeBody,
const MethodAnnotation &methodAnnotation) {

View File

@ -67,12 +67,59 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
return mlirBF16TypeGet(context);
case ScalarType::Half:
return mlirF16TypeGet(context);
case ScalarType::QInt8:
return npcompQInt8TypeGet(context);
default: {
return {nullptr};
}
}
}
// Types (such as `LinearPackedParamsBase`) implemented with the
// `torch::CustomClassHolder` mechanism described at
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
// are modeled with ordinary c10::ClassType's, but require special handling
// for importing.
//
// These types correspond to c10::IValue's with `isCustomClass() == true`.
//
// Under the hood, Torch represents such "custom classes" using the
// "object" variant of c10::IValue and a class type with one slot holding a
// type-erased c10::intrusive_ptr to the custom type. One the side, it keeps
// a registry of custom classes which is used to implement `isCustomClass()`
// by checking names against the registry.
//
// There is no generic way to import custom classes (or their types), so we
// have to name match them here (and the relevant code in the ivalue
// importer) and create special IR constructs for them.
static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
const c10::ClassTypePtr &classType) {
// If the type is unnamed, it cannot be a custom class.
if (!classType->name().has_value()) {
return {nullptr};
}
std::string name = classType->name()->qualifiedName();
// If the type is not stored in the custom class registry, it cannot be a
// custom class.
if (!torch::getCustomClass(name)) {
return {nullptr};
}
// Individually handle the custom classes that we know about.
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
return npcompLinearParamsTypeGet(context);
}
// At this point, we know that the type is indeed a custom class type, but
// that we don't know how to specially import it. We cannot proceed, so emit a
// diagnostic and halt compilation.
std::stringstream message;
message << "unable to import Torch CustomClass type '" << classType
<< "' to MLIR type";
mlirEmitError(loc, message.str().c_str());
throw mlir_diagnostic_emitted();
}
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
using c10::TypeKind;
@ -106,10 +153,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
}
case TypeKind::ClassType: {
auto maybeName = torchType->cast<c10::ClassType>()->name();
return npcompNnModuleTypeGet(
context, toMlirStringRef(maybeName ? maybeName->qualifiedName()
: "unnamed class"));
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
MlirType customClassType = mapCustomClassType(context, loc, classType);
if (!mlirTypeIsNull(customClassType)) {
return customClassType;
}
auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompNnModuleTypeGet(context, toMlirStringRef(name));
}
case TypeKind::FloatType: {
return mlirF64TypeGet(context);
@ -220,6 +271,14 @@ MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
// The generalized (non-Tensor) conversion, assumes that Bool is the
// Basicpy bool type.
elementType = mlirIntegerTypeGet(context, 1);
} else if (tensor.scalar_type() == ScalarType::QInt8) {
// This function returns the underlying integer representation of the tensor
// as an elements attr. That underlying representation is of type i8
// for a torch.qint8 tensor.
// Caller code is responsible for materializing the proper op that
// incorporates the quantization scheme to create a tensor of `!torch.qint8`
// element type.
elementType = mlirIntegerTypeGet(context, 8);
} else {
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
}
@ -259,6 +318,9 @@ MlirAttribute torch_mlir::converTensorToMlirElementsAttr(at::Tensor tensor,
return mlirDenseElementsAttrBoolGet(shapedType, numElements,
static_cast<const int *>(tensorData));
break;
case ScalarType::QInt8:
return mlirDenseElementsAttrInt8Get(
shapedType, numElements, static_cast<const int8_t *>(tensorData));
default:
throwUnsupportedTensorError();
}

View File

@ -25,6 +25,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
import basic
import vision_models
import mlp
import quantized_models
def main():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')

View File

@ -0,0 +1,58 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch import nn
from torch_mlir.torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export
# ==============================================================================
class QuantizedMLP(nn.Module):
def __init__(self):
super().__init__()
torch.random.manual_seed(0)
self.layers = nn.Sequential(
nn.Linear(16, 8),
nn.Tanh(),
nn.Linear(8, 4),
)
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()
@export
@export
@annotate_args([
None,
([1, 16], torch.float32),
])
def forward(self, x):
x = self.quantize(x)
x = self.layers(x)
x = self.dequantize(x)
return x
def get_mlp_input():
return 2 * torch.rand((1, 16)) - 1
def get_quantized_mlp():
model = QuantizedMLP()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_mlp_input())
torch.quantization.convert(model, inplace=True)
return model
@register_test_case(module_factory=get_quantized_mlp)
def QuantizedMLP_basic(module, tu: TestUtils):
module.forward(get_mlp_input())

View File

@ -213,6 +213,8 @@ SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]]
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]
# Mapping from torch types to their corresponding ODS type predicates.
# Use `get_ods_type` instead of using this directly.
TORCH_TYPE_TO_ODS_TYPE = {
"Tensor": "AnyTorchTensorType",
"Tensor?": "AnyTorchOptionalTensor",
@ -229,9 +231,18 @@ TORCH_TYPE_TO_ODS_TYPE = {
"Any": "AnyTorchType",
"Device": "Torch_DeviceType",
"str": "Basicpy_BytesType",
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
}
def get_ods_type(type: str):
ods_type = TORCH_TYPE_TO_ODS_TYPE.get(type)
if ods_type is None:
raise Exception(
f"{type!r} not in TORCH_TYPE_TO_ODS_TYPE mapping. Please add it!")
return ods_type
def _get_main_module_name() -> str:
# pytype: disable=attribute-error
return sys.modules["__main__"].__loader__.name
@ -287,7 +298,7 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
p("Variadic<AnyTorchType>:$operands")
else:
p(",\n".join([
f"""{TORCH_TYPE_TO_ODS_TYPE[arg["type"]]}:${arg["name"]}"""
f"""{get_ods_type(arg["type"])}:${arg["name"]}"""
for arg in operator.arguments
]))
p(");")
@ -297,7 +308,7 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
p("Variadic<AnyTorchType>:$results")
else:
p(",\n".join([
f"""{TORCH_TYPE_TO_ODS_TYPE[ret["type"]]}:${ret["name"] or "result"}"""
f"""{get_ods_type(ret["type"])}:${ret["name"] or "result"}"""
for ret in operator.returns
]))
p(");")
@ -444,6 +455,19 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::len.t : (t[]) -> (int)", has_canonicalizer=True)
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
with open(td_file, "w") as f:
f.write(ODS_BANNER)
def emit(key, **kwargs):
emit_op(registry[key], f, **kwargs)
emit(
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
traits=["HasValueSemantics"])
def dump_registered_ops(outfile: TextIO, registry: Registry):
for _, v in sorted(registry.by_unique_key.items()):
outfile.write(repr(v))
@ -460,6 +484,7 @@ def main(args: argparse.Namespace):
dump_registered_ops(debug_registry_dump, registry)
emit_prim_ops(args.torch_ir_dir, registry)
emit_aten_ops(args.torch_ir_dir, registry)
emit_quantized_ops(args.torch_ir_dir, registry)
def _create_argparse() -> argparse.ArgumentParser:

View File

@ -0,0 +1,44 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.quantized.Linear(5, 2, dtype=torch.qint8)
self.linear_no_bias = torch.nn.quantized.Linear(6,
2,
bias_=False,
dtype=torch.qint8)
# CHECK-DAG: %[[SCALE:.*]] = basicpy.numeric_constant {{.*}} : f64
# CHECK-DAG: %[[ZERO_POINT:.*]] = basicpy.numeric_constant 0 : i64
# CHECK-DAG: %[[INT_REPR:.*]] = constant dense<{{.*}}> : tensor<2x5xi8>
# CHECK-DAG: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : tensor<2x5xi8>, f64, i64 -> tensor<2x5x!torch.qint8>
# CHECK-DAG: %[[WEIGHTS_ARRAY:.*]] = numpy.create_array_from_tensor %[[WEIGHTS]] : (tensor<2x5x!torch.qint8>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK-DAG: %[[BIAS:.*]] = constant dense<{{.*}}> : tensor<2xf32>
# CHECK-DAG: %[[BIAS_ARRAY:.*]] = numpy.create_array_from_tensor %[[BIAS]] : (tensor<2xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK-DAG: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS_ARRAY]], %[[BIAS_ARRAY]] : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
@torch.jit.export
def test_linear(self, t):
return self.linear(t)
# CHECK: %[[LINEAR_PARAMS_NO_BIAS:.*]] = torch.linear_params.create %{{.*}} : !numpy.ndarray<*:!numpy.any_dtype>{{$}}
@torch.jit.export
def test_linear_no_bias(self, t):
return self.linear_no_bias(t)
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -147,6 +147,26 @@ int npcompTypeIsADevice(MlirType t);
/** Gets the !torch.Device type. */
MlirType npcompDeviceTypeGet(MlirContext context);
/*============================================================================*/
/* torch.LinearParams type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.LinearParams type */
int npcompTypeIsALinearParams(MlirType t);
/** Gets the !torch.LinearParams type. */
MlirType npcompLinearParamsTypeGet(MlirContext context);
/*============================================================================*/
/* torch.qint8 type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.qint8 type */
int npcompTypeIsAQInt8(MlirType t);
/** Gets the !torch.qint8 type. */
MlirType npcompQInt8TypeGet(MlirContext context);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,35 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Operation summaries and descriptions were systematically derived from public
// API docstrings and are licensed accordingly:
// https://github.com/pytorch/pytorch/blob/master/LICENSE
//===----------------------------------------------------------------------===//
//
// This file is automatically generated. Please do not edit.
// Generated via:
// python -m torch_mlir_utils.codegen.torch_ods_gen
//
//===----------------------------------------------------------------------===//
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics,
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$X,
Torch_LinearParamsType:$W_prepack,
AnyFloat:$Y_scale_i,
AnyTorchIntType:$Y_zero_point_i
);
let results = (outs
AnyTorchTensorType:$Y
);
let assemblyFormat = "$X `,` $W_prepack `,` $Y_scale_i `,` $Y_zero_point_i attr-dict `:` type($X) `,` type($W_prepack) `,` type($Y_scale_i) `,` type($Y_zero_point_i) `->` type($Y)";
}

View File

@ -22,6 +22,7 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
//===----------------------------------------------------------------------===//
// TorchScript `torch.nn.Module` object instantiation ops.
@ -439,7 +440,9 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
let hasCanonicalizer = 1;
}
def Torch_OperatorOp : Torch_Op<"operator", []> {
def Torch_OperatorOp : Torch_Op<"operator", [
AllowsTypeRefinement
]> {
let summary = "Opaque torch operator";
let description = [{
Represents an invocation of a `torch::jit::Operator` for which we don't
@ -458,4 +461,48 @@ def Torch_OperatorOp : Torch_Op<"operator", []> {
}];
}
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
AllowsTypeRefinement
]> {
let summary = "Create a `!torch.LinearParams`";
let arguments = (ins
AnyTorchTensorType:$weight,
Optional<AnyTorchTensorType>:$bias
);
let results = (outs Torch_LinearParamsType:$result);
let assemblyFormat = [{
$weight (`,` $bias^)? attr-dict `:` type($weight) (`,` type($bias)^)?
}];
}
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
AllowsTypeRefinement
]> {
let summary = "Create a per-tensor-affine quantized tensor";
let description = [{
Create a quantized tensor.
Quantization formula is:
```
Q(x, scale, zero_point) = round(x/scale + zero_point)
```
See:
https://pytorch.org/docs/stable/quantization.html#quantized-tensors
}];
let arguments = (ins
AnyTorchTensorType:$int_repr,
AnyFloat:$scale,
AnyTorchIntType:$offset
);
// TODO: Limit to quantized dtypes (e.g. !torch.qint8).
let results = (outs AnyTorchTensorType:$result);
let assemblyFormat = [{
$int_repr `,` $scale `,` $offset attr-dict
`:` type($int_repr) `,` type($scale) `,` type($offset) `->` type($result)
}];
}
#endif // TORCH_OPS

View File

@ -83,6 +83,37 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> {
let summary = "Type modeling `ScalarType::QInt8`";
let description = [{
This is intended to be a 1:1 match for the Torch `ScalarType` types.
Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of
types, it is deemed preferable to import them as one-off ad-hoc types
instead of a single parameterized type.
}];
}
def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
let summary = "Torch packed linear params type";
let description = [{
A weight and optional bias, packed into one value.
This is used to model the
`__torch__.torch.classes.quantized.LinearPackedParamsBase` custom C++ class
type which is the input to some Torch `quantized::` ops.
We may want to eventually have a full set of ops that model the
LinearPackedParamsBase interface, such as `apply`, `apply_relu`, etc.
But we instead are likely to just expand the `quantized::` ops directly
and fold away the instances of this type.
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
very library-call-y, runtime-y thing that embodies a number of assumptions
about the structure of how the program will be executed, which need not hold
for npcomp backends.
}];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
@ -189,6 +220,7 @@ def AnyTorchType : AnyTypeOf<[
Torch_NnModuleType,
Torch_OptionalType,
Torch_DeviceType,
Torch_LinearParamsType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_TYPES

View File

@ -188,3 +188,31 @@ int npcompTypeIsADevice(MlirType t) {
MlirType npcompDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
/*============================================================================*/
/* torch.LinearParams type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.LinearParams type */
int npcompTypeIsALinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>();
}
/** Gets the !torch.LinearParams type. */
MlirType npcompLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context)));
}
/*============================================================================*/
/* torch.qint8 type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.qint8 type */
int npcompTypeIsAQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>();
}
/** Gets the !torch.qint8 type. */
MlirType npcompQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context)));
}

View File

@ -7,6 +7,12 @@ func @torch.operator(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.nd
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}
func @torch.linear_params.create(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>) -> (!torch.LinearParams, !torch.LinearParams) {
%with_bias = torch.linear_params.create %arg0, %arg1 : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
%without_bias = torch.linear_params.create %arg0 : !numpy.ndarray<*:!numpy.any_dtype>
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
}
func @derefine(%arg0: tensor<f32>) -> !torch.optional<tensor<f32>> {
%0 = torch.derefine %arg0 : tensor<f32> to !torch.optional<tensor<f32>>
return %0 : !torch.optional<tensor<f32>>