Introduction
I am currently maintaining a powerful open source creative drawing board. This drawing board integrates a lot of interesting brushes and auxiliary drawing functions, which allows users to experience a new drawing effect. Whether on mobile or PC, you can enjoy a better interactive experience and effect display.
In this article, I will explain in detail how to combine Transformers.js to achieve background removal and image marking segmentation. The result is as follows
Link: https://songlh.top/paint-board/
Github: https://github.com/LHRUN/paint-board Welcome to Star ⭐️
Transformers.js
Transformers.js is a powerful JavaScript library based on Hugging Face's Transformers that can be run directly in the browser without relying on server-side computation. This means that you can run your models locally, increasing efficiency and reducing deployment and maintenance costs.
Currently Transformers.js has provided 1000+ models on Hugging Face, covering various domains, which can satisfy most of your needs, such as image processing, text generation, translation, sentiment analysis and other tasks processing, you can easily achieve through Transformers.js. Search for models as follows.
The current major version of Transformers.js has been updated to V3, which adds a lot of great features, details: Transformers.js v3: WebGPU Support, New Models & Tasks, and More….
Both of the features I've added to this post use WebGpu support, which is only available in V3, and has greatly improved processing speed, with parsing now in the milliseconds. However, it should be noted that there are not many browsers that support WebGPU, so it is recommended to use the latest version of Google to visit.
Function 1: Remove background
To remove the background I use the Xenova/modnet model, which looks like this
The processing logic can be divided into three steps
- initialise the state, and load the model and processor.
- the display of the interface, this is based on your own design, not on mine.
- Show the effect, this is based on your own design, not mine. Nowadays it is more popular to use a border line to dynamically display the contrast effect before and after removing the background.
The code logic is as follows, React + TS
, see my project's source code for details, the source code is located in src/components/boardOperation/uploadImage/index.tsx
import { useState, FC, useRef, useEffect, useMemo } from 'react'
import {
env,
AutoModel,
AutoProcessor,
RawImage,
PreTrainedModel,
Processor
} from '@huggingface/transformers'
const REMOVE_BACKGROUND_STATUS = {
LOADING: 0,
NO_SUPPORT_WEBGPU: 1,
LOAD_ERROR: 2,
LOAD_SUCCESS: 3,
PROCESSING: 4,
PROCESSING_SUCCESS: 5
}
type RemoveBackgroundStatusType =
(typeof REMOVE_BACKGROUND_STATUS)[keyof typeof REMOVE_BACKGROUND_STATUS]
const UploadImage: FC<{ url: string }> = ({ url }) => {
const [removeBackgroundStatus, setRemoveBackgroundStatus] =
useState<RemoveBackgroundStatusType>()
const [processedImage, setProcessedImage] = useState('')
const modelRef = useRef<PreTrainedModel>()
const processorRef = useRef<Processor>()
const removeBackgroundBtnTip = useMemo(() => {
switch (removeBackgroundStatus) {
case REMOVE_BACKGROUND_STATUS.LOADING:
return 'Remove background function loading'
case REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU:
return 'WebGPU is not supported in this browser, to use the remove background function, please use the latest version of Google Chrome'
case REMOVE_BACKGROUND_STATUS.LOAD_ERROR:
return 'Remove background function failed to load'
case REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS:
return 'Remove background function loaded successfully'
case REMOVE_BACKGROUND_STATUS.PROCESSING:
return 'Remove Background Processing'
case REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS:
return 'Remove Background Processing Success'
default:
return ''
}
}, [removeBackgroundStatus])
useEffect(() => {
;(async () => {
try {
if (removeBackgroundStatus === REMOVE_BACKGROUND_STATUS.LOADING) {
return
}
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOADING)
// Checking WebGPU Support
if (!navigator?.gpu) {
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU)
return
}
const model_id = 'Xenova/modnet'
if (env.backends.onnx.wasm) {
env.backends.onnx.wasm.proxy = false
}
// Load model and processor
modelRef.current ??= await AutoModel.from_pretrained(model_id, {
device: 'webgpu'
})
processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS)
} catch (err) {
console.log('err', err)
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_ERROR)
}
})()
}, [])
const processImages = async () => {
const model = modelRef.current
const processor = processorRef.current
if (!model || !processor) {
return
}
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING)
// load image
const img = await RawImage.fromURL(url)
// Pre-processed image
const { pixel_values } = await processor(img)
// Generate image mask
const { output } = await model({ input: pixel_values })
const maskData = (
await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
img.width,
img.height
)
).data
// Create a new canvas
const canvas = document.createElement('canvas')
canvas.width = img.width
canvas.height = img.height
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D
// Draw the original image
ctx.drawImage(img.toCanvas(), 0, 0)
// Updating the mask area
const pixelData = ctx.getImageData(0, 0, img.width, img.height)
for (let i = 0; i < maskData.length; ++i) {
pixelData.data[4 * i + 3] = maskData[i]
}
ctx.putImageData(pixelData, 0, 0)
// Save new image
setProcessedImage(canvas.toDataURL('image/png'))
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS)
}
return (
<div className="card shadow-xl">
<button
className={`btn btn-primary btn-sm ${
![
REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS,
REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS,
undefined
].includes(removeBackgroundStatus)
? 'btn-disabled'
: ''
}`}
onClick={processImages}
>
Remove background
</button>
<div className="text-xs text-base-content mt-2 flex">
{removeBackgroundBtnTip}
</div>
<div className="relative mt-4 border border-base-content border-dashed rounded-lg overflow-hidden">
<img
className={`w-[50vw] max-w-[400px] h-[50vh] max-h-[400px] object-contain`}
src={url}
/>
{processedImage && (
<img
className={`w-full h-full absolute top-0 left-0 z-[2] object-contain`}
src={processedImage}
/>
)}
</div>
</div>
)
}
export default UploadImage
Function 2: Image Marker Segmentation
The image marker segmentation is implemented using the Xenova/slimsam-77-uniform model. The effect is as follows, you can click on the image after it is loaded, and the segmentation is generated according to the coordinates of your click.
The processing logic can be divided into five steps
- initialise the state, and load the model and processor
- Get the image and load it, then save the image loading data and embedding data.
- listen to the image click event, record the click data, divided into positive markers and negative markers, after each click according to the click data decoded to generate the mask data, and then according to the mask data to draw the segmentation effect.
- interface display, this to your own design arbitrary play, not my prevail
- click to save the image, according to the mask pixel data, match the original image data, and then exported through the canvas drawing
The code logic is as follows, React + TS
, see my project's source code for details, the source code is located in src/components/boardOperation/uploadImage/imageSegmentation.tsx
import { useState, useRef, useEffect, useMemo, MouseEvent, FC } from 'react'
import {
SamModel,
AutoProcessor,
RawImage,
PreTrainedModel,
Processor,
Tensor,
SamImageProcessorResult
} from '@huggingface/transformers'
import LoadingIcon from '@/components/icons/loading.svg?react'
import PositiveIcon from '@/components/icons/boardOperation/image-segmentation-positive.svg?react'
import NegativeIcon from '@/components/icons/boardOperation/image-segmentation-negative.svg?react'
interface MarkPoint {
position: number[]
label: number
}
const SEGMENTATION_STATUS = {
LOADING: 0,
NO_SUPPORT_WEBGPU: 1,
LOAD_ERROR: 2,
LOAD_SUCCESS: 3,
PROCESSING: 4,
PROCESSING_SUCCESS: 5
}
type SegmentationStatusType =
(typeof SEGMENTATION_STATUS)[keyof typeof SEGMENTATION_STATUS]
const ImageSegmentation: FC<{ url: string }> = ({ url }) => {
const [markPoints, setMarkPoints] = useState<MarkPoint[]>([])
const [segmentationStatus, setSegmentationStatus] =
useState<SegmentationStatusType>()
const [pointStatus, setPointStatus] = useState<boolean>(true)
const maskCanvasRef = useRef<HTMLCanvasElement>(null) // Segmentation mask
const modelRef = useRef<PreTrainedModel>() // model
const processorRef = useRef<Processor>() // processor
const imageInputRef = useRef<RawImage>() // original image
const imageProcessed = useRef<SamImageProcessorResult>() // Processed image
const imageEmbeddings = useRef<Tensor>() // Embedding data
const segmentationTip = useMemo(() => {
switch (segmentationStatus) {
case SEGMENTATION_STATUS.LOADING:
return 'Image Segmentation function Loading'
case SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU:
return 'WebGPU is not supported in this browser, to use the image segmentation function, please use the latest version of Google Chrome.'
case SEGMENTATION_STATUS.LOAD_ERROR:
return 'Image Segmentation function failed to load'
case SEGMENTATION_STATUS.LOAD_SUCCESS:
return 'Image Segmentation function loaded successfully'
case SEGMENTATION_STATUS.PROCESSING:
return 'Image Processing...'
case SEGMENTATION_STATUS.PROCESSING_SUCCESS:
return 'The image has been processed successfully, you can click on the image to mark it, the green mask area is the segmentation area.'
default:
return ''
}
}, [segmentationStatus])
// 1. load model and processor
useEffect(() => {
;(async () => {
try {
if (segmentationStatus === SEGMENTATION_STATUS.LOADING) {
return
}
setSegmentationStatus(SEGMENTATION_STATUS.LOADING)
if (!navigator?.gpu) {
setSegmentationStatus(SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU)
return
}
const model_id = 'Xenova/slimsam-77-uniform'
modelRef.current ??= await SamModel.from_pretrained(model_id, {
dtype: 'fp16', // or "fp32"
device: 'webgpu'
})
processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
setSegmentationStatus(SEGMENTATION_STATUS.LOAD_SUCCESS)
} catch (err) {
console.log('err', err)
setSegmentationStatus(SEGMENTATION_STATUS.LOAD_ERROR)
}
})()
}, [])
// 2. process image
useEffect(() => {
;(async () => {
try {
if (
!modelRef.current ||
!processorRef.current ||
!url ||
segmentationStatus === SEGMENTATION_STATUS.PROCESSING
) {
return
}
setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING)
clearPoints()
imageInputRef.current = await RawImage.fromURL(url)
imageProcessed.current = await processorRef.current(
imageInputRef.current
)
imageEmbeddings.current = await (
modelRef.current as any
).get_image_embeddings(imageProcessed.current)
setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING_SUCCESS)
} catch (err) {
console.log('err', err)
}
})()
}, [url, modelRef.current, processorRef.current])
// Updating the mask effect
function updateMaskOverlay(mask: RawImage, scores: Float32Array) {
const maskCanvas = maskCanvasRef.current
if (!maskCanvas) {
return
}
const maskContext = maskCanvas.getContext('2d') as CanvasRenderingContext2D
// Update canvas dimensions (if different)
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
maskCanvas.width = mask.width
maskCanvas.height = mask.height
}
// Allocate buffer for pixel data
const imageData = maskContext.createImageData(
maskCanvas.width,
maskCanvas.height
)
// Select best mask
const numMasks = scores.length // 3
let bestIndex = 0
for (let i = 1; i < numMasks; ++i) {
if (scores[i] > scores[bestIndex]) {
bestIndex = i
}
}
// Fill mask with colour
const pixelData = imageData.data
for (let i = 0; i < pixelData.length; ++i) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const offset = 4 * i
pixelData[offset] = 101 // r
pixelData[offset + 1] = 204 // g
pixelData[offset + 2] = 138 // b
pixelData[offset + 3] = 255 // a
}
}
// Draw image data to context
maskContext.putImageData(imageData, 0, 0)
}
// 3. Decoding based on click data
const decode = async (markPoints: MarkPoint[]) => {
if (
!modelRef.current ||
!imageEmbeddings.current ||
!processorRef.current ||
!imageProcessed.current
) {
return
}
// No click on the data directly clears the segmentation effect
if (!markPoints.length && maskCanvasRef.current) {
const maskContext = maskCanvasRef.current.getContext(
'2d'
) as CanvasRenderingContext2D
maskContext.clearRect(
0,
0,
maskCanvasRef.current.width,
maskCanvasRef.current.height
)
return
}
// Prepare inputs for decoding
const reshaped = imageProcessed.current.reshaped_input_sizes[0]
const points = markPoints
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
.flat(Infinity)
const labels = markPoints.map((x) => BigInt(x.label)).flat(Infinity)
const num_points = markPoints.length
const input_points = new Tensor('float32', points, [1, 1, num_points, 2])
const input_labels = new Tensor('int64', labels, [1, 1, num_points])
// Generate the mask
const { pred_masks, iou_scores } = await modelRef.current({
...imageEmbeddings.current,
input_points,
input_labels
})
// Post-process the mask
const masks = await (processorRef.current as any).post_process_masks(
pred_masks,
imageProcessed.current.original_sizes,
imageProcessed.current.reshaped_input_sizes
)
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data)
}
const clamp = (x: number, min = 0, max = 1) => {
return Math.max(Math.min(x, max), min)
}
const clickImage = (e: MouseEvent) => {
if (segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS) {
return
}
const { clientX, clientY, currentTarget } = e
const { left, top } = currentTarget.getBoundingClientRect()
const x = clamp(
(clientX - left + currentTarget.scrollLeft) / currentTarget.scrollWidth
)
const y = clamp(
(clientY - top + currentTarget.scrollTop) / currentTarget.scrollHeight
)
const existingPointIndex = markPoints.findIndex(
(point) =>
Math.abs(point.position[0] - x) < 0.01 &&
Math.abs(point.position[1] - y) < 0.01 &&
point.label === (pointStatus ? 1 : 0)
)
const newPoints = [...markPoints]
if (existingPointIndex !== -1) {
// If there is a marker in the currently clicked area, it is deleted.
newPoints.splice(existingPointIndex, 1)
} else {
newPoints.push({
position: [x, y],
label: pointStatus ? 1 : 0
})
}
setMarkPoints(newPoints)
decode(newPoints)
}
const clearPoints = () => {
setMarkPoints([])
decode([])
}
return (
<div className="card shadow-xl overflow-auto">
<div className="flex items-center gap-x-3">
<button className="btn btn-primary btn-sm" onClick={clearPoints}>
Clear Points
</button>
<button
className="btn btn-primary btn-sm"
onClick={() => setPointStatus(true)}
>
{pointStatus ? 'Positive' : 'Negative'}
</button>
</div>
<div className="text-xs text-base-content mt-2">{segmentationTip}</div>
<div
id="test-image-container"
className="relative mt-4 border border-base-content border-dashed rounded-lg h-[60vh] max-h-[500px] w-fit max-w-[60vw] overflow-x-auto overflow-y-hidden"
onClick={clickImage}
>
{segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS && (
<div className="absolute z-[3] top-0 left-0 w-full h-full bg-slate-400 bg-opacity-70 flex justify-center items-center">
<LoadingIcon className="animate-spin" />
</div>
)}
<div className="h-full w-max relative overflow-hidden">
<img className="h-full max-w-none" src={url} />
<canvas
ref={maskCanvasRef}
className="absolute top-0 left-0 h-full w-full z-[1] opacity-60"
></canvas>
{markPoints.map((point, index) => {
switch (point.label) {
case 1:
return (
<PositiveIcon
key={index}
className="w-[24px] h-[24px] absolute z-[2] -ml-[13px] -mt-[14px] fill-[#FFD401]"
style={{
top: `${point.position[1] * 100}%`,
left: `${point.position[0] * 100}%`
}}
/>
)
case 0:
return (
<NegativeIcon
key={index}
className="w-[24px] h-[24px] absolute z-[2] -ml-[13px] -mt-[14px] fill-[#F44237]"
style={{
top: `${point.position[1] * 100}%`,
left: `${point.position[0] * 100}%`
}}
/>
)
default:
return null
}
})}
</div>
</div>
</div>
)
}
export default ImageSegmentation
Conclusion
Thank you for reading. This is the whole content of this article, I hope this article is helpful to you, welcome to like and favourite. If you have any questions, please feel free to discuss in the comment section!
Top comments (0)