Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions include/tvm/runtime/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include <atomic>
#include <functional>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -188,14 +189,25 @@ class Tensor : public tvm::ffi::Tensor {
*/
TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t nbytes,
TVMStreamHandle stream = nullptr);

TVM_DLL void SetScope(ffi::String scope);
TVM_DLL ffi::String GetScope() const;

protected:
/*!
* \brief The memory scope
* represents the underlaying scope information of device
*/
ffi::String scope = "global";
};

/*!
* \brief Save a DLTensor to stream
* \param strm The output stream
* \param tensor The tensor to be saved.
* \param scope The tensor storage scope.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String scope = "global");

inline void Tensor::CopyFrom(const DLTensor* other) {
ICHECK(data_ != nullptr);
Expand All @@ -220,10 +232,11 @@ inline void Tensor::CopyTo(const Tensor& other) const {
}

/*! \brief Magic number for Tensor file */
constexpr uint64_t kTVMTensorMagic = 0xDD5E40F096B4A13F;
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
constexpr uint64_t kTVMNDArrayScopedMagic = 0xDD5E40F096B4A13E;

inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
uint64_t header = kTVMTensorMagic, reserved = 0;
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor, ffi::String scope) {
uint64_t header = kTVMNDArrayScopedMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
// Always save data as CPU context
Expand All @@ -243,6 +256,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
strm->Write(std::string(scope));
int type_bytes = (tensor->dtype.bits + 7) / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
Expand All @@ -266,13 +280,14 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) {
return true;
}

inline void Tensor::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); }
inline void Tensor::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->(), GetScope()); }

inline bool Tensor::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
ICHECK(strm->Read(&header)) << "Invalid DLTensor file format";
ICHECK(strm->Read(&reserved)) << "Invalid DLTensor file format";
ICHECK(header == kTVMTensorMagic) << "Invalid DLTensor file format";
ICHECK((header == kTVMNDArrayMagic) || (header == kTVMNDArrayScopedMagic))
<< "Invalid DLTensor file format";
Device dev;
int ndim;
DLDataType dtype;
Expand All @@ -290,6 +305,11 @@ inline bool Tensor::Load(dmlc::Stream* strm) {
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
if (header == kTVMNDArrayScopedMagic) {
std::string scope;
strm->Read(&scope);
ret.SetScope(scope);
}
int64_t data_byte_size;
ICHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format";
ICHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format";
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,12 @@ TVM_DLL Pass DefaultGPUSchedule();
*/
TVM_DLL Pass UseAssumeToReduceBranches();

/*!
* \brief Inject Texture Allocation intrensic.
* \return The pass.
*/
TVM_DLL Pass InjectTextureAlloc();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/dlight/adreno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@
Adreno schedule rules.
"""
from .convolution import Conv2d
from .layout_transform import LayoutTransform
from .fallback import Fallback
from .pool import Pool2D

# from .fallback import Fallback
249 changes: 63 additions & 186 deletions python/tvm/dlight/adreno/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,215 +16,92 @@
# under the License.
# pylint: disable=missing-docstring, invalid-name
"""A Conv2d schedule rule for Adreno GPU operators."""
from dataclasses import dataclass
from typing import List, Optional
from typing import Optional, Union

from tvm import tir
from tvm.target import Target
from tvm.tir import IterVar
from tvm.tir.schedule.schedule import BlockRV

from ..analysis import BlockInfo, IterInfo
from .utils import schedule_inline_blocks, schedule_storage_annotate, schedule_default
from .. import analysis
from .base import AdrenoScheduleRule


def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.DataPar}


def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.CommReduce, IterVar.DataPar}


def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for producer in sch.get_producers(block):
result.append(producer)
result.extend(_collect_producers(sch, producer))
return result


def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for consumer in sch.get_consumers(block):
result.append(consumer)
result.extend(_collect_consumers(sch, consumer))
return result


def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo:
def _iter_kind(loop: tir.IterVar) -> str:
return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O")

def _is_reduction_block(block: tir.schedule.BlockRV):
for iter_var in sch.get(block).iter_vars:
if _iter_kind(iter_var) == "R":
return True
return False

return BlockInfo(
name=sch.get(block).name_hint,
iters=[
IterInfo(
kind=_iter_kind(iter_var),
var=iter_var.var,
dom=iter_var.dom.extent,
loop_rv=loop_rv,
)
for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars)
],
block_rv=block,
reduction_block=_is_reduction_block(block),
)


def get_reduction_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]) -> bool:
# NOTE: We assume there is only one reduction block in the function
# all blocks are required to be spatial or reduction
if not all(
[is_reduction_block(sch, block) or is_spatial_block(sch, block) for block in blocks]
):
return None

# There is only one reduction block
reduction_blocks = [block for block in blocks if is_reduction_block(sch, block)]
if len(reduction_blocks) != 1:
return None

return reduction_blocks[0]


def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV):
# TODO: Use buffer access patterns to discover convolution type kernels instead of using name.
return (
sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo")
and "".join([iter_type.kind for iter_type in get_block_info(sch, block).iters])
== "SSSSSRRR"
)


class Conv2d(AdrenoScheduleRule):
"""The schedule rule for convolution computation"""

@dataclass
class Config:
block_size_x: int = 8
block_size_y: int = 8
vector_size: int = 1
unroll: int = 256 # 0 means no unroll
use_shared: bool = True
storage_align: bool = False
inner_x: bool = False

def get_configs(self, target: Target) -> Config:
"""Get the schedule config for the target"""
if target.kind.name == "cuda" or target.kind.name == "rocm":
return Conv2d.Config(
block_size_x=8,
block_size_y=16,
vector_size=2,
unroll=256,
use_shared=True,
storage_align=True,
inner_x=False,
)
elif target.kind.name == "opencl" and (
("android" in str(target.host)) or ("adreno" in str(target.attrs))
):
return Conv2d.Config(
block_size_x=32,
block_size_y=4,
vector_size=8,
unroll=16,
use_shared=False,
storage_align=False,
inner_x=True,
)
else:
return Conv2d.Config()
@staticmethod
def schedule_conv2d(sch: tir.Schedule, blk: tir.schedule.BlockRV):
# TODO: Loop Pattern mayn't be reliable, need to perform better analysis.
n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk)

# bz, vz, tz = sch.split(oc, sch.sample_perfect_tile(oc, 3, 32))
# by, vy, ty = sch.split(oh, sch.sample_perfect_tile(oh, 3, 32))
# bx, vx, tx = sch.split(ow, sch.sample_perfect_tile(ow, 3, 32))

bz, vz, tz = sch.split(oc, [None, 8, 1], preserve_unit_iters=True)
by, vy, ty = sch.split(oh, [None, 1, 16], preserve_unit_iters=True)
bx, vx, tx = sch.split(ow, [None, 1, 16], preserve_unit_iters=True)

bz = sch.fuse(n, bz, preserve_unit_iters=True)
sch.reorder(bz, by, bx, vz, vy, vx, tz, ty, tx, ob)
sch.bind(bz, "blockIdx.z")
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(vz, "vthread.z")
sch.bind(vy, "vthread.y")
sch.bind(vx, "vthread.x")
sch.bind(tz, "threadIdx.z")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

rblk = sch.cache_read(blk, 0, "local")
ico, icb = sch.split(ic, [None, 4], preserve_unit_iters=True)
sch.reorder(ico, kh, kw, icb, ob)

sch.compute_at(rblk, kw, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(rblk)[-1])
wblk = sch.cache_write(blk, 0, "local")
sch.reverse_compute_at(wblk, tx, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(wblk)[-1])
init_blk = sch.decompose_reduction(blk, tx)
sch.vectorize(sch.get_loops(init_blk)[-1])

def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
func: Union[tir.PrimFunc],
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
if not (isinstance(func, (tir.PrimFunc, tir.Schedule))) or not self.is_target_available(
target
):
return None

if isinstance(func, tir.PrimFunc):
sch = tir.Schedule(func)
sch.work_on("main")
elif isinstance(func, tir.Schedule):
sch = func

# config = self.get_configs(target)
root_block = analysis.get_root_block(sch)
root_block = analysis.get_root_block(sch, sch.func_working_on)
blocks = sch.get_child_blocks(root_block)
reduction_block = get_reduction_blocks(sch, blocks)
reduction_blocks = list(
filter(lambda block: analysis.get_block_info(sch, block).is_reduction(), blocks)
)
remaining_blocks = [blk for blk in blocks if blk not in reduction_blocks]

if reduction_block is None:
return None
if not is_convolution(sch, reduction_block):
def is_convolution(blk):
block_info = analysis.get_block_info(sch, blk)
return "conv2d_NCHWc" in block_info.name

if len(reduction_blocks) != 1 or not is_convolution(reduction_blocks[0]):
return None

def schedule_data_pad(blk):
axes = sch.get_loops(blk)
axes, vec = axes[:-1], axes[-1]
axis = sch.fuse(*axes)
bx, ty, tx = sch.split(axis, [None, 16, 16])
sch.bind(bx, "blockIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)

def schedule_conv2d(blk):
# TODO: Loop Pattern mayn't be reliable, need to perform better analysis.
n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk)
sch.reorder(n, oc, oh, ow, ic, kh, kw, ob)
main_lp = sch.fuse(n, oc, oh, ow)
bx, ty, tx = sch.split(main_lp, [None, 16, 16])
sch.bind(tx, "threadIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(bx, "blockIdx.x")

ico, icv = sch.split(ic, [None, 4])
sch.reorder(ico, kh, kw, icv, ob)
rblk = sch.cache_read(blk, 0, "local")
sch.compute_at(rblk, kw)
sch.vectorize(sch.get_loops(rblk)[-1])
wblk = sch.cache_write(blk, 0, "local")
sch.reverse_compute_at(wblk, tx)
sch.vectorize(sch.get_loops(wblk)[-1])
sch.vectorize(ob)
init_blk = sch.decompose_reduction(blk, ico)
sch.vectorize(sch.get_loops(init_blk)[-1])

def is_data_pad(block: tir.stmt.Block):
return is_spatial_block(sch, block) and tir.analysis.has_if_then_else(sch.get(block))

def schedule_conv2d_blocks():

# Do analysis to find block type
blocks = sch.get_child_blocks(root_block)
passed_reduction = False
for blk in blocks:
if is_reduction_block(sch, blk):
schedule_conv2d(blk)
passed_reduction = True
elif is_data_pad(blk):
schedule_data_pad(blk)
elif is_spatial_block(sch, blk):
try:
if not passed_reduction:
sch.compute_inline(blk)
else:
sch.reverse_compute_inline(blk)
except: # pylint: disable=W0702
pass
else:
raise TypeError("Can't Schedule this Block", sch.get(blk))

schedule_conv2d_blocks()
# sch.set_scope(blocks[0], 0, "global.texture")
conv_blk = reduction_blocks[0]
Conv2d.schedule_conv2d(sch, conv_blk)
remaining_blocks = schedule_inline_blocks(sch, remaining_blocks)
schedule_default(sch, remaining_blocks)
schedule_storage_annotate(sch, remaining_blocks)

return sch
Loading
Loading