Add a number of kernels and new patterns.

* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_
* Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args.
* The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen.
* The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with.
* More progress on #97
pull/108/head
Stella Laurenzo 2020-11-03 19:24:28 -08:00
parent 3dab9056f0
commit 6c702b149f
17 changed files with 920 additions and 454 deletions

View File

@ -11,4 +11,5 @@ export PYTHONPATH="${build_dir}/python"
python -m torch_mlir_utils.codegen.torch_signature_ods_gen \
--ods_td_file="${aten_dir}/GeneratedATenOps.td" \
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc"
--ods_impl_file="${aten_dir}/GeneratedATenOps.cpp.inc" \
--debug_op_reg_file="${aten_dir}/ATenOpRegistrations.txt"

View File

@ -189,9 +189,8 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
return constValue;
}
MlirValue
FuncBuilder::buildList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements) {
MlirValue FuncBuilder::buildList(MlirLocation loc,
llvm::SmallVectorImpl<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(context);
OperationStateHolder state{"basicpy.build_list", loc};
mlirOperationStateAddResults(state, 1, &resultType);

View File

@ -13,6 +13,7 @@ import logging
import re
import sys
import textwrap
import traceback
# Note that this utility exists only in the c-extension.
from _torch_mlir import get_registered_ops
@ -75,6 +76,55 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_unary_op(f"aten::{uname}(Tensor)",
f"{snakecase_to_camelcase(uname)}Op", uname)
# Convolution ops. Note that these are special in PyTorch and the importer,
# and we model them after the signatures of the convolution_overrideable
# ops (generic for non-CPU/GPU backends) but set the names according to
# how they come in.
g.print_banner("NN ops")
g.ordinary_immutable_op(
"aten::convolution_overrideable(Tensor,Tensor,Tensor?,int[],int[],int[],bool,int[],int)",
"ConvolutionOp",
"convolution",
alias_kernel_names=["aten::convolution"])
g.ordinary_immutable_op(
"aten::convolution_backward_overrideable(Tensor,Tensor,Tensor,int[],int[],int[],bool,int[],int,bool[])",
"ConvolutionBackwardOp",
"convolution_backward",
alias_kernel_names=["aten::convolution_backward"],
# These do return as None but are not coded optional in the registry :(
override_return_types=["Tensor?", "Tensor?", "Tensor?"])
g.ordinary_immutable_op("aten::_log_softmax(Tensor,int,bool)",
"LogSoftmaxOp", "log_softmax")
g.ordinary_immutable_op(
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
# Loss functions.
g.print_banner("Loss function ops")
g.ordinary_immutable_op(
"aten::nll_loss_forward(Tensor,Tensor,Tensor?,int,int)",
"NllLossForwardOp", "nll_loss_forward")
# Note also a grad_input 8-arg variant.
g.ordinary_immutable_op(
"aten::nll_loss_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
"NllLossBackwardOp", "nll_loss_backward")
g.ordinary_immutable_op(
"aten::nll_loss2d_forward(Tensor,Tensor,Tensor?,int,int)",
"NllLoss2dForwardOp", "nll_loss2d_forward")
# Note also a grad_input 8-arg variant.
g.ordinary_immutable_op(
"aten::nll_loss2d_backward(Tensor,Tensor,Tensor,Tensor?,int,int,Tensor)",
"NllLoss2dBackwardOp", "nll_loss2d_backward")
# One-off in-place ops (note that many in-place arithmetic ops are handled
# as a transformation from their immutable forms).
g.ordinary_inplace_op("aten::copy_(Tensor,Tensor,bool)",
"CopyInplaceOp",
"copy.inplace",
drop_arg_indices=[2])
def dump_registered_ops(outfile, reg_ops_dict):
for k in sorted(reg_ops_dict.keys()):
@ -104,7 +154,20 @@ class OpGenerator:
)
em.print("")
def ordinary_binary_op(self, kernel_sig, ods_name, op_name):
def define_op(self, kernel_sig, ods_name, op_name, **kwargs):
return InflightOpDef(self,
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
**kwargs)
def ordinary_binary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
""""Binary"-ops. These ops typically have:
- '.Tensor' variant where the second arg is a Tensor
- '.Scalar' variant where the second arg is a Scalar
@ -124,75 +187,130 @@ class OpGenerator:
- Setting all arguments and returns to kImmutableTensor
- Enabling kPromoteScalarToTensor on the second argument.
"""
reg_record = self._get_reg_record(kernel_sig)
ods_ins, arg_type_flags = self._map_sigtypes(
reg_record["arguments"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
"Tensor:1": "AnyTorchImmutableTensor",
"Scalar:1": "AnyTorchImmutableTensor",
"Scalar": "AnyTorchScalarType",
},
flag_transforms={
":0": ["kImmutableTensor"],
":1": ["kImmutableTensor", "kPromoteScalar"],
})
ods_outs, return_type_flags = self._map_sigtypes(
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
self.ods_emitter.emit_opdef(ods_name,
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
"Tensor:1": "AnyTorchImmutableTensor",
"Scalar:1": "AnyTorchImmutableTensor",
"Scalar": "AnyTorchScalarType",
},
flag_transforms={
":0": ["kImmutableTensor"],
":1": ["kImmutableTensor", "kPromoteScalar"],
})
opdef.return_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
opdef.emit()
def ordinary_unary_op(self, kernel_sig, ods_name, op_name):
def ordinary_immutable_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
""""An ordinary immutable-tensor based op."""
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.transforms(type_transforms={
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
},
flag_transforms={
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
opdef.emit()
def ordinary_inplace_op(self, kernel_sig, ods_name, op_name, **kwargs):
"""In-place ops (ending in '_').
These ops take a mutable first argument and then standard immutable
conversions for subsequent. When emitting into MLIR, the return value is
dropped.
"""
opdef = self.define_op(kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
**kwargs)
opdef.arg_transforms(type_transforms={
":0": "AnyTorchMutableTensor",
"Tensor": "AnyTorchImmutableTensor",
"Tensor?": "AnyTorchOptionalImmutableTensor",
"int": "AnyTorchIntType",
"int[]": "AnyTorchIntListType",
"bool": "AnyTorchBoolType",
"bool[]": "AnyTorchBoolListType",
},
flag_transforms={
":0": [],
"Tensor": ["kImmutableTensor"],
"Tensor?": ["kImmutableTensor"],
})
opdef.return_transforms(
type_transforms={
":0": "DROP_UNUSED", # Ignored because we clear the outs below.
},
flag_transforms={
":0": ["kDropReturnAndAliasArg0"],
})
opdef.map_signatures()
opdef.ods_outs = [] # Clear the computed outs.
opdef.emit()
def ordinary_unary_op(self,
kernel_sig,
ods_name,
op_name,
promote_trailing_out_tensor=True,
traits=(),
**kwargs):
"""Unary ops.
These take and return a tensor and typically have an out and inplace
variant (they may not but we generate patterns to match anyway).
"""
reg_record = self._get_reg_record(kernel_sig)
ods_ins, arg_type_flags = self._map_sigtypes(
reg_record["arguments"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
ods_outs, return_type_flags = self._map_sigtypes(
reg_record["returns"],
type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
self.ods_emitter.emit_opdef(ods_name,
op_name,
reg_record,
ods_ins=ods_ins,
ods_outs=ods_outs,
traits=["NoSideEffect"])
self.impl_emitter.emit_kernel_methods(ods_name,
reg_record,
arg_type_flags=arg_type_flags,
return_type_flags=return_type_flags,
promote_trailing_out_tensor=True)
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
opdef.return_transforms(type_transforms={
"Tensor:0": "AnyTorchImmutableTensor",
},
flag_transforms={
":0": ["kImmutableTensor"],
})
opdef.emit()
def _get_reg_record(self, kernel_sig):
def get_reg_record(self, kernel_sig):
"""Gets the op-dict for a given registered op name.
Args:
@ -212,8 +330,14 @@ class OpGenerator:
raise ValueError(f"Could not find registry op matching '{kernel_sig}'. "
f"Possible matches:\n {dym_message}")
def _map_sigtypes(self, siglist: List[Dict], type_transforms: Dict[str, str],
flag_transforms: Dict[str, List[str]]) -> List[Tuple[str]]:
def _map_sigtypes(
self,
siglist: List[Dict],
type_transforms: Dict[str, str],
flag_transforms: Dict[str, List[str]],
drop_indices: Sequence[int] = (),
override_types: Optional[Sequence[str]] = None,
) -> List[Tuple[str]]:
"""Maps a list of signature entries to ods dags and flag lists.
The torch signature list contains dicts that minimally have keys 'name' and
@ -233,15 +357,23 @@ class OpGenerator:
- An ods dag list of (ods_name, ods_type) tuples
- List of (torch_type, [conversion_flag]) for specifying conversions.
"""
# Make sure any override types are sane.
if override_types:
assert len(override_types) == len(siglist), (
"Mismatch override and actual types")
# Generate to ods dag list.
ods_dag_list = []
for i, sigitem in enumerate(siglist):
if i in drop_indices:
# Do not emit in ODS.
continue
torch_name = sigitem["name"]
torch_type = sigitem["type"]
torch_type = (sigitem["type"]
if override_types is None else override_types[i])
# Look up the type transform.
ods_type = (type_transforms.get(f"{torch_type}:{i}") or
type_transforms.get(f":{i}") or
type_transforms.get(torch_type))
ods_type = _first_non_none(type_transforms.get(f"{torch_type}:{i}"),
type_transforms.get(f":{i}"),
type_transforms.get(torch_type))
if not ods_type:
raise ValueError(f"Signature item {i}, type {torch_type} did not match "
f"a type transform {type_transforms}")
@ -250,16 +382,130 @@ class OpGenerator:
# Generate the type conversion flags.
type_flag_list = []
for i, sigitem in enumerate(siglist):
torch_type = sigitem["type"]
torch_type = (sigitem["type"]
if override_types is None else override_types[i])
# Look up the type transform.
flags = (flag_transforms.get(f"{torch_type}:{i}") or
flag_transforms.get(f":{i}") or flag_transforms.get(torch_type))
if not flags:
flags = []
if i in drop_indices:
flags = ["kDrop"]
else:
flags = _first_non_none(flag_transforms.get(f"{torch_type}:{i}"),
flag_transforms.get(f":{i}"),
flag_transforms.get(torch_type))
if flags is None:
flags = []
type_flag_list.append((torch_type, flags))
return ods_dag_list, type_flag_list
class InflightOpDef:
def __init__(self,
g: OpGenerator,
kernel_sig,
ods_name,
op_name,
traits=(),
alias_kernel_names=(),
promote_trailing_out_tensor=False,
override_arg_types=None,
override_return_types=None,
drop_arg_indices=(),
drop_return_indices=()):
super().__init__()
self.g = g
self.kernel_sig = kernel_sig
self.ods_name = ods_name
self.op_name = op_name
self.traits = list(traits)
self.alias_kernel_names = list(alias_kernel_names)
self.promote_trailing_out_tensor = promote_trailing_out_tensor
self.override_arg_types = override_arg_types
self.override_return_types = override_return_types
self.drop_arg_indices = drop_arg_indices
self.drop_return_indices = drop_return_indices
self.reg_record = g.get_reg_record(self.kernel_sig)
self._emitted = False
self._traceback = traceback.extract_stack()[0:-2]
# Arg and flag transform dicts.
self.arg_type_transforms = dict()
self.arg_flag_transforms = dict()
self.return_type_transforms = dict()
self.return_flag_trasforms = dict()
# Signature mapping.
self._sigs_mapped = False
self.ods_ins = None
self.ods_outs = None
self.arg_type_flags = None
self.return_type_flags = None
def __del__(self):
if not self._emitted:
print("WARNING: Op defined but not emitted. Defined at:", file=sys.stderr)
for line in traceback.format_list(self._traceback):
sys.stderr.write(line)
def transforms(self, type_transforms=None, flag_transforms=None):
self.arg_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
self.return_transforms(type_transforms=type_transforms,
flag_transforms=flag_transforms)
return self
def arg_transforms(self, type_transforms=None, flag_transforms=None):
"""Adds arg type and flag transforms dicts."""
if type_transforms:
self.arg_type_transforms.update(type_transforms)
if flag_transforms:
self.arg_flag_transforms.update(flag_transforms)
return self
def return_transforms(self, type_transforms=None, flag_transforms=None):
"""Adds return type and flag transform dicts."""
if type_transforms:
self.return_type_transforms.update(type_transforms)
if flag_transforms:
self.return_flag_trasforms.update(flag_transforms)
return self
def map_signatures(self):
assert not self._sigs_mapped, "Signatures already mapped"
self._sigs_mapped = True
self.ods_ins, self.arg_type_flags = self.g._map_sigtypes(
self.reg_record["arguments"],
type_transforms=self.arg_type_transforms,
flag_transforms=self.arg_flag_transforms,
override_types=self.override_arg_types,
drop_indices=self.drop_arg_indices)
self.ods_outs, self.return_type_flags = self.g._map_sigtypes(
self.reg_record["returns"],
type_transforms=self.return_type_transforms,
flag_transforms=self.return_flag_trasforms,
override_types=self.override_return_types,
drop_indices=self.drop_return_indices)
return self
def emit(self):
assert not self._emitted, "Op already emitted"
self._emitted = True
if not self._sigs_mapped:
self.map_signatures()
self.g.ods_emitter.emit_opdef(self.ods_name,
self.op_name,
self.reg_record,
ods_ins=self.ods_ins,
ods_outs=self.ods_outs,
traits=self.traits)
self.g.impl_emitter.emit_kernel_methods(
self.ods_name,
self.reg_record,
arg_type_flags=self.arg_type_flags,
return_type_flags=self.return_type_flags,
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
alias_kernel_names=self.alias_kernel_names)
class EmitterBase:
_INDENT = " "
@ -355,7 +601,8 @@ class CCImplEmitter(EmitterBase):
reg_record,
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
promote_trailing_out_tensor=False):
promote_trailing_out_tensor=False,
alias_kernel_names: Sequence[str] = ()):
# getTorchKernelMetadata() method.
self.print(
f"Torch::KernelMetadata {ods_def_name}::getTorchKernelMetadata() {{")
@ -374,6 +621,9 @@ class CCImplEmitter(EmitterBase):
with self.indent():
self.print("Torch::BuildKernelMetadata m;")
self.print(f"m.kernelName = {self.quote(kernel_name)};")
for alias_kernel_name in alias_kernel_names:
self.print(
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
if promote_trailing_out_tensor:
self.print("m.promoteTrailingOutTensor = true;")
# Arg types/flags.
@ -393,7 +643,7 @@ class CCImplEmitter(EmitterBase):
self.print("return m;")
self.print("})();")
self.print("return metadata;")
self.print("}")
self.print("}\n")
def _format_cpp_str_initlist(self, strings):
quoted = [self.quote(s) for s in strings]
@ -416,6 +666,13 @@ def snakecase_to_camelcase(ident: str):
return "".join(x.capitalize() or "_" for x in re.split(r"[\._]", ident))
def _first_non_none(*args):
for arg in args:
if arg is not None:
return arg
return None
def _load_ops_as_dict():
# Returns a list of dicts, each with a name that is a tuple of the form:
# (kernel_signature, variant)

View File

@ -0,0 +1 @@
ATenOpRegistrations.txt

View File

@ -61,55 +61,6 @@ def aten_ConstantOp: aten_Op<"constant", [NoSideEffect]>,
}
// Our jit library only supports 6 argument convolutions, rather than 9
// arguments supported by pytorch. This operation allows us to represent this
// limitation temporarily.
def aten_ConvolutionOp: aten_Op<"_convolution", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation
);
let summary = "Convolution operator";
let description = [{
Convolution operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
uint64_t getOperandTransferVolume(unsigned int idx, bool read);
uint64_t getResultTransferVolume(unsigned int idx, bool read);
}];
}
// Our jit library only supports 6 argument convolutions, rather than 9
// arguments supported by pytorch. This operation allows us to represent this
// limitation temporarily.
def aten_ConvolutionBackwardOp: aten_Op<"_convolution_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor:$dx, AnyTensor:$dw, AnyTensor:$db)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$input,
AnyTensor:$weight,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation
);
let summary = "ConvolutionBackward operator";
let description = [{
ConvolutionBackward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_FlattenOp: aten_Op<"flatten", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (

View File

@ -37,6 +37,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata Atan2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -55,6 +56,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata DivOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -73,6 +75,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata FloorDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -91,6 +94,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata MulOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -109,6 +113,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata RemainderOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -127,6 +132,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata TrueDivideOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -145,6 +151,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Unary arithmetic ops
// -----------------------------------------------------------------------------
@ -167,6 +174,7 @@ const Torch::BuildKernelMetadata &AbsOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata AcosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -185,6 +193,7 @@ const Torch::BuildKernelMetadata &AcosOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata AngleOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -203,6 +212,7 @@ const Torch::BuildKernelMetadata &AngleOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata AsinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -221,6 +231,7 @@ const Torch::BuildKernelMetadata &AsinOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata AtanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -239,6 +250,7 @@ const Torch::BuildKernelMetadata &AtanOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata CeilOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -257,6 +269,7 @@ const Torch::BuildKernelMetadata &CeilOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ConjOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -275,6 +288,7 @@ const Torch::BuildKernelMetadata &ConjOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata CosOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -293,6 +307,7 @@ const Torch::BuildKernelMetadata &CosOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata CoshOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -311,6 +326,7 @@ const Torch::BuildKernelMetadata &CoshOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata DigammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -329,6 +345,7 @@ const Torch::BuildKernelMetadata &DigammaOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ErfOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -347,6 +364,7 @@ const Torch::BuildKernelMetadata &ErfOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ErfcOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -365,6 +383,7 @@ const Torch::BuildKernelMetadata &ErfcOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ErfinvOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -383,6 +402,7 @@ const Torch::BuildKernelMetadata &ErfinvOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ExpOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -401,6 +421,7 @@ const Torch::BuildKernelMetadata &ExpOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata Expm1Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -419,6 +440,7 @@ const Torch::BuildKernelMetadata &Expm1Op::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata FloorOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -437,6 +459,7 @@ const Torch::BuildKernelMetadata &FloorOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata FracOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -455,6 +478,7 @@ const Torch::BuildKernelMetadata &FracOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata LgammaOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -473,6 +497,7 @@ const Torch::BuildKernelMetadata &LgammaOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata LogOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -491,6 +516,7 @@ const Torch::BuildKernelMetadata &LogOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata Log10Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -509,6 +535,7 @@ const Torch::BuildKernelMetadata &Log10Op::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata Log1pOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -527,6 +554,7 @@ const Torch::BuildKernelMetadata &Log1pOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata Log2Op::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -545,6 +573,7 @@ const Torch::BuildKernelMetadata &Log2Op::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata NegOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -563,6 +592,7 @@ const Torch::BuildKernelMetadata &NegOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ReluOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -581,6 +611,7 @@ const Torch::BuildKernelMetadata &ReluOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata ReciprocalOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -599,6 +630,7 @@ const Torch::BuildKernelMetadata &ReciprocalOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata RoundOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -617,6 +649,7 @@ const Torch::BuildKernelMetadata &RoundOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata RsqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -635,6 +668,7 @@ const Torch::BuildKernelMetadata &RsqrtOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata SigmoidOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -653,6 +687,7 @@ const Torch::BuildKernelMetadata &SigmoidOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata SignOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -671,6 +706,7 @@ const Torch::BuildKernelMetadata &SignOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata SinOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -689,6 +725,7 @@ const Torch::BuildKernelMetadata &SinOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata SinhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -707,6 +744,7 @@ const Torch::BuildKernelMetadata &SinhOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata SqrtOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -725,6 +763,7 @@ const Torch::BuildKernelMetadata &SqrtOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata TanOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -743,6 +782,7 @@ const Torch::BuildKernelMetadata &TanOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata TanhOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -761,6 +801,7 @@ const Torch::BuildKernelMetadata &TanhOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
Torch::KernelMetadata TruncOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
@ -779,3 +820,184 @@ const Torch::BuildKernelMetadata &TruncOp::getTorchBuildKernelMetadata() {
})();
return metadata;
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata ConvolutionOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::convolution_overrideable";
m.aliasKernelNames.push_back("aten::convolution");
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &ConvolutionBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::convolution_backward_overrideable";
m.aliasKernelNames.push_back("aten::convolution_backward");
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor?", "Tensor?", "Tensor?"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LogSoftmaxOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LogSoftmaxOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::_log_softmax";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "int", "bool"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata LogSoftmaxBackwardDataOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::_log_softmax_backward_data";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------
Torch::KernelMetadata NllLossForwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLossForwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss_forward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor", "Tensor"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLossBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLossBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss_backward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLoss2dForwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLoss2dForwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss2d_forward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor", "Tensor"});
m.addReturnConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata NllLoss2dBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &NllLoss2dBackwardOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::nll_loss2d_backward";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor", "Tensor?", "int", "int", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata CopyInplaceOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &CopyInplaceOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::copy_";
m.addArgTypes({"Tensor", "Tensor", "bool"});
m.addArgConversions({KVC::kNone, KVC::kImmutableTensor, KVC::kDrop});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kDropReturnAndAliasArg0});
return m;
})();
return metadata;
}

View File

@ -450,3 +450,147 @@ def aten_TruncOp: aten_Op<"trunc", [NoSideEffect, DeclareOpInterfaceMethods<Torc
);
}
// -----------------------------------------------------------------------------
// NN ops
// -----------------------------------------------------------------------------
def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::convolution_overrideable";
let arguments = (ins
AnyTorchImmutableTensor:$input,
AnyTorchImmutableTensor:$weight,
AnyTorchOptionalImmutableTensor:$bias,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchBoolType:$transposed,
AnyTorchIntListType:$output_padding,
AnyTorchIntType:$groups
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$input,
AnyTorchImmutableTensor:$weight,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchBoolType:$transposed,
AnyTorchIntListType:$output_padding,
AnyTorchIntType:$groups,
AnyTorchBoolListType:$output_mask
);
let results = (outs
AnyTorchOptionalImmutableTensor:$grad_input,
AnyTorchOptionalImmutableTensor:$grad_weight,
AnyTorchOptionalImmutableTensor:$grad_bias
);
}
def aten_LogSoftmaxOp: aten_Op<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::_log_softmax";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchIntType:$dim,
AnyTorchBoolType:$half_to_float
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::_log_softmax_backward_data";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$output,
AnyTorchIntType:$dim,
AnyTorchImmutableTensor:$self
);
let results = (outs
AnyTorchImmutableTensor
);
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss_forward";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index
);
let results = (outs
AnyTorchImmutableTensor:$output,
AnyTorchImmutableTensor:$total_weight
);
}
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss_backward";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index,
AnyTorchImmutableTensor:$total_weight
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss2d_forward";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index
);
let results = (outs
AnyTorchImmutableTensor:$output,
AnyTorchImmutableTensor:$total_weight
);
}
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::nll_loss2d_backward";
let arguments = (ins
AnyTorchImmutableTensor:$grad_output,
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$target,
AnyTorchOptionalImmutableTensor:$weight,
AnyTorchIntType:$reduction,
AnyTorchIntType:$ignore_index,
AnyTorchImmutableTensor:$total_weight
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::copy_";
let arguments = (ins
AnyTorchMutableTensor:$self,
AnyTorchImmutableTensor:$src
);
let results = (outs
);
}

View File

@ -44,52 +44,6 @@ def aten_AsStridedOp: aten_Op<"as_strided", [NoSideEffect, StatisticsOpInterface
}];
}
def aten_ConvolutionOverrideableOp: aten_Op<"convolution_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$input,
AnyTensor:$weight,
AnyTensor:$bias,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_overrideable operator";
let description = [{
ConvolutionOverrideableOp
aten convolution_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_ConvolutionBackwardOverrideableOp: aten_Op<"convolution_backward_overrideable", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$input,
AnyTensor:$weight,
AnyType:$stride,
AnyType:$padding,
AnyType:$dilation,
AnyScalar:$transposed,
AnyType:$output_padding,
AnyScalar:$groups
);
let summary = "aten convolution_backward_overrideable operator";
let description = [{
ConvolutionBackwardOverrideableOp
aten convolution_backward_overrideable operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_DivUnderOp: aten_Op<"div_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
@ -123,35 +77,6 @@ def aten_ExpandOp: aten_Op<"expand", [NoSideEffect, StatisticsOpInterface]>,
}];
}
def aten_LogSoftmaxOp: aten_Op<"_log_softmax", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyScalar:$dim,
AnyScalar:$half_to_float
);
let summary = "aten _log_softmax operator";
let description = [{
LogSoftmaxOp
aten _log_softmax operator
}];
}
def aten_LogSoftmaxBackwardDataOp: aten_Op<"_log_softmax_backward_data", [NoSideEffect]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$output,
AnyScalar:$dim,
AnyTensor:$self
);
let summary = "aten _log_softmax_backward_data operator";
let description = [{
LogSoftmaxBackwardDataOp
aten _log_softmax_backward_data operator
}];
}
def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
@ -445,86 +370,6 @@ def aten_GatherOp: aten_Op<"gather", [NoSideEffect, StatisticsOpInterface]>,
}];
}
def aten_NllLossForwardOp: aten_Op<"nll_loss_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss_forward operator";
let description = [{
NllLossForwardOp
aten nll_loss_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLossBackwardOp: aten_Op<"nll_loss_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss_backward operator";
let description = [{
NllLossBackwardOp
aten nll_loss_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dForwardOp: aten_Op<"nll_loss2d_forward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor, AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index
);
let summary = "aten nll_loss2d_forward operator";
let description = [{
NllLoss2dForwardOp
aten nll_loss2d_forward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_NllLoss2dBackwardOp: aten_Op<"nll_loss2d_backward", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$grad_output,
AnyTensor:$self,
AnyTensor:$target,
AnyTensor:$weight,
AnyScalar:$reduction,
AnyScalar:$ignore_index,
AnyTensor:$total_weight
);
let summary = "aten nll_loss2d_backward operator";
let description = [{
NllLoss2dBackwardOp
aten nll_loss2d_backward operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_HardtanhOp: aten_Op<"hardtanh", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (

View File

@ -12,15 +12,4 @@
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "npcomp/Dialect/ATen/IR/ATenOps.td"
// The pytorch convolution operator has 9 arguments, but we only have a jit
// library that supports the first six at the moment.
def : Pat<(aten_ConvolutionOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
$a7, $a8, $a9),
(aten_ConvolutionOp $a1, $a2, $a3, $a4, $a5, $a6)>;
def : Pat<(aten_ConvolutionBackwardOverrideableOp $a1, $a2, $a3, $a4, $a5, $a6,
$a7, $a8, $a9),
(aten_ConvolutionBackwardOp $a1, $a2, $a3, $a4, $a5, $a6)>;
#endif

View File

@ -18,7 +18,7 @@ namespace Torch {
/// Conversion rule to apply to a value (argument or return).
namespace KernelValueConversion {
enum BitMask {
enum BitMask : uint32_t {
// No coercion.
kNone = 0,
@ -32,7 +32,16 @@ enum BitMask {
// to a 0d tensor.
kPromoteScalar = 8,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kPromoteScalar)
// Drops the return value and aliases to argument 0.
// TODO: Remove this in favor of general alias metadata processing (note that
// the vast majority are this one case so it isn't so bad to have a special
// case for it if necessary).
kDropReturnAndAliasArg0 = 16,
// Drops the argument/return.
kDrop = 32,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ kDrop)
};
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
} // namespace KernelValueConversion
@ -74,6 +83,9 @@ struct BuildKernelMetadata : public KernelMetadata {
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
/// Additional alias kernel names to match.
SmallVector<StringRef, 1> aliasKernelNames;
void addArgConversions(
std::initializer_list<KernelValueConversion::BitMask> ilist) {
argConversions.insert(argConversions.end(), ilist);

View File

@ -72,6 +72,11 @@ def AnyTorchImmutableTensor : AnyTypeOf<[
AnyTensor,
], "allowable torch immutable tensor">;
def AnyTorchOptionalImmutableTensor : AnyTypeOf<[
AnyTorchImmutableTensor,
Basicpy_NoneType,
], "allowable torch immutable tensor (or None)">;
def AnyTorchMutableTensor : AnyTypeOf<[
// "Numpy-style" mutable NDArray. While not offering the full generality
// of a Torch tensor, it models the same access patterns and implies the
@ -95,7 +100,28 @@ def AnyTorchScalarType : AnyTypeOf<[
AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchBoolType : AnyTypeOf<[
I1,
Basicpy_BoolType,
], "Any permissible bool type">;
def AnyTorchBoolListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any bool list type (bool[])">;
def AnyTorchIntType : AnyTypeOf<[
AnySignedInteger,
AnySignlessInteger,
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
def AnyTorchIntListType : AnyTypeOf<[
Basicpy_ListType,
// TODO: Support typed list when available.
], "Any int list type (int[])">;
def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_ListType,

View File

@ -170,40 +170,6 @@ std::map<std::string, uint64_t> BatchNormOp::getStatistics() {
return toReturn;
}
// _convolution
std::map<std::string, uint64_t> ConvolutionOp::getStatistics() {
return getConv2dStatistics(this, /*groups*/ 1);
}
std::map<std::string, uint64_t> ConvolutionOverrideableOp::getStatistics() {
// FIXME
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue();
return getConv2dStatistics(this, groups);
}
uint64_t ConvolutionOp::getOperandTransferVolume(unsigned int idx, bool read) {
return getConv2dOperandTransferVolume<ConvolutionOp>(this, idx, read);
}
uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
return getConv2dResultTransferVolume<ConvolutionOp>(this, idx, write);
}
// _convolution_backward
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
return getConv2dBackwardStatistics(*this, 1);
}
std::map<std::string, uint64_t>
ConvolutionBackwardOverrideableOp::getStatistics() {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue();
return getConv2dBackwardStatistics(*this, groups);
}
// div_
std::map<std::string, uint64_t> DivUnderOp::getStatistics() {
@ -559,35 +525,6 @@ std::map<std::string, uint64_t> NativeBatchNormBackwardOp::getStatistics() {
return toReturn;
}
std::map<std::string, uint64_t> NllLossForwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLossBackwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLoss2dForwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
std::map<std::string, uint64_t> NllLoss2dBackwardOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
// FIXME: unimplemented
toReturn["reads"] = -1;
toReturn["writes"] = -1;
return toReturn;
}
// std::map<std::string, uint64_t> ReLUUnderOp::getStatistics() {
// return getReLUOpStatistics(*this);
// }

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
@ -28,6 +29,14 @@ using namespace mlir::NPCOMP::Torch;
namespace {
bool isTorchTensorType(StringRef torchType) {
return torchType == "Tensor" || torchType == "Tensor?";
}
bool isTorchOptionalType(StringRef torchType) {
return torchType.endswith("?");
}
struct TypeConversion {
Type targetType;
std::function<Value(Location loc, Value originalValue,
@ -49,9 +58,15 @@ convertTorchArgType(StringRef sourceTorchType, StringRef targetTorchType,
// Immutable tensor conversion.
if (flag & KVC::kImmutableTensor) {
// TODO: Support the kPromoteScalar flag.
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType))
return None;
// If the target is optional and the type is NoneType, passthrough.
if (isTorchOptionalType(targetTorchType) &&
sourceMlirType.isa<Basicpy::NoneType>())
return TypeConversion{sourceMlirType, nullptr};
// Already immutable.
if (sourceMlirType.isa<TensorType>())
return TypeConversion{sourceMlirType, nullptr};
@ -86,30 +101,51 @@ convertTorchReturnType(StringRef sourceTorchType, StringRef targetTorchType,
Type sourceMlirType) {
using KVC = KernelValueConversion::BitMask;
// Default trivial case.
if (sourceTorchType == targetTorchType && flag == 0)
if (sourceTorchType == targetTorchType && flag == 0) {
LLVM_DEBUG(llvm::dbgs() << " * Return types already match\n");
return TypeConversion{sourceMlirType, nullptr};
}
// Immutable tensor conversion.
if (flag & KVC::kImmutableTensor) {
if (sourceTorchType != "Tensor" || targetTorchType != "Tensor")
LLVM_DEBUG(llvm::dbgs()
<< " * Return conversion flag kImmutableTensor\n");
if (!isTorchTensorType(sourceTorchType) ||
!isTorchTensorType(targetTorchType)) {
LLVM_DEBUG(llvm::dbgs()
<< " * Source or target not a Tensor type\n");
return None;
}
// Already immutable.
if (sourceMlirType.isa<TensorType>())
if (sourceMlirType.isa<TensorType>()) {
LLVM_DEBUG(llvm::dbgs() << " * Source is already immutable\n");
return TypeConversion{sourceMlirType, nullptr};
}
// Convert NdArray type.
if (auto ndArrayType = sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
if (sourceMlirType.isa<Basicpy::NoneType>() &&
isTorchOptionalType(targetTorchType)) {
LLVM_DEBUG(llvm::dbgs() << " * None Tensor type passthrough\n");
return TypeConversion{sourceMlirType, nullptr};
} else if (auto ndArrayType =
sourceMlirType.dyn_cast<Numpy::NdArrayType>()) {
auto tensorType = ndArrayType.toTensorType();
auto callback = [=](Location loc, Value newOpResultValue,
PatternRewriter &rewriter) -> Value {
return rewriter.create<Numpy::CreateArrayFromTensorOp>(
loc, ndArrayType, newOpResultValue);
};
LLVM_DEBUG(llvm::dbgs() << " * Convert return type\n");
return TypeConversion{tensorType, callback};
} else {
LLVM_DEBUG(llvm::dbgs()
<< " * Return type is not a supported tensor type\n");
return None;
}
}
LLVM_DEBUG(llvm::dbgs() << " * Return type conversion fallthrough\n");
return None;
}
@ -142,9 +178,16 @@ public:
const BuildKernelMetadata &buildMetadata) {
LLVM_DEBUG(llvm::dbgs()
<< "Register kernel call translation for: " << opName << "\n");
CandidateTransformList &candidates =
kernelTransforms[buildMetadata.kernelName];
candidates.emplace_back(opName, buildMetadata);
{
CandidateTransformList &candidates =
kernelTransforms[buildMetadata.kernelName];
candidates.emplace_back(opName, buildMetadata);
}
for (StringRef aliasKernelName : buildMetadata.aliasKernelNames) {
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
candidates.emplace_back(opName, buildMetadata);
}
}
LogicalResult transformKernelCall(KernelCallOp kernelCall,
@ -229,6 +272,8 @@ public:
"arg arity mismatch");
// Convert fixed return types.
using PostConversionCallback = std::function<void()>;
SmallVector<PostConversionCallback, 4> postConversionCallbacks;
struct ConversionInfo {
Value originalValue;
TypeConversion conversion;
@ -241,25 +286,49 @@ public:
KVC flag = candidate.buildMetadata.getReturnConversion(i);
Value sourceValue = kernelCall.getResult(i);
Type sourceMlirType = kernelCall.getResultTypes()[i];
auto conversion = convertTorchReturnType(sourceTorchType, targetTorchType,
flag, sourceMlirType);
if (!conversion) {
LLVM_DEBUG(llvm::dbgs() << " - Return type[" << i
<< "] incompatible: source=" << sourceTorchType
<< ", target=" << targetTorchType
<< ", flag=" << flag << "\n");
return failure();
if (flag & KVC::kDropReturnAndAliasArg0) {
// Reduce result arity and alias any uses to arg0.
if (kernelCall.args().empty()) {
LLVM_DEBUG(llvm::dbgs()
<< " - Cannot alias arg0 (no arguments)\n");
return failure();
}
Value arg0 = kernelCall.args()[0];
postConversionCallbacks.push_back(
[sourceValue, arg0]() { sourceValue.replaceAllUsesWith(arg0); });
} else {
// General, arity-preserving type conversion.
auto conversion = convertTorchReturnType(
sourceTorchType, targetTorchType, flag, sourceMlirType);
if (!conversion) {
LLVM_DEBUG(llvm::dbgs()
<< " - Return type[" << i << "] incompatible: source="
<< sourceTorchType << ", target=" << targetTorchType
<< ", flag=" << flag << "\n");
return failure();
}
resultTypes.push_back(conversion->targetType);
resultConversions.push_back({sourceValue, std::move(*conversion)});
}
resultTypes.push_back(conversion->targetType);
resultConversions.push_back({sourceValue, std::move(*conversion)});
}
// Convert fixed arg types.
SmallVector<ConversionInfo, 4> operandInfos;
for (size_t i = 0; i < fixedArgArity; ++i) {
for (size_t i = 0, operandIndex = 0; i < fixedArgArity; ++i) {
// Drop this arg?
if (candidate.buildMetadata.argConversions[i] & KVC::kDrop)
continue;
if (kernelCall.getNumOperands() <= operandIndex) {
LLVM_DEBUG(llvm::dbgs()
<< " - Arg operand " << i
<< " does not exist in kernel call (missing default?)\n");
return failure();
}
// Normal type conversion of the operand.
operandInfos.emplace_back();
ConversionInfo &info = operandInfos.back();
info.originalValue = kernelCall.getOperand(i);
info.originalValue = kernelCall.getOperand(operandIndex++);
Type sourceMlirType = info.originalValue.getType();
auto conversion = convertTorchArgType(
/*sourceTorchType=*/sourceMetadata.argTypes[i],
@ -312,6 +381,10 @@ public:
origOpResultValue.replaceAllUsesWith(convertedValue);
}
// Post conversion callbacks.
for (auto &callback : postConversionCallbacks)
callback();
// Done.
rewriter.eraseOp(kernelCall);
return success();

View File

@ -1,25 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-_convolution-0": {
// CHECK-NEXT: "activation_in": 32768,
// CHECK-NEXT: "activation_out": 65536,
// CHECK-NEXT: "ops:+": 65536,
// CHECK-NEXT: "ops:MAC": 6422528,
// CHECK-NEXT: "parameters_in": 1584,
// CHECK-NEXT: "reads": 34352,
// CHECK-NEXT: "writes": 65536
module {
func @graph(%arg0: tensor<1x2x128x128xf32>, %arg1: tensor<16x2x7x7xf32>, %arg2: tensor<16xf32>) -> tensor<1x16x64x64xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi64>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<3> : vector<2xi64>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi64>} : () -> !aten.list<i32>
%3 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%4 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi64>} : () -> !aten.list<i32>
%5 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%6 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%7 = "aten.constant"() {type = "bool", value = 0 : i1} : () -> i1
%8 = "aten.constant"() {type = "bool", value = 1 : i1} : () -> i1
%9 = "aten._convolution"(%arg0, %arg1, %arg2, %0, %1, %2) : (tensor<1x2x128x128xf32>, tensor<16x2x7x7xf32>, tensor<16xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<1x16x64x64xf32>
"std.return"(%9) : (tensor<1x16x64x64xf32>) -> ()
}
}

View File

@ -1,23 +0,0 @@
// RUN: npcomp-opt %s -aten-layer-name -aten-op-report |& FileCheck %s
// CHECK-LABEL: "L0-convolution_backward_overrideable-0": {
// CHECK-NEXT: "activation_in": 5568,
// CHECK-NEXT: "grad": 5380,
// CHECK-NEXT: "ops:+": 768,
// CHECK-NEXT: "ops:MAC": 345600,
// CHECK-NEXT: "parameters_in": 576,
// CHECK-NEXT: "reads": 6144,
// CHECK-NEXT: "writes": 5380
// CHECK-NEXT: }
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
// CHECK-CONVERSION-LABEL: @graph
module {
func @graph(%arg0: tensor<3x4x8x8xf32>, %arg1: tensor<3x16x10x10xf32>, %arg2: tensor<4x16x3x3xf32>) -> tensor<4x16x3x3xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%10:3 = "aten.convolution_backward_overrideable"(%arg0, %arg1, %arg2, %0, %1, %0, %2, %1, %3) {layer_name = "L5-convolution_backward_overrideable-0"} : (tensor<3x4x8x8xf32>, tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1, !aten.list<i32>, i32) -> (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>)
return %10#1 : tensor<4x16x3x3xf32>
}
}

View File

@ -1,26 +0,0 @@
// RUN: npcomp-opt %s -aten-to-std |& FileCheck %s --check-prefix=CHECK-CONVERSION
// CHECK-CONVERSION-LABEL: @graph
module {
func @graph(%arg0: tensor<10xf32>, %arg1: tensor<128xf32>, %arg2: tensor<4x1x28x28xf32>, %arg3: tensor<32x1x3x3xf32>, %arg4: tensor<32xf32>, %arg5: tensor<64x32x3x3xf32>, %arg6: tensor<64xf32>, %arg7: tensor<128x9216xf32>, %arg8: tensor<10x128xf32>) -> tensor<4x10xf32> {
%0 = "aten.constant"() {type = "List[i32]", value = dense<1> : vector<2xi32>} : () -> !aten.list<i32>
%1 = "aten.constant"() {type = "List[i32]", value = dense<0> : vector<2xi32>} : () -> !aten.list<i32>
%2 = "aten.constant"() {type = "bool", value = false} : () -> i1
%3 = "aten.constant"() {type = "i32", value = 1 : i32} : () -> i32
%4 = "aten.constant"() {type = "bool", value = true} : () -> i1
%5 = "aten._convolution"(%arg2, %arg3, %arg4, %0, %1, %0) {layer_name = "L0-_convolution-0"} : (tensor<4x1x28x28xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x32x26x26xf32>
%6 = "aten.relu"(%5) {layer_name = "L1-relu-0"} : (tensor<4x32x26x26xf32>) -> tensor<4x32x26x26xf32>
%7 = "aten._convolution"(%6, %arg5, %arg6, %0, %1, %0) {layer_name = "L2-_convolution-1"} : (tensor<4x32x26x26xf32>, tensor<64x32x3x3xf32>, tensor<64xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>) -> tensor<4x64x24x24xf32>
%8 = "aten.constant"() {type = "List[i32]", value = dense<2> : vector<2xi32>} : () -> !aten.list<i32>
%9:2 = "aten.max_pool2d_with_indices"(%7, %8, %8, %1, %0, %2) {layer_name = "L3-max_pool2d_with_indices-0"} : (tensor<4x64x24x24xf32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, !aten.list<i32>, i1) -> (tensor<4x64x12x12xf32>, tensor<4x64x12x12xi64>)
%10 = "aten.constant"() {type = "List[i32]", value = dense<[4, 9216]> : vector<2xi32>} : () -> !aten.list<i32>
%11 = "aten.view"(%9#0, %10) {layer_name = "L4-view-0"} : (tensor<4x64x12x12xf32>, !aten.list<i32>) -> tensor<4x9216xf32>
%12 = "aten.t"(%arg7) {layer_name = "L5-t-0"} : (tensor<128x9216xf32>) -> tensor<9216x128xf32>
%13 = "aten.addmm"(%arg1, %11, %12, %3, %3) {layer_name = "L6-addmm-0"} : (tensor<128xf32>, tensor<4x9216xf32>, tensor<9216x128xf32>, i32, i32) -> tensor<4x128xf32>
%14 = "aten.relu"(%13) {layer_name = "L7-relu-1"} : (tensor<4x128xf32>) -> tensor<4x128xf32>
%15 = "aten.t"(%arg8) {layer_name = "L8-t-1"} : (tensor<10x128xf32>) -> tensor<128x10xf32>
%16 = "aten.addmm"(%arg0, %14, %15, %3, %3) {layer_name = "L9-addmm-1"} : (tensor<10xf32>, tensor<4x128xf32>, tensor<128x10xf32>, i32, i32) -> tensor<4x10xf32>
%17 = "aten._log_softmax"(%16, %3, %2) {layer_name = "L10-_log_softmax-0"} : (tensor<4x10xf32>, i32, i1) -> tensor<4x10xf32>
return %17 : tensor<4x10xf32>
}
}

View File

@ -13,3 +13,86 @@ func @graph(%arg0: !numpy.ndarray<*:?>, %arg1 : !numpy.ndarray<*:?>, %arg2 : si6
// CHECK: return %[[RESULT_MUT]]
return %0 : !numpy.ndarray<*:?>
}
// -----
// CHECK-LABEL: func @nll_loss2d_forward
// Contains a Tensor? type mapped to None.
func @nll_loss2d_forward(
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
%arg1: !numpy.ndarray<[3,8,8]:i64>,
%arg2: !basicpy.NoneType,
%arg3: i64,
%arg4: i64) -> (!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>) {
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: %[[TOUTPUT:.*]], %[[TTOTAL_WEIGHT:.*]] = "aten.nll_loss2d_forward"(%[[TARG0]], %[[TARG1]], %arg2, %arg3, %arg4) : (tensor<3x4x8x8xf32>, tensor<3x8x8xi64>, !basicpy.NoneType, i64, i64) -> (tensor<f32>, tensor<f32>)
// CHECK: %[[AOUTPUT:.*]] = numpy.create_array_from_tensor %[[TOUTPUT]]
// CHECK: %[[ATOTAL_WEIGHT:.*]] = numpy.create_array_from_tensor %[[TTOTAL_WEIGHT]]
%0:2 = torch.kernel_call "aten::nll_loss2d_forward"
%arg0, %arg1, %arg2, %arg3, %arg4 :
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,8,8]:i64>, !basicpy.NoneType, i64, i64) ->
(!numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>)
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor"]}
// CHECK: return %[[AOUTPUT]], %[[ATOTAL_WEIGHT]]
return %0#0, %0#1 : !numpy.ndarray<[]:f32>, !numpy.ndarray<[]:f32>
}
// -----
// CHECK-LABEL: func @convolution
// Contains a Tensor?, bool, int and list types.
func @convolution(
%arg0: !numpy.ndarray<[3,16,10,10]:f32>, %arg1: !numpy.ndarray<[4,16,3,3]:f32>,
%arg2: !numpy.ndarray<[4]:f32>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType,
%arg5: !basicpy.ListType, %arg6: i1, %arg7: !basicpy.ListType, %arg8: i64) -> !numpy.ndarray<[3,4,8,8]:f32> {
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
// CHECK: %[[TRESULT:.*]] = "aten.convolution"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (tensor<3x16x10x10xf32>, tensor<4x16x3x3xf32>, tensor<4xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> tensor<3x4x8x8xf32>
%0 = torch.kernel_call "aten::convolution"
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 :
(!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>,
!numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType,
!basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
{sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "bool", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
return %0 : !numpy.ndarray<[3,4,8,8]:f32>
}
// -----
// CHECK-LABEL: func @convolution_backward
// Interesting because it has optional tensor returns.
func @convolution_backward(
%arg0: !numpy.ndarray<[3,4,8,8]:f32>,
%arg1: !numpy.ndarray<[3,16,10,10]:f32>,
%arg2: !numpy.ndarray<[4,16,3,3]:f32>,
%arg3: !basicpy.ListType,
%arg4: !basicpy.ListType,
%arg5: !basicpy.ListType,
%arg6: i1,
%arg7: !basicpy.ListType,
%arg8: i64,
%arg9: !basicpy.ListType) -> (!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {
// CHECK: %[[GRAD_INPUT:.*]], %[[GRAD_WEIGHT:.*]], %[[GRAD_BIAS:.*]] = "aten.convolution_backward"
// Note that this kernel call masks out the input gradients, which will return as NoneType.
%0:3 = torch.kernel_call "aten::convolution_backward"
%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 :
(!numpy.ndarray<[3,4,8,8]:f32>, !numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64, !basicpy.ListType) ->
(!basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>) {sigArgTypes = ["Tensor", "Tensor", "Tensor", "int[]", "int[]", "int[]", "bool", "int[]", "int", "bool[]"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor", "Tensor", "Tensor"]}
// CHECK: %[[AGRAD_WEIGHT:.*]] = numpy.create_array_from_tensor %[[GRAD_WEIGHT]]
// CHECK: %[[AGRAD_BIAS:.*]] = numpy.create_array_from_tensor %[[GRAD_BIAS]]
// Key thing: The return returns the raw NoneType from the masked input gradient
// and it does not get converted to an array.
// CHECK: return %[[GRAD_INPUT]], %[[AGRAD_WEIGHT]], %[[AGRAD_BIAS]]
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
}
// -----
// CHECK-LABEL: func @copy_inplace
// Mutable/in-place op conversion, dropping result.
func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: "aten.copy.inplace"(%arg0, %[[TARG1]]) : (!numpy.ndarray<[4]:f32>, tensor<4xf32>) -> ()
%0 = torch.kernel_call "aten::copy_" %arg0, %arg1 : (!numpy.ndarray<[4]:f32>, !numpy.ndarray<[4]:f32>) -> !numpy.ndarray<[4]:f32> {sigArgTypes = ["Tensor", "Tensor", "bool"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
// CHECK: return %arg0
return %0 : !numpy.ndarray<[4]:f32>
}